基于 Saga 模式构建一个包含 NumPy 计算任务的分布式事务工作流


我们面临一个棘手的场景:一个风险评估流程被拆分成了三个独立的微服务。第一个服务(risk-model-service)负责执行一个计算密集型的蒙特卡洛模拟,它严重依赖 NumPy 进行大规模矩阵运算,这个过程可能耗时几十秒甚至数分钟。第二个服务(portfolio-service)负责在计算成功后,将风险敞口数据持久化到数据库。第三个服务(notification-service)则在数据入库后,向交易终端发送一个关键通知。

整个流程必须具备事务性。如果通知服务失败,那么已经入库的数据必须回滚。如果数据库写入失败,那么长时间运行的计算任务产生的结果需要被视为无效,并进行清理。传统的两阶段提交(2PC)在这里完全不适用,因为它要求所有资源在整个事务期间被锁定,对于一个长达数分钟的计算任务,这会造成灾难性的资源占用和系统雪崩。

最初的尝试是使用一个简单的任务队列配合大量的重试和状态检查逻辑,但这很快就演变成了一个难以维护的状态泥潭。我们需要一个更清晰的、能够处理长周期、多步骤业务流程并保证最终一致性的模式。Saga 模式,特别是其中的“编排式 Saga”(Orchestration),成为了我们的最终选择。它通过一个中央协调器来调用各个参与者服务,并在某个步骤失败时,调用之前已成功步骤的补偿操作。

这里的核心挑战在于,如何将一个纯计算型、本身无事务概念的 NumPy 任务,无缝地融入到一个分布式事务工作流中。本文将复盘我们从零开始,使用 Python 构建一个轻量级 Saga 编排器的全过程,并展示如何将一个 NumPy 计算服务作为其中的一个事务参与者进行管理。

技术选型与编排器设计

我们决定自己实现一个轻量级的 Saga 编排器,而不是引入像 Temporal 或 Cadence 这样重型的框架。主要原因是我们的业务场景相对固定,自研可以让我们对执行逻辑、状态持久化和失败处理有百分之百的控制力,并且能与现有的 Python 技术栈(FastAPI, NumPy)完美融合。

编排器的核心设计思路如下:

  1. Saga 定义: 一个 Saga 工作流被定义为一个步骤(Step)的有序列表。
  2. 步骤 (Step): 每个步骤包含一个正向操作(Action)和一个补偿操作(Compensation)。
  3. 操作 (Action/Compensation): 本质上是一个对其他微服务的 API 调用,包含 HTTP 方法、URL、Payload 等信息。
  4. 状态机: Saga 的执行过程是一个状态机。其状态包括 PENDING, RUNNING, COMPLETED, COMPENSATING, FAILED
  5. 持久化: 编排器自身必须是无状态的,Saga 的执行状态和步骤结果必须持久化(我们选择了 Redis),这样即使编排器崩溃重启,也能从中断的地方继续执行或补偿。

下面是这个 Saga 模型的 Python 数据结构定义。我们使用 Pydantic 来确保类型安全和数据校验。

# src/saga/models.py

import uuid
from typing import List, Dict, Any, Optional
from enum import Enum
from pydantic import BaseModel, Field

class HttpMethod(str, Enum):
    POST = "POST"
    PUT = "PUT"
    DELETE = "DELETE"
    GET = "GET" # 虽然不常见,但某些查询操作也可能作为Saga的一部分

class SagaStatus(str, Enum):
    PENDING = "PENDING"
    RUNNING = "RUNNING"
    COMPLETED = "COMPLETED"
    COMPENSATING = "COMPENSATING"
    FAILED = "FAILED"

class Operation(BaseModel):
    """定义一个具体的操作,即对某个服务的API调用"""
    method: HttpMethod
    url: str
    payload: Optional[Dict[str, Any]] = None
    headers: Optional[Dict[str, str]] = Field(default_factory=dict)

class Step(BaseModel):
    """定义Saga的一个步骤,包含正向操作和补偿操作"""
    name: str
    action: Operation
    compensation: Operation

class StepExecutionRecord(BaseModel):
    """记录每个步骤的执行结果"""
    step_name: str
    action_executed: bool = False
    action_response: Optional[Dict[str, Any]] = None
    compensation_executed: bool = False
    compensation_response: Optional[Dict[str, Any]] = None

class SagaExecutionInstance(BaseModel):
    """代表一个完整的Saga执行实例"""
    instance_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
    saga_name: str
    status: SagaStatus = SagaStatus.PENDING
    steps: List[Step]
    execution_records: List[StepExecutionRecord] = Field(default_factory=list)
    current_step_index: int = 0
    # 存储整个Saga流程的共享上下文,用于步骤间数据传递
    context: Dict[str, Any] = Field(default_factory=dict)

这个数据模型为我们的编排器提供了清晰的结构。SagaExecutionInstance 就像一张执行清单,详细记录了每一步的计划和实际执行情况。context 字段至关重要,它允许后一步骤使用前一步骤的输出,例如,portfolio-service 需要使用 risk-model-service 计算出的结果。

Saga 编排器的实现

编排器的核心是一个循环,它根据当前 Saga 实例的状态驱动整个流程前进或后退(补偿)。我们将它封装在 SagaOrchestrator 类中。

# src/saga/orchestrator.py

import httpx
import redis
import json
import logging
from .models import SagaExecutionInstance, Step, SagaStatus, StepExecutionRecord

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class SagaOrchestrator:
    def __init__(self, redis_url: str):
        self.redis_client = redis.StrictRedis.from_url(redis_url, decode_responses=True)
        # 使用 httpx 支持异步和连接池,这在生产环境中至关重要
        self.http_client = httpx.Client(timeout=180.0) # 计算任务可能耗时较长,设置较长超时

    def _save_state(self, instance: SagaExecutionInstance):
        """将Saga实例的当前状态持久化到Redis"""
        self.redis_client.set(f"saga:{instance.instance_id}", instance.model_dump_json())
        logger.info(f"Saga instance {instance.instance_id} state saved. Status: {instance.status}")

    def _load_state(self, instance_id: str) -> SagaExecutionInstance:
        """从Redis加载Saga实例状态"""
        state_json = self.redis_client.get(f"saga:{instance_id}")
        if not state_json:
            raise ValueError(f"Saga instance {instance_id} not found.")
        return SagaExecutionInstance.model_validate_json(state_json)

    def create_and_start_saga(self, saga_name: str, steps: List[Step], initial_context: dict) -> SagaExecutionInstance:
        """创建并启动一个新的Saga实例"""
        instance = SagaExecutionInstance(saga_name=saga_name, steps=steps, context=initial_context)
        instance.status = SagaStatus.RUNNING
        self._save_state(instance)
        logger.info(f"New saga '{saga_name}' instance {instance.instance_id} created and started.")
        self.run(instance.instance_id)
        return self._load_state(instance.instance_id)

    def run(self, instance_id: str):
        """主执行逻辑,驱动Saga状态机"""
        instance = self._load_state(instance_id)
        
        if instance.status == SagaStatus.RUNNING:
            self._execute_forward(instance)
        elif instance.status == SagaStatus.COMPENSATING:
            self._execute_backward(instance)
        else:
            logger.warning(f"Saga instance {instance.instance_id} is in a terminal state: {instance.status}. No action taken.")

    def _execute_forward(self, instance: SagaExecutionInstance):
        """正向执行逻辑"""
        while instance.current_step_index < len(instance.steps):
            step = instance.steps[instance.current_step_index]
            record = StepExecutionRecord(step_name=step.name)
            
            try:
                logger.info(f"Executing action for step '{step.name}' in saga {instance.instance_id}")
                
                # 从上下文中动态替换payload中的占位符
                action_payload = self._render_payload(step.action.payload, instance.context)
                
                response = self.http_client.request(
                    method=step.action.method.value,
                    url=step.action.url,
                    json=action_payload,
                    headers=step.action.headers
                )
                response.raise_for_status() # 关键:非2xx状态码会抛出异常
                
                response_data = response.json()
                record.action_executed = True
                record.action_response = response_data
                instance.execution_records.append(record)
                
                # 将当前步骤的输出合并到全局上下文中
                instance.context.update(response_data)
                
                instance.current_step_index += 1
                self._save_state(instance)

            except httpx.HTTPStatusError as e:
                logger.error(f"Action for step '{step.name}' failed with status {e.response.status_code}. Starting compensation for saga {instance.instance_id}")
                record.action_executed = False
                record.action_response = {"error": str(e), "status_code": e.response.status_code}
                instance.execution_records.append(record)
                instance.status = SagaStatus.COMPENSATING
                self._save_state(instance)
                self._execute_backward(instance)
                return # 终止正向流程
            except Exception as e:
                logger.error(f"An unexpected error occurred during action for step '{step.name}': {e}. Starting compensation.")
                record.action_executed = False
                record.action_response = {"error": str(e)}
                instance.execution_records.append(record)
                instance.status = SagaStatus.COMPENSATING
                self._save_state(instance)
                self._execute_backward(instance)
                return
        
        # 所有步骤成功
        instance.status = SagaStatus.COMPLETED
        self._save_state(instance)
        logger.info(f"Saga instance {instance.instance_id} completed successfully.")

    def _execute_backward(self, instance: SagaExecutionInstance):
        """反向补偿逻辑"""
        # 从最后一个成功的步骤开始补偿
        # 注意: current_step_index 指向的是失败的或下一个要执行的步骤,所以补偿要从它前面一个开始
        # 同时,execution_records 记录了所有已尝试的步骤
        last_executed_step_index = len(instance.execution_records) - 1

        for i in range(last_executed_step_index, -1, -1):
            record = instance.execution_records[i]
            # 只对已成功执行正向操作的步骤进行补偿
            if record.action_executed and not record.compensation_executed:
                step = next((s for s in instance.steps if s.name == record.step_name), None)
                if not step: continue

                try:
                    logger.warning(f"Executing compensation for step '{step.name}' in saga {instance.instance_id}")
                    
                    compensation_payload = self._render_payload(step.compensation.payload, instance.context)

                    response = self.http_client.request(
                        method=step.compensation.method.value,
                        url=step.compensation.url,
                        json=compensation_payload,
                        headers=step.compensation.headers
                    )
                    response.raise_for_status()
                    
                    record.compensation_executed = True
                    record.compensation_response = response.json()
                    
                except Exception as e:
                    # 补偿失败是一个严重问题,需要人工介入
                    logger.critical(f"Compensation for step '{step.name}' FAILED: {e}. Saga {instance.instance_id} is now in a failed state requiring manual intervention.")
                    instance.status = SagaStatus.FAILED
                    self._save_state(instance)
                    return # 停止进一步补偿
            # 更新状态
            self.redis_client.set(f"saga:{instance.instance_id}", instance.model_dump_json())

        instance.status = SagaStatus.FAILED
        self._save_state(instance)
        logger.error(f"Saga instance {instance.instance_id} failed and compensation process finished.")

    def _render_payload(self, payload: Optional[Dict], context: Dict) -> Optional[Dict]:
        """使用Saga上下文动态渲染payload模板"""
        if payload is None:
            return None
        
        # 简单的模板替换,例如 "user_id": "{initial_user_id}"
        payload_str = json.dumps(payload)
        for key, value in context.items():
            placeholder = "{" + str(key) + "}"
            # 注意:JSON序列化会把非字符串值转换,这里只替换原始字符串里的占位符
            if isinstance(value, (str, int, float)):
                 payload_str = payload_str.replace(placeholder, str(value))
        
        return json.loads(payload_str)

这个编排器实现了Saga的核心逻辑:顺序执行、状态持久化、失败检测和反向补偿。_render_payload 方法虽然简单,但对于实现步骤间数据传递至关重要。

将 NumPy 计算服务作为事务参与者

现在,我们来构建那个特殊的参与者:risk-model-service。它使用 FastAPI 框架,并提供一个执行计算的端点。

# services/risk_model_service.py

from fastapi import FastAPI, HTTPException, Body
from pydantic import BaseModel
import numpy as np
import time
import uuid
import logging

# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

app = FastAPI()

# 模拟一个存储,用于存放计算产生的临时结果
# 在真实项目中,这可能是S3, HDFS, 或一个临时文件系统
TEMP_RESULTS_STORE = {}

class ModelInput(BaseModel):
    portfolio_id: str
    simulation_count: int = 100000
    # 模拟一个会触发失败的开关,用于测试补偿逻辑
    force_failure: bool = False

@app.post("/calculate")
async def calculate_risk(data: ModelInput):
    """
    执行计算密集型任务
    """
    calculation_id = str(uuid.uuid4())
    logger.info(f"Starting risk calculation {calculation_id} for portfolio {data.portfolio_id}...")

    try:
        # 这是一个耗时的NumPy操作
        # 模拟生成随机资产回报率
        # 在真实场景中,这会是更复杂的金融模型
        returns = np.random.normal(0.001, 0.02, (data.simulation_count, 100))
        
        # 模拟复杂的矩阵运算
        covariance_matrix = np.cov(returns, rowvar=False)
        # 模拟一个长时间的计算过程
        time.sleep(5) 
        
        if data.force_failure:
             raise ValueError("Forced failure for testing compensation.")

        # 计算风险价值 (VaR) - 这是一个简化示例
        portfolio_value = np.sum(returns, axis=1)
        var_95 = np.percentile(portfolio_value, 5)

        result = {
            "calculation_id": calculation_id,
            "portfolio_id": data.portfolio_id,
            "var_95": var_95,
            "status": "completed"
        }
        
        # 将结果存入临时存储,等待Saga后续步骤确认
        TEMP_RESULTS_STORE[calculation_id] = result
        logger.info(f"Calculation {calculation_id} completed successfully. Result stored temporarily.")
        
        return result

    except Exception as e:
        logger.error(f"Calculation {calculation_id} failed: {e}")
        raise HTTPException(status_code=500, detail=str(e))


@app.delete("/calculate/{calculation_id}")
async def cleanup_calculation(calculation_id: str):
    """
    补偿操作:清理计算产生的临时结果
    这个接口必须是幂等的。
    """
    logger.warning(f"Executing compensation: cleaning up calculation {calculation_id}.")
    if calculation_id in TEMP_RESULTS_STORE:
        del TEMP_RESULTS_STORE[calculation_id]
        logger.info(f"Cleanup for calculation {calculation_id} successful.")
        return {"status": "cleaned_up", "calculation_id": calculation_id}
    else:
        # 即使资源已经不存在,也应该返回成功,保证补偿操作的幂等性
        logger.warning(f"Calculation {calculation_id} not found during cleanup, assuming already cleaned.")
        return {"status": "not_found_or_already_cleaned", "calculation_id": calculation_id}

这个服务的设计关键点在于:

  1. 正向操作 (/calculate): 它执行计算,并将结果保存在一个临时位置 (TEMP_RESULTS_STORE)。它返回一个 calculation_id,这个 ID 是后续步骤(包括补偿)定位这个计算结果的句柄。
  2. 补偿操作 (/calculate/{calculation_id}): 它的职责是清理正向操作产生的“副作用”。在这里就是删除临时存储中的结果。这个操作被设计为幂等的,多次调用同一个 calculation_id 不会产生错误,这是健壮的补偿操作的基本要求。

组装和运行完整的 Saga 工作流

其他两个服务 (portfolio-servicenotification-service) 的实现比较常规,这里只给出它们的接口定义和补偿逻辑。

Portfolio Service:

  • POST /portfolios/data: 接收计算结果并存入数据库,返回一个 record_id
  • DELETE /portfolios/data/{record_id}: 补偿操作,根据 record_id 删除数据库记录。

Notification Service:

  • POST /notifications/send: 发送通知。
  • POST /notifications/retract/{notification_id}: 补偿操作,发送一个“撤回”通知。在某些场景下,补偿并非删除,而是执行一个反向的业务操作。

现在,我们可以定义一个完整的 Saga 工作流并执行它。

# run_saga.py

import os
from src.saga.models import Step, Operation, HttpMethod
from src.saga.orchestrator import SagaOrchestrator

# --- 服务地址配置 ---
# 在生产环境中,这些应该来自环境变量或配置中心
RISK_SERVICE_URL = os.getenv("RISK_SERVICE_URL", "http://127.0.0.1:8001")
PORTFOLIO_SERVICE_URL = os.getenv("PORTFOLIO_SERVICE_URL", "http://127.0.0.1:8002")
NOTIFICATION_SERVICE_URL = os.getenv("NOTIFICATION_SERVICE_URL", "http://127.0.0.1:8003")
REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/0")

# --- Saga 定义 ---
risk_analysis_saga_steps = [
    Step(
        name="CalculateRiskModel",
        action=Operation(
            method=HttpMethod.POST,
            url=f"{RISK_SERVICE_URL}/calculate",
            payload={"portfolio_id": "{portfolio_id}", "force_failure": "{force_failure}"}
        ),
        compensation=Operation(
            method=HttpMethod.DELETE,
            url=f"{RISK_SERVICE_URL}/calculate/{{calculation_id}}"
        )
    ),
    Step(
        name="PersistPortfolioData",
        action=Operation(
            method=HttpMethod.POST,
            url=f"{PORTFOLIO_SERVICE_URL}/portfolios/data",
            # 使用上一步的输出作为输入
            payload={
                "calculation_id": "{calculation_id}",
                "risk_metrics": {"var_95": "{var_95}"}
            }
        ),
        compensation=Operation(
            method=HttpMethod.DELETE,
            url=f"{PORTFOLIO_SERVICE_URL}/portfolios/data/{{record_id}}"
        )
    ),
    Step(
        name="SendTradeTerminalNotification",
        action=Operation(
            method=HttpMethod.POST,
            url=f"{NOTIFICATION_SERVICE_URL}/notifications/send",
            payload={
                "portfolio_id": "{portfolio_id}",
                "message": "Risk analysis for portfolio {portfolio_id} is complete."
            }
        ),
        compensation=Operation(
            method=HttpMethod.POST,
            url=f"{NOTIFICATION_SERVICE_URL}/notifications/retract/{{notification_id}}",
            payload={"reason": "Risk analysis process was rolled back."}
        )
    )
]

def main():
    orchestrator = SagaOrchestrator(redis_url=REDIS_URL)

    # 场景一: 成功执行
    print("--- SCENARIO 1: SUCCESSFUL EXECUTION ---")
    initial_context_success = {"portfolio_id": "P-123", "force_failure": False}
    success_instance = orchestrator.create_and_start_saga(
        saga_name="RiskAnalysisWorkflow",
        steps=risk_analysis_saga_steps,
        initial_context=initial_context_success
    )
    print(f"Final Saga State (Success): {success_instance.status.value}")
    print(f"Final Context: {success_instance.context}")
    print("-" * 40)

    # 场景二: 在持久化步骤失败,触发补偿
    print("--- SCENARIO 2: FAILURE AT PERSISTENCE STEP ---")
    # 假设 portfolio_service 的 /portfolios/data 接口被修改为,当看到特定ID时返回500
    initial_context_fail = {"portfolio_id": "P-FAIL-DB", "force_failure": False}
    # 注入一个模拟失败的逻辑到portfolio-service中,此处省略服务代码,仅演示Saga流程
    # 在真实测试中,会使用mock或真实的服务故障
    # 这里我们通过修改action的payload来模拟
    fail_steps = risk_analysis_saga_steps[:]
    fail_steps[1].action.payload["trigger_db_failure"] = True # 假设服务会识别这个标志
    
    fail_instance = orchestrator.create_and_start_saga(
        saga_name="RiskAnalysisWorkflow",
        steps=fail_steps,
        initial_context=initial_context_fail
    )
    print(f"Final Saga State (Failure): {fail_instance.status.value}")
    print("Execution Records:")
    for record in fail_instance.execution_records:
        print(f"  - Step: {record.step_name}, Action Executed: {record.action_executed}, Compensation Executed: {record.compensation_executed}")
    print("-" * 40)


if __name__ == "__main__":
    # 需要先启动三个模拟服务和Redis
    main()

为了让上面的 run_saga.py 可运行,你需要实现 portfolio-servicenotification-service 的简单 FastAPI 应用,它们逻辑与 risk-model-service 类似,提供正向和补偿端点即可。

下面是整个流程的Mermaid图示,清晰地展示了成功和失败两种路径。

sequenceDiagram
    participant Orchestrator
    participant RiskModelService
    participant PortfolioService
    participant NotificationService

    %% -- 成功路径 --
    Orchestrator->>RiskModelService: POST /calculate (portfolio_id: P-123)
    activate RiskModelService
    Note right of RiskModelService: NumPy calculation runs (5s)
    RiskModelService-->>Orchestrator: 200 OK {calculation_id: C1, ...}
    deactivate RiskModelService
    
    Orchestrator->>PortfolioService: POST /portfolios/data (calculation_id: C1)
    activate PortfolioService
    PortfolioService-->>Orchestrator: 200 OK {record_id: R1}
    deactivate PortfolioService

    Orchestrator->>NotificationService: POST /notifications/send (portfolio_id: P-123)
    activate NotificationService
    NotificationService-->>Orchestrator: 200 OK {notification_id: N1}
    deactivate NotificationService
    Note over Orchestrator: Saga COMPLETED

    %% -- 失败与补偿路径 --
    Orchestrator->>RiskModelService: POST /calculate (portfolio_id: P-FAIL-DB)
    activate RiskModelService
    Note right of RiskModelService: NumPy calculation runs (5s)
    RiskModelService-->>Orchestrator: 200 OK {calculation_id: C2, ...}
    deactivate RiskModelService

    Orchestrator->>PortfolioService: POST /portfolios/data (calculation_id: C2)
    activate PortfolioService
    Note right of PortfolioService: Database error occurs
    PortfolioService-->>Orchestrator: 500 Internal Server Error
    deactivate PortfolioService

    Note over Orchestrator: Step failed, starting compensation
    Orchestrator->>RiskModelService: DELETE /calculate/C2
    activate RiskModelService
    Note right of RiskModelService: Cleans up temporary results
    RiskModelService-->>Orchestrator: 200 OK {status: cleaned_up}
    deactivate RiskModelService
    Note over Orchestrator: Saga FAILED

遗留问题与未来迭代

这个自研的 Saga 编排器虽然解决了我们的核心问题,但它并非一个完备的生产级解决方案。当前实现存在几个明显的局限性:

  1. 编排器单点故障: 尽管 Saga 状态被持久化,但编排器进程本身是单点的。如果它在执行一个 API 调用期间崩溃,状态可能不会被及时更新。生产环境需要一个高可用的编排器集群,通过分布式锁(如 RedLock)来确保同一时间只有一个实例在处理某个 Saga。

  2. 同步执行: 当前的 run 方法是同步阻塞的。对于大量并发的 Saga,这会耗尽线程资源。可以将其改造为基于消息队列的异步模型。例如,create_and_start_saga 只是向队列中投递一个 “start” 消息,由一个或多个 worker 进程消费消息并驱动 Saga 状态机。

  3. 补偿操作的可靠性: 我们假设补偿操作总能成功。但在现实中,补偿操作本身也可能失败。当前实现会将 Saga 标记为 FAILED 并停止,这需要人工介入。更复杂的系统可能会引入重试机制、备用补偿策略,或者将失败的补偿任务推送到一个“死信队列”中等待修复。

  4. 对参与者服务的侵入性: Saga 模式要求参与者服务必须提供补偿接口。这是一种设计上的耦合。对于无法修改的第三方服务,可能需要引入“防腐层”(Anti-Corruption Layer)来适配。

尽管存在这些局限,这次构建过程证明了 Saga 模式在处理包含非传统事务资源(如一个长时间运行的 NumPy 计算任务)的分布式流程中的价值。它通过牺牲隔离性和原子性,换取了业务流程的最终一致性和系统的整体可用性,这正是在微服务架构中常常需要做出的权衡。


  目录