构建连接 Keras 与 Apache Iceberg 的高性能 Rust 特征服务层


模型训练和线上推理之间的数据一致性,是所有 MLOps 体系中绕不过的坎。训练时,我们用 Spark 从数据湖批量拉取数周的样本,生成特征;而在线上,模型需要毫秒级的响应,从某个键值存储中获取最新的实时特征。这两套路径,从数据源、计算引擎到存储介质都完全不同,由此引入的特征偏斜(Feature Skew)常常是模型效果衰减的根源。我们团队最近的痛点就源于此,一个欺诈检测模型的离线 AUC 能做到 0.95,但一上线就掉到 0.8,复盘下来,问题就出在一个关键特征的计算口径在 Python 批处理脚本和 Java 实时流处理任务中出现了细微偏差。

为了根治这个问题,我们决定自建一个统一的特征平台。初步构想是,它必须包含一个离线特征仓库和一个在线特征服务。离线部分负责存储海量的、可回溯的历史特征,用于模型训练和迭代;在线部分则提供低延迟的特征查询接口,供线上模型实时调用。

技术选型阶段经过了不少讨论。离线存储我们很快锁定了 Apache Iceberg。它的事务性、Schema 演进和时间旅行(Time Travel)能力,对于保证训练数据的可复现性和质量至关重要。在真实项目中,特征 Schema 的变更非常频繁,Iceberg 能优雅地处理这些变更,避免了重写整个数据集的痛苦。

真正的抉择点在于在线特征服务层。它必须满足几个严苛的条件:极低的响应延迟(p99 < 10ms)、高并发下的稳定性、以及内存安全,我们不能接受一个核心服务因为 GC 停顿或内存泄漏而抖动。团队最初倾向于使用 Vert.x 或 Netty 等成熟的 JVM 方案,但对 GC 的顾虑始终存在。这时,Rust 进入了我们的视野。其无 GC、所有权系统带来的内存安全保障,以及接近 C++ 的性能,使其成为构建这类基础组件的理想选择。在 Web 框架方面,Axum 因其与 Tokio 生态的无缝集成、优雅的中间件设计以及类型安全的编程体验而胜出。

于是,最终的架构蓝图清晰了:

  1. 离线特征存储: 使用 Apache Spark 定期计算特征,并写入 Apache Iceberg 表中。
  2. 在线特征缓存: 使用 Redis 或 ScyllaDB 存储需要实时访问的热特征。
  3. 特征服务层: 构建一个基于 Axum 和 Tonic (gRPC) 的 Rust 服务,负责从在线缓存中高效地拉取特征。
  4. 模型消费: Python 端的 Keras 模型通过 gRPC 客户端调用 Rust 服务,获取推理所需的特征。

整个系统的核心,就是这个 Rust 特征服务层。它不仅是性能瓶颈所在,也是连接数据工程和算法模型的桥梁。

graph TD
    subgraph "离线训练流程 (Offline Training Flow)"
        A[原始数据源] --> B{Apache Spark};
        B --> C[Apache Iceberg 表];
        C --> D[Python 训练脚本];
        D --> E[Keras 模型 v1.0];
    end

    subgraph "在线推理流程 (Online Inference Flow)"
        F[实时事件流] --> G{实时计算引擎};
        G --> H[在线缓存 Redis];
        I[业务请求] --> J[Keras 模型服务];
        J -- gRPC 请求 --> K((Rust 特征服务));
        K -- 查询 --> H;
        K -- gRPC 响应 --> J;
    end

    style K fill:#f9f,stroke:#333,stroke-width:2px

第一步:定义服务契约 (gRPC & Protobuf)

跨语言的服务调用,第一步永远是定义一个清晰、稳定、向后兼容的接口。gRPC 和 Protobuf 是这里的标准答案。我们定义一个 FeatureStore 服务,提供一个 GetOnlineFeatures 方法。

这里的坑在于请求和响应的设计。一个常见的错误是让 FeatureRequest 只接受一个实体 ID。但在真实场景中,我们可能需要一次性为多个实体(比如一个 mini-batch 的用户)批量获取特征。因此,请求应该接受一个 ID 列表。同样,返回的特征值应该用 map 来组织,以支持多种不同类型的特征。

proto/feature_store.proto:

syntax = "proto3";

package feature_store;

import "google/protobuf/struct.proto";

// 定义特征服务
service FeatureStore {
  // 批量获取在线特征
  rpc GetOnlineFeatures(OnlineFeaturesRequest) returns (OnlineFeaturesResponse);
}

// 请求体: 包含一个或多个实体ID
message OnlineFeaturesRequest {
  // 要查询的实体ID列表,例如 user_id, device_id 等
  repeated string entity_ids = 1;
  // 指定需要拉取的特征名称列表,如果为空则返回所有可用特征
  repeated string feature_names = 2;
}

// 响应体: 每个实体ID对应一组特征
message OnlineFeaturesResponse {
  // key 是实体ID, value 是该实体的特征集合
  map<string, FeatureSet> results = 1;
}

// 单个实体的特征集合
message FeatureSet {
  // key 是特征名, value 是特征值
  // 使用 google.protobuf.Value 来支持多种数据类型 (string, number, bool, etc.)
  map<string, google.protobuf.Value> features = 1;
}

使用 google.protobuf.Value 是一个关键决策。它让我们不必为每一种特征类型(浮点、整型、字符串)都定义一个字段,从而使 API 更具灵活性,以应对未来新增的各种特征类型。

第二步:构建 Axum + Tonic gRPC 服务

现在开始构建 Rust 服务。首先是项目结构和依赖。

Cargo.toml:

[package]
name = "feature-server"
version = "0.1.0"
edition = "2021"

[dependencies]
axum = "0.7"
tokio = { version = "1", features = ["full"] }
tonic = "0.11"
prost = "0.12"
redis = { version = "0.25", features = ["tokio-comp"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
config = "0.14"
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }

[build-dependencies]
tonic-build = "0.11"

我们需要一个 build.rs 文件来在编译时根据 .proto 文件自动生成 Rust 代码。

build.rs:

fn main() -> Result<(), Box<dyn std::error::Error>> {
    tonic_build::configure()
        .build_server(true)
        .compile(&["proto/feature_store.proto"], &["proto/"])?;
    Ok(())
}

接下来是服务的主体实现。我们将配置、日志、Redis 连接池和 gRPC 服务逻辑清晰地分离开。

src/main.rs:

use std::net::SocketAddr;
use tonic::transport::Server;
use tracing::{info, error};

mod settings;
mod feature_service;

// 引入由 tonic-build 生成的代码
mod proto {
    pub mod feature_store {
        tonic::include_proto!("feature_store");
    }
}

use feature_service::MyFeatureStore;
use proto::feature_store::feature_store_server::FeatureStoreServer;
use settings::Settings;

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
    // 1. 初始化日志
    tracing_subscriber::fmt()
        .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
        .init();

    // 2. 加载配置
    let settings = match Settings::new() {
        Ok(s) => s,
        Err(e) => {
            error!("Failed to load configuration: {}", e);
            return Err(e.into());
        }
    };
    info!("Configuration loaded: {:?}", settings);
    
    // 3. 创建 Redis 连接池
    let redis_client = redis::Client::open(settings.redis.url)?;
    let redis_pool = redis::aio::ConnectionManager::new(redis_client).await?;
    info!("Redis connection pool created.");

    // 4. 实例化 gRPC 服务实现
    let feature_store = MyFeatureStore::new(redis_pool);
    let addr: SocketAddr = settings.server.address.parse()?;
    
    info!("Feature server listening on {}", addr);

    // 5. 启动 Tonic gRPC 服务器
    Server::builder()
        .add_service(FeatureStoreServer::new(feature_store))
        .serve(addr)
        .await?;

    Ok(())
}

我们将配置逻辑封装在 settings.rs 中。在生产级项目中,硬编码配置是不可接受的。

src/settings.rs:

use config::{Config, ConfigError, File};
use serde::Deserialize;

#[derive(Debug, Deserialize)]
pub struct Server {
    pub address: String,
}

#[derive(Debug, Deserialize)]
pub struct Redis {
    pub url: String,
}

#[derive(Debug, Deserialize)]
pub struct Settings {
    pub server: Server,
    pub redis: Redis,
}

impl Settings {
    pub fn new() -> Result<Self, ConfigError> {
        let s = Config::builder()
            // 从 `config/default.toml` 加载默认配置
            .add_source(File::with_name("config/default"))
            // 可以通过环境变量覆盖,例如 APP_SERVER_ADDRESS=0.0.0.0:50052
            .add_source(config::Environment::with_prefix("APP").separator("__"))
            .build()?;
        s.try_deserialize()
    }
}

config/default.toml:

[server]
address = "0.0.0.0:50051"

[redis]
url = "redis://127.0.0.1/"

这是服务的核心业务逻辑,即从 Redis 中获取特征。

src/feature_service.rs:

use std::collections::HashMap;
use tonic::{Request, Response, Status};
use prost_types::Value as ProstValue;
use serde_json::Value as JsonValue;
use tracing::{info, warn};

// 引入生成的代码
use crate::proto::feature_store::{
    feature_store_server::FeatureStore,
    FeatureSet, OnlineFeaturesRequest, OnlineFeaturesResponse,
};

// 使用类型别名来简化连接池的类型
type RedisPool = redis::aio::ConnectionManager;

#[derive(Debug)]
pub struct MyFeatureStore {
    // Redis 连接池
    redis_pool: RedisPool,
}

impl MyFeatureStore {
    pub fn new(pool: RedisPool) -> Self {
        Self { redis_pool: pool }
    }
}

#[tonic::async_trait]
impl FeatureStore for MyFeatureStore {
    async fn get_online_features(
        &self,
        request: Request<OnlineFeaturesRequest>,
    ) -> Result<Response<OnlineFeaturesResponse>, Status> {
        let req = request.into_inner();
        info!("Received request for {} entities", req.entity_ids.len());

        if req.entity_ids.is_empty() {
            return Err(Status::invalid_argument("Entity IDs cannot be empty."));
        }

        let mut conn = self.redis_pool.clone();
        
        // 在真实项目中,这里应该用 MGET 批量获取,以获得最佳性能。
        // 为简化示例,这里使用循环,但已为批量操作准备好。
        let mut results = HashMap::new();
        for entity_id in &req.entity_ids {
            let redis_key = format!("features:{}", entity_id);
            let result: Result<String, redis::RedisError> = redis::cmd("GET")
                .arg(&redis_key)
                .query_async(&mut conn)
                .await;

            match result {
                Ok(json_str) => {
                    // Redis 中存储的是 JSON 字符串
                    let features: HashMap<String, JsonValue> = serde_json::from_str(&json_str)
                        .unwrap_or_else(|e| {
                            warn!("Failed to parse JSON for entity {}: {}", entity_id, e);
                            HashMap::new()
                        });
                    
                    let mut feature_map = HashMap::new();
                    // 这里需要将 serde_json::Value 转换为 prost_types::Value
                    for (name, value) in features {
                        // 如果请求指定了 feature_names,则进行过滤
                        if req.feature_names.is_empty() || req.feature_names.contains(&name) {
                            feature_map.insert(name, json_to_prost_value(value));
                        }
                    }
                    results.insert(entity_id.clone(), FeatureSet { features: feature_map });
                },
                Err(e) => {
                    warn!("Failed to get features for entity {}: {}", entity_id, e);
                    // 即使某个 key 失败,也不应该中断整个请求,而是返回一个空结果
                    results.insert(entity_id.clone(), FeatureSet { features: HashMap::new() });
                }
            }
        }
        
        let response = OnlineFeaturesResponse { results };
        Ok(Response::new(response))
    }
}


// 辅助函数: 将 serde_json::Value 转换为 prost_types::Value
fn json_to_prost_value(json_value: JsonValue) -> ProstValue {
    match json_value {
        JsonValue::Null => ProstValue { kind: Some(prost_types::value::Kind::NullValue(0)) },
        JsonValue::Bool(b) => ProstValue { kind: Some(prost_types::value::Kind::BoolValue(b)) },
        JsonValue::Number(n) => ProstValue { kind: Some(prost_types::value::Kind::NumberValue(n.as_f64().unwrap_or(0.0))) },
        JsonValue::String(s) => ProstValue { kind: Some(prost_types::value::Kind::StringValue(s)) },
        JsonValue::Array(a) => {
            let values = a.into_iter().map(json_to_prost_value).collect();
            ProstValue { kind: Some(prost_types::value::Kind::ListValue(prost_types::ListValue { values })) }
        }
        JsonValue::Object(o) => {
            let fields = o.into_iter().map(|(k, v)| (k, json_to_prost_value(v))).collect();
            ProstValue { kind: Some(prost_types::value::Kind::StructValue(prost_types::Struct { fields })) }
        }
    }
}

这里的错误处理逻辑体现了生产级服务的考量:单个实体查询失败不应该导致整个批量请求失败。我们记录警告日志,并为该实体返回一个空的特征集,让调用方来决定如何处理部分失败。此外,json_to_prost_value 这个转换函数是连接数据存储格式 (JSON) 和 API 传输格式 (Protobuf) 的关键胶水代码。

第三步:离线特征的生产与加载

Rust 服务已经就绪,但它需要数据。我们需要一个流程来生成特征并将其写入 Iceberg(用于训练)和 Redis(用于在线服务)。这里使用 PySpark 来完成。

scripts/generate_features.py:

from pyspark.sql import SparkSession
from pyspark.sql.functions import col, lit, rand, to_json, struct
import redis

def write_to_iceberg(df, table_name):
    print(f"Writing data to Iceberg table: {table_name}")
    df.write \
      .format("iceberg") \
      .mode("overwrite") \
      .save(table_name)
    print("Write to Iceberg complete.")

def write_to_redis(df):
    print("Writing data to Redis...")
    r = redis.Redis(host='localhost', port=6379, db=0)
    # 在生产环境中,应该使用 pipeline 来批量写入
    with r.pipeline() as pipe:
        for row in df.collect():
            entity_id = row.entity_id
            # 将除 entity_id 之外的所有列转换为 JSON
            features_df = row.asDict()
            del features_df['entity_id']
            redis_key = f"features:{entity_id}"
            pipe.set(redis_key, to_json(struct([col(c) for c in features_df])).first()[0])
    
    # 模拟 to_json(struct(...)) 的行为
    # 真实项目中 to_json 是 Spark 函数,这里为了能在 collect 后使用,手动转换
    pdf = df.toPandas()
    with r.pipeline() as pipe:
        for _, row in pdf.iterrows():
            entity_id = row['entity_id']
            redis_key = f"features:{entity_id}"
            feature_dict = row.drop('entity_id').to_dict()
            pipe.set(redis_key, json.dumps(feature_dict))
        pipe.execute()
    print("Write to Redis complete.")

if __name__ == "__main__":
    spark = SparkSession.builder \
        .appName("FeatureGeneration") \
        .config("spark.sql.extensions", "org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions") \
        .config("spark.sql.catalog.spark_catalog", "org.apache.iceberg.spark.SparkSessionCatalog") \
        .config("spark.sql.catalog.spark_catalog.type", "hive") \
        .getOrCreate()
        
    # 1. 生成模拟特征数据
    data = [(f"user_{i}", i * 0.1, i % 10, "A" if i % 2 == 0 else "B") for i in range(100)]
    columns = ["entity_id", "feature_a", "feature_b", "feature_c"]
    df = spark.createDataFrame(data, columns)

    # 2. 写入 Iceberg 表,用于模型训练
    iceberg_table = "default.user_features"
    write_to_iceberg(df, iceberg_table)

    # 3. 写入 Redis,用于在线服务
    # 在实际场景中,可能会有单独的流处理任务来做这一步
    write_to_redis(df)
    
    spark.stop()

这个脚本模拟了特征工程的两个输出:一份写入 Iceberg 用于模型分析和训练,另一份写入 Redis 用于在线推理。这清晰地展示了“离线”和“在线”两条路径。

第四步:Keras 模型侧的 gRPC 客户端

最后,我们需要让 Keras 模型能够消费这个服务。我们需要在 Python 环境中生成 gRPC 客户端代码,并编写一个简单的客户端来调用它。

首先,生成 Python 代码:
python -m grpc_tools.protoc -I./proto --python_out=. --grpc_python_out=. ./proto/feature_store.proto

然后,编写客户端和模拟的 Keras 推理函数。

client/keras_client.py:

import grpc
import json
import numpy as np
# 假设 Keras 模型已经加载
# from tensorflow import keras

# 导入生成的 gRPC 代码
import feature_store_pb2
import feature_store_pb2_grpc

# 模拟一个已经加载的 Keras 模型
class MockKerasModel:
    def predict(self, features):
        print(f"Mock model predicting with features: {features}")
        # 简单逻辑:基于 feature_a 的值进行预测
        return [0.9 if f[0] > 5.0 else 0.1 for f in features]

def get_features_from_server(stub, entity_ids):
    """通过 gRPC 从特征服务器获取特征"""
    request = feature_store_pb2.OnlineFeaturesRequest(entity_ids=entity_ids)
    try:
        response = stub.GetOnlineFeatures(request, timeout=0.5) # 设置超时
        
        # 将 Protobuf 响应转换为更易于使用的 Python 字典
        feature_dict = {}
        for entity_id, feature_set in response.results.items():
            features = {}
            for name, value in feature_set.features.items():
                # 这个转换逻辑需要根据 `google.protobuf.Value` 的类型来做
                if value.HasField("string_value"):
                    features[name] = value.string_value
                elif value.HasField("number_value"):
                    features[name] = value.number_value
                elif value.HasField("bool_value"):
                    features[name] = value.bool_value
                # ... 其他类型
            feature_dict[entity_id] = features
        return feature_dict
    except grpc.RpcError as e:
        print(f"RPC failed: {e.code()} - {e.details()}")
        return {eid: {} for eid in entity_ids} # 返回空字典以保证流程继续

def online_predict(model, stub, entity_ids):
    """在线推理的完整流程"""
    # 1. 从特征服务获取特征
    features_map = get_features_from_server(stub, entity_ids)
    
    # 2. 特征预处理和排序,确保输入模型的顺序是固定的
    # 这是一个非常容易出错的地方,必须保证这里的顺序和训练时一致
    feature_order = ['feature_a', 'feature_b'] # 假设模型需要这两个特征
    
    batch_features = []
    valid_ids = []
    for entity_id in entity_ids:
        user_features = features_map.get(entity_id, {})
        # 检查所需特征是否存在,如果不存在则跳过该用户的预测
        if all(key in user_features for key in feature_order):
             # 注意:这里的 one-hot encoding 或其他转换逻辑必须和训练时完全一致
            processed = [user_features[name] for name in feature_order]
            batch_features.append(processed)
            valid_ids.append(entity_id)
        else:
            print(f"Skipping prediction for {entity_id} due to missing features.")

    if not batch_features:
        print("No valid features found for prediction.")
        return {}
        
    # 3. 模型预测
    predictions = model.predict(np.array(batch_features))
    
    # 4. 组合结果
    results = {uid: float(pred) for uid, pred in zip(valid_ids, predictions)}
    return results

if __name__ == '__main__':
    with grpc.insecure_channel('localhost:50051') as channel:
        stub = feature_store_pb2_grpc.FeatureStoreStub(channel)
        
        model = MockKerasModel()
        
        # 模拟一个批量推理请求
        user_ids_to_predict = ["user_10", "user_55", "user_99", "user_101"]
        
        final_predictions = online_predict(model, stub, user_ids_to_predict)
        
        print("\n--- Final Predictions ---")
        print(json.dumps(final_predictions, indent=2))

这个 Python 客户端代码暴露了 MLOps 中另一个关键且棘手的问题:在线推理时的特征预处理逻辑必须与离线训练时严格对齐。feature_order 的硬编码、缺失值的处理方式、数值的归一化,任何一个环节的微小差异都可能导致模型表现的断崖式下跌。这也是为什么很多成熟的特征平台会提供一个 SDK,将这部分逻辑也封装起来,确保一致性。

sequenceDiagram
    participant Python Client
    participant Rust gRPC Server
    participant Redis

    Python Client->>Rust gRPC Server: GetOnlineFeatures(["user_10", "user_55"])
    Rust gRPC Server->>Redis: MGET features:user_10, features:user_55
    Redis-->>Rust gRPC Server: [json_str_10, json_str_55]
    Rust gRPC Server->>Rust gRPC Server: Parse JSON, Convert to Protobuf
    Rust gRPC Server-->>Python Client: OnlineFeaturesResponse
    Python Client->>Python Client: Preprocess features
    Python Client->>Keras Model: predict(np.array)
    Keras Model-->>Python Client: Predictions

这套架构并非没有局限性。首先,我们目前依赖手动运行 PySpark 脚本来同步数据到 Redis,在真实世界里,这应该是一个由 Airflow 调度的批处理任务,或者是一个 Flink/Spark Streaming 实时任务。其次,特征转换逻辑(如 one-hot encoding)目前散落在 Python 训练脚本和推理客户端中,存在不一致的风险。一个改进方向是将这些转换逻辑也下沉到 Rust 服务中,或者使用一个共享的配置(如 YAML)来定义转换规则,确保两端执行相同的操作。最后,系统的可观测性尚待完善,需要加入更详细的指标(如请求延迟、缓存命中率、错误率)和分布式追踪,以便在出现问题时快速定位。


  目录