模型训练和线上推理之间的数据一致性,是所有 MLOps 体系中绕不过的坎。训练时,我们用 Spark 从数据湖批量拉取数周的样本,生成特征;而在线上,模型需要毫秒级的响应,从某个键值存储中获取最新的实时特征。这两套路径,从数据源、计算引擎到存储介质都完全不同,由此引入的特征偏斜(Feature Skew)常常是模型效果衰减的根源。我们团队最近的痛点就源于此,一个欺诈检测模型的离线 AUC 能做到 0.95,但一上线就掉到 0.8,复盘下来,问题就出在一个关键特征的计算口径在 Python 批处理脚本和 Java 实时流处理任务中出现了细微偏差。
为了根治这个问题,我们决定自建一个统一的特征平台。初步构想是,它必须包含一个离线特征仓库和一个在线特征服务。离线部分负责存储海量的、可回溯的历史特征,用于模型训练和迭代;在线部分则提供低延迟的特征查询接口,供线上模型实时调用。
技术选型阶段经过了不少讨论。离线存储我们很快锁定了 Apache Iceberg。它的事务性、Schema 演进和时间旅行(Time Travel)能力,对于保证训练数据的可复现性和质量至关重要。在真实项目中,特征 Schema 的变更非常频繁,Iceberg 能优雅地处理这些变更,避免了重写整个数据集的痛苦。
真正的抉择点在于在线特征服务层。它必须满足几个严苛的条件:极低的响应延迟(p99 < 10ms)、高并发下的稳定性、以及内存安全,我们不能接受一个核心服务因为 GC 停顿或内存泄漏而抖动。团队最初倾向于使用 Vert.x 或 Netty 等成熟的 JVM 方案,但对 GC 的顾虑始终存在。这时,Rust 进入了我们的视野。其无 GC、所有权系统带来的内存安全保障,以及接近 C++ 的性能,使其成为构建这类基础组件的理想选择。在 Web 框架方面,Axum 因其与 Tokio 生态的无缝集成、优雅的中间件设计以及类型安全的编程体验而胜出。
于是,最终的架构蓝图清晰了:
- 离线特征存储: 使用 Apache Spark 定期计算特征,并写入 Apache Iceberg 表中。
- 在线特征缓存: 使用 Redis 或 ScyllaDB 存储需要实时访问的热特征。
- 特征服务层: 构建一个基于 Axum 和 Tonic (gRPC) 的 Rust 服务,负责从在线缓存中高效地拉取特征。
- 模型消费: Python 端的 Keras 模型通过 gRPC 客户端调用 Rust 服务,获取推理所需的特征。
整个系统的核心,就是这个 Rust 特征服务层。它不仅是性能瓶颈所在,也是连接数据工程和算法模型的桥梁。
graph TD
subgraph "离线训练流程 (Offline Training Flow)"
A[原始数据源] --> B{Apache Spark};
B --> C[Apache Iceberg 表];
C --> D[Python 训练脚本];
D --> E[Keras 模型 v1.0];
end
subgraph "在线推理流程 (Online Inference Flow)"
F[实时事件流] --> G{实时计算引擎};
G --> H[在线缓存 Redis];
I[业务请求] --> J[Keras 模型服务];
J -- gRPC 请求 --> K((Rust 特征服务));
K -- 查询 --> H;
K -- gRPC 响应 --> J;
end
style K fill:#f9f,stroke:#333,stroke-width:2px
第一步:定义服务契约 (gRPC & Protobuf)
跨语言的服务调用,第一步永远是定义一个清晰、稳定、向后兼容的接口。gRPC 和 Protobuf 是这里的标准答案。我们定义一个 FeatureStore 服务,提供一个 GetOnlineFeatures 方法。
这里的坑在于请求和响应的设计。一个常见的错误是让 FeatureRequest 只接受一个实体 ID。但在真实场景中,我们可能需要一次性为多个实体(比如一个 mini-batch 的用户)批量获取特征。因此,请求应该接受一个 ID 列表。同样,返回的特征值应该用 map 来组织,以支持多种不同类型的特征。
proto/feature_store.proto:
syntax = "proto3";
package feature_store;
import "google/protobuf/struct.proto";
// 定义特征服务
service FeatureStore {
// 批量获取在线特征
rpc GetOnlineFeatures(OnlineFeaturesRequest) returns (OnlineFeaturesResponse);
}
// 请求体: 包含一个或多个实体ID
message OnlineFeaturesRequest {
// 要查询的实体ID列表,例如 user_id, device_id 等
repeated string entity_ids = 1;
// 指定需要拉取的特征名称列表,如果为空则返回所有可用特征
repeated string feature_names = 2;
}
// 响应体: 每个实体ID对应一组特征
message OnlineFeaturesResponse {
// key 是实体ID, value 是该实体的特征集合
map<string, FeatureSet> results = 1;
}
// 单个实体的特征集合
message FeatureSet {
// key 是特征名, value 是特征值
// 使用 google.protobuf.Value 来支持多种数据类型 (string, number, bool, etc.)
map<string, google.protobuf.Value> features = 1;
}
使用 google.protobuf.Value 是一个关键决策。它让我们不必为每一种特征类型(浮点、整型、字符串)都定义一个字段,从而使 API 更具灵活性,以应对未来新增的各种特征类型。
第二步:构建 Axum + Tonic gRPC 服务
现在开始构建 Rust 服务。首先是项目结构和依赖。
Cargo.toml:
[package]
name = "feature-server"
version = "0.1.0"
edition = "2021"
[dependencies]
axum = "0.7"
tokio = { version = "1", features = ["full"] }
tonic = "0.11"
prost = "0.12"
redis = { version = "0.25", features = ["tokio-comp"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
config = "0.14"
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
[build-dependencies]
tonic-build = "0.11"
我们需要一个 build.rs 文件来在编译时根据 .proto 文件自动生成 Rust 代码。
build.rs:
fn main() -> Result<(), Box<dyn std::error::Error>> {
tonic_build::configure()
.build_server(true)
.compile(&["proto/feature_store.proto"], &["proto/"])?;
Ok(())
}
接下来是服务的主体实现。我们将配置、日志、Redis 连接池和 gRPC 服务逻辑清晰地分离开。
src/main.rs:
use std::net::SocketAddr;
use tonic::transport::Server;
use tracing::{info, error};
mod settings;
mod feature_service;
// 引入由 tonic-build 生成的代码
mod proto {
pub mod feature_store {
tonic::include_proto!("feature_store");
}
}
use feature_service::MyFeatureStore;
use proto::feature_store::feature_store_server::FeatureStoreServer;
use settings::Settings;
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
// 1. 初始化日志
tracing_subscriber::fmt()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.init();
// 2. 加载配置
let settings = match Settings::new() {
Ok(s) => s,
Err(e) => {
error!("Failed to load configuration: {}", e);
return Err(e.into());
}
};
info!("Configuration loaded: {:?}", settings);
// 3. 创建 Redis 连接池
let redis_client = redis::Client::open(settings.redis.url)?;
let redis_pool = redis::aio::ConnectionManager::new(redis_client).await?;
info!("Redis connection pool created.");
// 4. 实例化 gRPC 服务实现
let feature_store = MyFeatureStore::new(redis_pool);
let addr: SocketAddr = settings.server.address.parse()?;
info!("Feature server listening on {}", addr);
// 5. 启动 Tonic gRPC 服务器
Server::builder()
.add_service(FeatureStoreServer::new(feature_store))
.serve(addr)
.await?;
Ok(())
}
我们将配置逻辑封装在 settings.rs 中。在生产级项目中,硬编码配置是不可接受的。
src/settings.rs:
use config::{Config, ConfigError, File};
use serde::Deserialize;
#[derive(Debug, Deserialize)]
pub struct Server {
pub address: String,
}
#[derive(Debug, Deserialize)]
pub struct Redis {
pub url: String,
}
#[derive(Debug, Deserialize)]
pub struct Settings {
pub server: Server,
pub redis: Redis,
}
impl Settings {
pub fn new() -> Result<Self, ConfigError> {
let s = Config::builder()
// 从 `config/default.toml` 加载默认配置
.add_source(File::with_name("config/default"))
// 可以通过环境变量覆盖,例如 APP_SERVER_ADDRESS=0.0.0.0:50052
.add_source(config::Environment::with_prefix("APP").separator("__"))
.build()?;
s.try_deserialize()
}
}
config/default.toml:
[server]
address = "0.0.0.0:50051"
[redis]
url = "redis://127.0.0.1/"
这是服务的核心业务逻辑,即从 Redis 中获取特征。
src/feature_service.rs:
use std::collections::HashMap;
use tonic::{Request, Response, Status};
use prost_types::Value as ProstValue;
use serde_json::Value as JsonValue;
use tracing::{info, warn};
// 引入生成的代码
use crate::proto::feature_store::{
feature_store_server::FeatureStore,
FeatureSet, OnlineFeaturesRequest, OnlineFeaturesResponse,
};
// 使用类型别名来简化连接池的类型
type RedisPool = redis::aio::ConnectionManager;
#[derive(Debug)]
pub struct MyFeatureStore {
// Redis 连接池
redis_pool: RedisPool,
}
impl MyFeatureStore {
pub fn new(pool: RedisPool) -> Self {
Self { redis_pool: pool }
}
}
#[tonic::async_trait]
impl FeatureStore for MyFeatureStore {
async fn get_online_features(
&self,
request: Request<OnlineFeaturesRequest>,
) -> Result<Response<OnlineFeaturesResponse>, Status> {
let req = request.into_inner();
info!("Received request for {} entities", req.entity_ids.len());
if req.entity_ids.is_empty() {
return Err(Status::invalid_argument("Entity IDs cannot be empty."));
}
let mut conn = self.redis_pool.clone();
// 在真实项目中,这里应该用 MGET 批量获取,以获得最佳性能。
// 为简化示例,这里使用循环,但已为批量操作准备好。
let mut results = HashMap::new();
for entity_id in &req.entity_ids {
let redis_key = format!("features:{}", entity_id);
let result: Result<String, redis::RedisError> = redis::cmd("GET")
.arg(&redis_key)
.query_async(&mut conn)
.await;
match result {
Ok(json_str) => {
// Redis 中存储的是 JSON 字符串
let features: HashMap<String, JsonValue> = serde_json::from_str(&json_str)
.unwrap_or_else(|e| {
warn!("Failed to parse JSON for entity {}: {}", entity_id, e);
HashMap::new()
});
let mut feature_map = HashMap::new();
// 这里需要将 serde_json::Value 转换为 prost_types::Value
for (name, value) in features {
// 如果请求指定了 feature_names,则进行过滤
if req.feature_names.is_empty() || req.feature_names.contains(&name) {
feature_map.insert(name, json_to_prost_value(value));
}
}
results.insert(entity_id.clone(), FeatureSet { features: feature_map });
},
Err(e) => {
warn!("Failed to get features for entity {}: {}", entity_id, e);
// 即使某个 key 失败,也不应该中断整个请求,而是返回一个空结果
results.insert(entity_id.clone(), FeatureSet { features: HashMap::new() });
}
}
}
let response = OnlineFeaturesResponse { results };
Ok(Response::new(response))
}
}
// 辅助函数: 将 serde_json::Value 转换为 prost_types::Value
fn json_to_prost_value(json_value: JsonValue) -> ProstValue {
match json_value {
JsonValue::Null => ProstValue { kind: Some(prost_types::value::Kind::NullValue(0)) },
JsonValue::Bool(b) => ProstValue { kind: Some(prost_types::value::Kind::BoolValue(b)) },
JsonValue::Number(n) => ProstValue { kind: Some(prost_types::value::Kind::NumberValue(n.as_f64().unwrap_or(0.0))) },
JsonValue::String(s) => ProstValue { kind: Some(prost_types::value::Kind::StringValue(s)) },
JsonValue::Array(a) => {
let values = a.into_iter().map(json_to_prost_value).collect();
ProstValue { kind: Some(prost_types::value::Kind::ListValue(prost_types::ListValue { values })) }
}
JsonValue::Object(o) => {
let fields = o.into_iter().map(|(k, v)| (k, json_to_prost_value(v))).collect();
ProstValue { kind: Some(prost_types::value::Kind::StructValue(prost_types::Struct { fields })) }
}
}
}
这里的错误处理逻辑体现了生产级服务的考量:单个实体查询失败不应该导致整个批量请求失败。我们记录警告日志,并为该实体返回一个空的特征集,让调用方来决定如何处理部分失败。此外,json_to_prost_value 这个转换函数是连接数据存储格式 (JSON) 和 API 传输格式 (Protobuf) 的关键胶水代码。
第三步:离线特征的生产与加载
Rust 服务已经就绪,但它需要数据。我们需要一个流程来生成特征并将其写入 Iceberg(用于训练)和 Redis(用于在线服务)。这里使用 PySpark 来完成。
scripts/generate_features.py:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, lit, rand, to_json, struct
import redis
def write_to_iceberg(df, table_name):
print(f"Writing data to Iceberg table: {table_name}")
df.write \
.format("iceberg") \
.mode("overwrite") \
.save(table_name)
print("Write to Iceberg complete.")
def write_to_redis(df):
print("Writing data to Redis...")
r = redis.Redis(host='localhost', port=6379, db=0)
# 在生产环境中,应该使用 pipeline 来批量写入
with r.pipeline() as pipe:
for row in df.collect():
entity_id = row.entity_id
# 将除 entity_id 之外的所有列转换为 JSON
features_df = row.asDict()
del features_df['entity_id']
redis_key = f"features:{entity_id}"
pipe.set(redis_key, to_json(struct([col(c) for c in features_df])).first()[0])
# 模拟 to_json(struct(...)) 的行为
# 真实项目中 to_json 是 Spark 函数,这里为了能在 collect 后使用,手动转换
pdf = df.toPandas()
with r.pipeline() as pipe:
for _, row in pdf.iterrows():
entity_id = row['entity_id']
redis_key = f"features:{entity_id}"
feature_dict = row.drop('entity_id').to_dict()
pipe.set(redis_key, json.dumps(feature_dict))
pipe.execute()
print("Write to Redis complete.")
if __name__ == "__main__":
spark = SparkSession.builder \
.appName("FeatureGeneration") \
.config("spark.sql.extensions", "org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions") \
.config("spark.sql.catalog.spark_catalog", "org.apache.iceberg.spark.SparkSessionCatalog") \
.config("spark.sql.catalog.spark_catalog.type", "hive") \
.getOrCreate()
# 1. 生成模拟特征数据
data = [(f"user_{i}", i * 0.1, i % 10, "A" if i % 2 == 0 else "B") for i in range(100)]
columns = ["entity_id", "feature_a", "feature_b", "feature_c"]
df = spark.createDataFrame(data, columns)
# 2. 写入 Iceberg 表,用于模型训练
iceberg_table = "default.user_features"
write_to_iceberg(df, iceberg_table)
# 3. 写入 Redis,用于在线服务
# 在实际场景中,可能会有单独的流处理任务来做这一步
write_to_redis(df)
spark.stop()
这个脚本模拟了特征工程的两个输出:一份写入 Iceberg 用于模型分析和训练,另一份写入 Redis 用于在线推理。这清晰地展示了“离线”和“在线”两条路径。
第四步:Keras 模型侧的 gRPC 客户端
最后,我们需要让 Keras 模型能够消费这个服务。我们需要在 Python 环境中生成 gRPC 客户端代码,并编写一个简单的客户端来调用它。
首先,生成 Python 代码:python -m grpc_tools.protoc -I./proto --python_out=. --grpc_python_out=. ./proto/feature_store.proto
然后,编写客户端和模拟的 Keras 推理函数。
client/keras_client.py:
import grpc
import json
import numpy as np
# 假设 Keras 模型已经加载
# from tensorflow import keras
# 导入生成的 gRPC 代码
import feature_store_pb2
import feature_store_pb2_grpc
# 模拟一个已经加载的 Keras 模型
class MockKerasModel:
def predict(self, features):
print(f"Mock model predicting with features: {features}")
# 简单逻辑:基于 feature_a 的值进行预测
return [0.9 if f[0] > 5.0 else 0.1 for f in features]
def get_features_from_server(stub, entity_ids):
"""通过 gRPC 从特征服务器获取特征"""
request = feature_store_pb2.OnlineFeaturesRequest(entity_ids=entity_ids)
try:
response = stub.GetOnlineFeatures(request, timeout=0.5) # 设置超时
# 将 Protobuf 响应转换为更易于使用的 Python 字典
feature_dict = {}
for entity_id, feature_set in response.results.items():
features = {}
for name, value in feature_set.features.items():
# 这个转换逻辑需要根据 `google.protobuf.Value` 的类型来做
if value.HasField("string_value"):
features[name] = value.string_value
elif value.HasField("number_value"):
features[name] = value.number_value
elif value.HasField("bool_value"):
features[name] = value.bool_value
# ... 其他类型
feature_dict[entity_id] = features
return feature_dict
except grpc.RpcError as e:
print(f"RPC failed: {e.code()} - {e.details()}")
return {eid: {} for eid in entity_ids} # 返回空字典以保证流程继续
def online_predict(model, stub, entity_ids):
"""在线推理的完整流程"""
# 1. 从特征服务获取特征
features_map = get_features_from_server(stub, entity_ids)
# 2. 特征预处理和排序,确保输入模型的顺序是固定的
# 这是一个非常容易出错的地方,必须保证这里的顺序和训练时一致
feature_order = ['feature_a', 'feature_b'] # 假设模型需要这两个特征
batch_features = []
valid_ids = []
for entity_id in entity_ids:
user_features = features_map.get(entity_id, {})
# 检查所需特征是否存在,如果不存在则跳过该用户的预测
if all(key in user_features for key in feature_order):
# 注意:这里的 one-hot encoding 或其他转换逻辑必须和训练时完全一致
processed = [user_features[name] for name in feature_order]
batch_features.append(processed)
valid_ids.append(entity_id)
else:
print(f"Skipping prediction for {entity_id} due to missing features.")
if not batch_features:
print("No valid features found for prediction.")
return {}
# 3. 模型预测
predictions = model.predict(np.array(batch_features))
# 4. 组合结果
results = {uid: float(pred) for uid, pred in zip(valid_ids, predictions)}
return results
if __name__ == '__main__':
with grpc.insecure_channel('localhost:50051') as channel:
stub = feature_store_pb2_grpc.FeatureStoreStub(channel)
model = MockKerasModel()
# 模拟一个批量推理请求
user_ids_to_predict = ["user_10", "user_55", "user_99", "user_101"]
final_predictions = online_predict(model, stub, user_ids_to_predict)
print("\n--- Final Predictions ---")
print(json.dumps(final_predictions, indent=2))
这个 Python 客户端代码暴露了 MLOps 中另一个关键且棘手的问题:在线推理时的特征预处理逻辑必须与离线训练时严格对齐。feature_order 的硬编码、缺失值的处理方式、数值的归一化,任何一个环节的微小差异都可能导致模型表现的断崖式下跌。这也是为什么很多成熟的特征平台会提供一个 SDK,将这部分逻辑也封装起来,确保一致性。
sequenceDiagram
participant Python Client
participant Rust gRPC Server
participant Redis
Python Client->>Rust gRPC Server: GetOnlineFeatures(["user_10", "user_55"])
Rust gRPC Server->>Redis: MGET features:user_10, features:user_55
Redis-->>Rust gRPC Server: [json_str_10, json_str_55]
Rust gRPC Server->>Rust gRPC Server: Parse JSON, Convert to Protobuf
Rust gRPC Server-->>Python Client: OnlineFeaturesResponse
Python Client->>Python Client: Preprocess features
Python Client->>Keras Model: predict(np.array)
Keras Model-->>Python Client: Predictions
这套架构并非没有局限性。首先,我们目前依赖手动运行 PySpark 脚本来同步数据到 Redis,在真实世界里,这应该是一个由 Airflow 调度的批处理任务,或者是一个 Flink/Spark Streaming 实时任务。其次,特征转换逻辑(如 one-hot encoding)目前散落在 Python 训练脚本和推理客户端中,存在不一致的风险。一个改进方向是将这些转换逻辑也下沉到 Rust 服务中,或者使用一个共享的配置(如 YAML)来定义转换规则,确保两端执行相同的操作。最后,系统的可观测性尚待完善,需要加入更详细的指标(如请求延迟、缓存命中率、错误率)和分布式追踪,以便在出现问题时快速定位。