Agent 回滚设计:当自动化任务写坏文件时如何恢复

30秒要点

  • 核心问题:Agent 自主写文件时,一次错误的写入可以损坏配置、破坏数据库状态、引入难以定位的 bug——而且这些错误不是简单的重试能修复的。没有回滚机制,修复代价是从"重跑一次"升级到"手动取证恢复"。
  • 三种策略:快照(Snapshot)→ 操作前保存完整状态,失败即还原;事务(Transaction)→ 写前日志 + 原子提交,精确到单次写入;补偿(Compensation)→ 执行语义逆向操作,应对无法快照的外部副作用。
  • 关键设计:回滚不是事后补救——它必须在写入之前就准备好撤销路径。本文的核心主张是:每一次 Agent 写操作都应在系统中注册一个可执行的 undo handler,否则不执行写入。
  • 读完能做什么:为你的 Agent 系统设计分层回滚架构——根据操作类型自动选择快照/事务/补偿策略,在文件、数据和环境三个层面建立系统性撤销能力,让 Agent 的每一次写入都可逆。

1. 为什么回滚是自主 Agent 的一等需求

一个部署 Agent 正在执行"在 14 台服务器上更新 Nginx 配置,启用新的限流规则"。它在第 9 台服务器上替换了 /etc/nginx/nginx.conf,但 limit_req_zone 的语法写错了一个参数——Nginx 重启失败,这第 9 台服务现在离线。更糟的是,由于这 9 台服务器是灰度发布的前 9 台,剩下的 5 台还没有更新——你已经陷入了第 9 台挂掉、剩下 5 台不敢动的夹缝。

此时你需要的不是"重试"——重试会用同样的错误配置再次写入,再次失败。你需要的不是"版本控制"——你确实可以用 git checkout nginx.conf 恢复配置文件,但 Agent 可能同时修改了 /etc/hosts、写入了新的 systemd unit 文件、在 iptables 里添加了临时规则。版本控制只覆盖它能追踪的文件,而 Agent 的写入范围远超这个边界。

你需要的是一套系统性的回滚机制:在每一次写入操作执行之前,系统已经准备好了足够的信息和逻辑来撤销这次写入——无论写入的对象是文件、数据库行、配置项,还是外部 API 调用。

Agent 写坏文件的五种典型场景

自主 Agent 的写入操作有三种特质使它天然容易出错:高自主性(Agent 自己决定写什么和怎么写)、高破坏力(一次写入可以影响生产环境)、低可见性(写入发生在 Agent 内部逻辑中,人类要到结果出问题才能感知)。以下五种场景是生产环境中反复出现的:

  1. 配置损坏(Config Corruption):Agent 在"优化"配置文件时引入语法错误或语义矛盾。典型的例子是 Agent 修改了一个 YAML 配置文件,缩进错误导致整个服务配置树被解析为完全不同的结构。配置文件不像代码——它们没有编译期检查,只有运行期崩溃。一次错误写入可能只有 2 字节的差异(一个空格 vs 一个 tab),但后果是服务不可用。
  2. 文件覆盖(Silent Overwrite):Agent 生成新内容时错误地覆盖了已有的关键文件。例如 Agent 被要求"创建 README.md",但它没有检查文件是否存在就写入了,导致原有的 README 被替换为空模板。更隐蔽的情况:Agent 在第 3 步写入了一个临时文件,第 17 步误把它当作另一个文件覆盖,第 42 步发现错误时已经无法追溯原始内容。
  3. 级联写入(Cascading Write Failure):Agent 的一个错误写入触发了下游系统的连锁反应。Agent 修改了数据库中的一条配置记录,这条记录被 5 个微服务读取——5 个微服务先后进入错误状态。修复不只是恢复数据库字段——你需要重启或恢复 5 个微服务的运行时状态。
  4. 部分完成的破坏性操作(Partial Mutation):Agent 执行一个多步操作(如"把 user-service 从 JWT 认证迁移到 OAuth2"),但在第 4 步(共 7 步)时失败。前 3 步已经修改了代码、配置和环境变量——系统现在处于一个既不是 JWT 也不是 OAuth2 的不一致状态。重试不能解决问题,因为前提条件已经被前 3 步改变。
  5. 非确定性写入(Non-Deterministic Write):同样的 Agent 指令在不同环境下产生不同的写入结果。Agent 在 staging 环境写入了一个格式正确的配置文件,但在 production 环境因为依赖库版本差异写了错误的格式。这种跨环境的不确定性使"在 staging 验证通过就行"的假设失效——回滚必须是每个独立环境的最后防线。

关键洞察:Agent 的写入错误与人类工程师的写入错误有本质区别——人类工程师在写入之前往往会做检查(cat 看内容、diff 看差异、--dry-run 预演),而 Agent 默认不做这些。更关键的是,人类遇到错误可以先停下来思考,而 Agent 的"思考"本身就在消耗同样的注意力预算——上下文越往后,写入决策越容易出错。回滚不是 Agent 的 nice-to-have 特性——它是 Agent 自主写入模型中缺失的那一环。

为什么版本控制和重试都不够

很多工程师的第一反应是:"我们不是有 git 吗?"或者"让 Agent 重试一下不行吗?"这两种直觉在 Agent 的上下文中是危险的误解:

正确的理解是:回滚是一个独立于版本控制和重试的设计维度。版本控制给你"恢复到已知好状态"的能力(粗粒度),重试给你"在瞬态故障后重新尝试"的能力(无状态前提),而回滚给你"精确撤销一次特定写入操作"的能力(细粒度、有状态、可组合)。

关于 Agent 的故障恢复全貌,参见 Agent 错误恢复——回滚是恢复策略谱系中最激进的一端。而回滚需要在写入之前就知道写入发生了——这要求可观察性基础设施的配合,参见 Agent 可观察性。同时,准确的审计日志是确定"需要回滚到什么范围"的唯一依据,参见 Agent 审计日志设计

回滚的设计哲学:先准备撤销路径,再执行写入

本文的核心设计原则可以浓缩为一句话:没有注册 undo handler 的写操作不应该被执行。在代码层面,这意味着每次 Agent 写入操作之前,系统必须:(1)捕获写入前的状态(或记录足够的信息来重建它),(2)注册一个可调用的 undo 函数,(3)在写入成功后验证 undo 函数的有效性(至少验证其存在性和参数合法性)。只有当这三步完成,实际的写入才被允许执行。

下面的代码展示了这个原则的最简实现——一个异常类和对应的上下文管理器,作为所有回滚策略的基础抽象层:

from __future__ import annotations

import os
import shutil
import tempfile
from dataclasses import dataclass, field
from pathlib import Path
from typing import Callable, Optional, Any
from contextlib import contextmanager
import logging

logger = logging.getLogger("agent.rollback")


# ---------------------------------------------------------------------------
# RollbackNeededError — any write operation failure throws this unified
# exception. Carries undo_handler so callers can execute rollback without
# knowing the operation's internal details.
# ---------------------------------------------------------------------------

@dataclass
class RollbackNeededError(Exception):
    """
    Raised when a write operation fails after making partial changes.
    Carries the undo handler so callers can revert without knowing
    the operation's internals.
    """
    message: str
    operation_id: str
    # undo() restores the pre-write state. Returns True on success.
    undo_handler: Callable[[], bool]

    def __str__(self) -> str:
        return f"RollbackNeededError(op={self.operation_id}): {self.message}"


# ---------------------------------------------------------------------------
# safe_agent_write — a context manager that enforces "prepare undo first"
# for any file write operation. Every write is wrapped in a
# snapshot → write → validate pipeline. If validation fails or an exception
# occurs, the snapshot is restored.
# ---------------------------------------------------------------------------

@dataclass
class WriteRecord:
    """Metadata for a single write operation within the context manager."""
    file_path: Path
    snapshot_path: Optional[Path] = None
    succeeded: bool = False
    undo_registered: bool = False


class WriteContext:
    """Accumulator for write records; caller inspects after the block."""
    def __init__(self) -> None:
        self.records: list[WriteRecord] = []

    def add(self, record: WriteRecord) -> None:
        self.records.append(record)

    @property
    def failed(self) -> list[WriteRecord]:
        return [r for r in self.records if not r.succeeded]

    @property
    def succeeded(self) -> list[WriteRecord]:
        return [r for r in self.records if r.succeeded]


@contextmanager
def safe_agent_write(
    file_path: str | Path,
    content: str | bytes,
    *,
    operation_id: str = "",
    staging_dir: str | Path = "",
    validator: Callable[[Path], bool] | None = None,
    ctx: WriteContext | None = None,
):
    """
    Context manager that guarantees a file write is undoable.

    Pipeline:
      1. SNAPSHOT — copy the target file to a temp location (if it exists).
      2. STAGE   — write content to a staging file (never in-place).
      3. VALIDATE — run optional validator on the staged content.
      4. ATOMIC REPLACE — os.replace() from staging to target.
      5. ON FAILURE — restore snapshot if anything goes wrong.

    Design decisions:
      - Stage-never-write-in-place: the target file is never partially
        modified. Either the staged file passes validation and atomically
        replaces the target, or the target is untouched.
      - Snapshot-on-first-touch: the snapshot is taken only when content
        differs from current file, avoiding unnecessary copies for no-op
        writes.
      - Validator hook: caller can supply a function that inspects the
        staged file (e.g., linter, schema check, dry-run) before the atomic
        replace. If the validator returns False or raises, the write is
        aborted and snapshot restored.

    Args:
        file_path:    Target file to write.
        content:      Content to write (str or bytes).
        operation_id: Unique ID for logging and error tracing.
        staging_dir:  Where to create staging files (default: target's dir).
        validator:    Optional callable(Path) -> bool.
        ctx:          Optional WriteContext to accumulate records.

    Yields:
        Path to the staging file (so caller can inspect before commit).

    Raises:
        RollbackNeededError: if any step fails, with undo_handler attached.
    """
    target = Path(file_path).resolve()
    staging_root = Path(staging_dir) if staging_dir else target.parent
    staging_root.mkdir(parents=True, exist_ok=True)

    record = WriteRecord(file_path=target)
    snapshot: Optional[Path] = None
    staging: Optional[Path] = None
    undo_registered = False

    try:
        # -- Step 1: Snapshot ------------------------------------------------
        # Only snapshot if the file exists AND content actually differs.
        # No-op writes should not generate snapshots.
        if target.exists():
            current_bytes = target.read_bytes()
            new_bytes = (
                content if isinstance(content, bytes)
                else content.encode("utf-8")
            )
            if current_bytes == new_bytes:
                # Content unchanged -- skip the write entirely
                record.succeeded = True
                if ctx:
                    ctx.add(record)
                yield target
                return

            # Content differs: create snapshot
            snapshot = staging_root / f".rollback-snap-{operation_id}-{target.name}"
            shutil.copy2(target, snapshot)
            logger.debug("Snapshot created: %s -> %s", target, snapshot)
        else:
            # File does not exist yet -- nothing to snapshot. Undo means
            # deleting the new file if creation fails validation.
            pass

        # -- Step 2: Stage ---------------------------------------------------
        staging = staging_root / f".rollback-stage-{operation_id}-{target.name}"
        if isinstance(content, str):
            staging.write_text(content, encoding="utf-8")
        else:
            staging.write_bytes(content)
        logger.debug("Staged content: %s (%d bytes)", staging, len(content))

        # -- Step 3: Register undo handler BEFORE committing -----------------
        def undo() -> bool:
            """Restore the pre-write state."""
            try:
                if snapshot is not None and snapshot.exists():
                    # Restore from snapshot
                    shutil.copy2(snapshot, target)
                    snapshot.unlink(missing_ok=True)
                    logger.info(
                        "Rollback: restored %s from snapshot %s",
                        target, snapshot,
                    )
                elif not target.exists():
                    # File was created by this write -- nothing to restore
                    logger.info(
                        "Rollback: nothing to restore for %s (was new file)",
                        target,
                    )
                else:
                    # Edge case: snapshot missing but file exists.
                    # Best effort: remove the file created by this op.
                    logger.warning(
                        "Rollback: no snapshot for %s, removing as best-effort",
                        target,
                    )
                    target.unlink(missing_ok=True)
                return True
            except Exception as exc:
                logger.error("Rollback failed for %s: %s", target, exc)
                return False

        record.undo_registered = True
        undo_registered = True

        # -- Step 4: Validate staged content ---------------------------------
        if validator is not None:
            try:
                if not validator(staging):
                    raise RollbackNeededError(
                        message=f"Validation rejected staged content for {target}",
                        operation_id=operation_id,
                        undo_handler=undo,
                    )
            except RollbackNeededError:
                raise
            except Exception as exc:
                raise RollbackNeededError(
                    message=f"Validator raised exception: {exc}",
                    operation_id=operation_id,
                    undo_handler=undo,
                ) from exc

        # Yield staging path so caller can inspect before commit
        yield staging

        # -- Step 5: Atomic replace ------------------------------------------
        # os.replace is atomic on POSIX -- the target is either the
        # old file or the new one, never a partial write.
        os.replace(staging, target)
        staging = None  # prevent cleanup (already moved)

        # -- Step 6: Cleanup snapshot on success -----------------------------
        if snapshot is not None and snapshot.exists():
            snapshot.unlink()
            snapshot = None

        record.succeeded = True
        logger.info("Write committed: %s (op=%s)", target, operation_id)

    except RollbackNeededError:
        # Re-raise so caller gets the undo_handler
        raise

    except Exception as exc:
        # Any unexpected error becomes a RollbackNeededError with undo
        raise RollbackNeededError(
            message=f"Write operation failed: {exc}",
            operation_id=operation_id,
            undo_handler=(
                lambda: True if not undo_registered
                else _make_cleanup_undo(snapshot, staging, target)
            ),
        ) from exc

    finally:
        # Always add the record so callers can inspect what happened
        if ctx:
            ctx.add(record)

        # Cleanup orphaned temp files (fail-safe, not-already-cleaned)
        if staging is not None:
            staging.unlink(missing_ok=True)
        if snapshot is not None and snapshot.exists():
            # Snapshot still exists -> write didn't succeed, restore it
            shutil.copy2(snapshot, target)
            snapshot.unlink(missing_ok=True)
            logger.warning(
                "Rollback: emergency restore of %s from snapshot", target,
            )


def _make_cleanup_undo(
    snapshot: Optional[Path],
    staging: Optional[Path],
    target: Path,
) -> Callable[[], bool]:
    """Factory for undo handler on unexpected errors."""
    def undo() -> bool:
        try:
            if staging is not None:
                staging.unlink(missing_ok=True)
            if snapshot is not None and snapshot.exists():
                shutil.copy2(snapshot, target)
                snapshot.unlink(missing_ok=True)
                logger.info(
                    "Cleanup rollback: restored %s from snapshot", target,
                )
            return True
        except Exception as exc:
            logger.error("Cleanup rollback failed: %s", exc)
            return False
    return undo


# ---------------------------------------------------------------------------
# Usage example — demonstrating the full lifecycle
# ---------------------------------------------------------------------------

if __name__ == "__main__":
    import json

    # Setup: create a test file to write into
    test_file = Path("/tmp/agent-rollback-demo.conf")
    test_file.write_text('server { listen 80; }\n')

    ctx = WriteContext()

    # -- Successful write ----------------------------------------------------
    try:
        with safe_agent_write(
            test_file,
            'server { listen 443 ssl; }\n',
            operation_id="nginx-tls-upgrade",
            ctx=ctx,
            # Validator: staged file must contain "listen"
            validator=lambda p: "listen" in p.read_text(),
        ) as staged:
            print(f"Staged at: {staged}")
            # If we reach here without exception, atomic replace happens
    except RollbackNeededError as e:
        print(f"ROLLBACK TRIGGERED: {e}")
        success = e.undo_handler()
        print(f"Undo result: {success}")

    print(f"File content after write: {test_file.read_text()!r}")
    print(
        f"Write records: {len(ctx.records)} total, "
        f"{len(ctx.succeeded)} ok, {len(ctx.failed)} failed"
    )

    # -- Failed write (validator rejects) ------------------------------------
    ctx2 = WriteContext()
    try:
        with safe_agent_write(
            test_file,
            'bad config content\n',
            operation_id="nginx-bad-config",
            ctx=ctx2,
            # Validator: reject content that doesn't contain "server"
            validator=lambda p: "server" in p.read_text(),
        ):
            pass  # This should never be reached
    except RollbackNeededError as e:
        print(f"\nEXPECTED ROLLBACK: {e}")
        success = e.undo_handler()
        print(f"Undo result: {success}")

    # File should still contain the content from the successful write
    print(f"File content after failed write: {test_file.read_text()!r}")
    print(
        f"Write records: {len(ctx2.records)} total, "
        f"{len(ctx2.succeeded)} ok, {len(ctx2.failed)} failed"
    )

    # Cleanup
    test_file.unlink(missing_ok=True)

这个上下文管理器的设计体现了本文的核心理念:在实际写入发生之前,撤销路径已经构建完毕RollbackNeededError 携带 undo_handler 的设计使调用者可以在不关心写入操作内部细节的情况下执行回滚——你不需要知道是文件写入还是数据库写入,只要调用 undo() 即可。这种"回滚接口与操作实现解耦"的模式,将在后面的章节中贯穿所有回滚策略。

但实际场景远比单个文件写入复杂。Agent 可能同时操作多个文件、修改数据库记录、调用外部 API。下一节将讨论三种基础回滚策略——快照、事务和补偿——以及如何根据操作类型自动选择合适的策略。

2. 回滚策略:快照、事务、补偿——三种武器,各有所长

并不是所有操作都能用同一种方式回滚。对本地文件的修改、对数据库的事务性写入、对 Stripe 的 API 订阅创建——这三种操作的回滚机制在原理上完全不同。强行用一种策略处理所有操作会得到两种结果之一:要么过于昂贵(为简单操作拍快照浪费磁盘),要么无法回滚(事务无法覆盖外部 API)。

本节定义了三种基础回滚策略,并提供了一个编排器——RollbackOrchestrator——它根据操作类型自动选择策略。

策略一:快照回滚——操作前保存完整状态

核心思想:在执行任何修改之前,先把被修改对象的完整当前状态保存到一份快照中。如果操作失败或需要撤销,用快照覆写回原状态。

适用场景:

优势:实现简单,理解成本低;回滚速度快(一次副本操作);不依赖操作的语义理解——你不需要知道 Agent 做了什么,只需要恢复到操作之前的状态。

代价:存储开销大——每次写入都保存一份完整副本。一个 2GB 的数据库文件被 Agent 写入 100 次 → 200GB 的快照。多文件操作需要整体快照或每个文件独立快照。

何时选择:(1)被操作对象较小(< 100MB);(2)写入频率低(每分钟 < 5 次);(3)操作是不可分割的整体——如果一个文件在回滚时需要恢复到操作前,其他文件也必须同步恢复,单文件快照就不够。

策略二:事务回滚——写前日志 + 原子提交

核心思想:不存完整副本,只记录"我打算做什么"(intent log)和"我做了什么"(commit log)。操作执行前写入 intent 条目,执行后将 intent 标记为 committed。回滚时,对于 committed 条目执行逆向操作,对于未 committed 条目直接丢弃。

适用场景:

优势:存储开销极小(只存 delta 和 intent,不存完整副本);支持细粒度回滚(只回滚一个键而非整个文件);交易语义——all-or-nothing 的原子性。从审计角度,WAL 本身就是一份可查询的操作日志——关于审计基础设施,参见 Agent 审计日志设计

代价:实现复杂度高——需要为每种数据类型实现 intent 格式和逆向操作逻辑。事务的原子性要求所有参与方(文件、数据库、配置)都支持两阶段提交——跨系统事务是一个分布式系统问题。同时,事务不能覆盖外部副作用——Stripe API 的这笔扣款是真实发生的,没有"回滚"按钮。

何时选择:(1)操作对象是大文件或高频写入——快照成本不可接受;(2)需要原子性——多文件多行要么全成功要么全撤销;(3)操作具有明确的前后状态差异——可以用 delta 描述而无需完整副本。

策略三:补偿回滚——执行一个语义上等价的逆向操作

核心思想:对于无法通过快照或事务恢复的操作(通常是有外部副作用或语义复杂的操作),编写一个专用的补偿函数——它不恢复状态到操作前,而是执行一个语义上"抵消"原始操作效果的动作。

适用场景:

优势:覆盖快照和事务无法处理的场景——外部副作用、不可逆状态变更、需要业务语义的撤销("退款"而非"删除数据库行")。

代价:补偿不是真正的回滚——系统不会回到操作前的完全相同状态。邮件已经发出去了,退款的 Stripe fee 可能不会退回。补偿函数本身也可能失败——"取消订阅"API 可能也挂掉了。补偿策略的可靠性天花板取决于外部系统的可靠性。

何时选择:(1)操作有不可逆的外部副作用(API 调用、消息发送);(2)操作需要业务语义上的撤销而非状态恢复(退款、道歉信、删除记录);(3)作为快照/事务之后的补充——"能恢复到文件状态但外部 API 调用已经发生了"的场景。

实践法则:Agent 的所有写操作可以按回滚策略分为三类——本地文件操作(快照优先)、数据库/状态操作(事务优先)、外部 API 操作(补偿兜底)。当一个操作同时涉及多种类型时(例如"修改配置文件 + 重启服务 + 调用监控 API 确认"),需要组合策略。这也是 RollbackOrchestrator 的核心职责——分析操作、选择策略、协调序列执行。

策略决策矩阵

以下矩阵总结了选择策略的标准决策路径。对于一次 Agent 写操作,按此表从高到低检查条件:


  操作特征                          │ 推荐策略      │ 回退策略
  ──────────────────────────────────┼───────────────┼─────────────
  纯本地文件写入,文件 < 100MB       │ 快照 Snapshot │ 事务(WAL)
  纯本地文件写入,文件 > 100MB       │ 事务 WAL      │ 快照(如果磁盘允许)
  多文件原子写入(> 1个文件)        │ 事务 WAL      │ 每文件快照
  包含数据库写入                    │ 事务 WAL      │ 补偿
  包含外部 API 调用                 │ 补偿          │ 不可回滚(记录+告警)
  高频率写入 (>5次/分钟/文件)       │ 事务 WAL      │ 快照+合并
  操作需要语义级撤销 (如退款)        │ 补偿          │ 不可回滚
  操作对象完全不可逆 (如物理删除)    │ 不可回滚       │ 审计日志+人工介入
  跨 Agent 的分布式写入             │ 补偿/协调      │ 人工介入
  

RollbackOrchestrator:策略选择与执行编排

下面的代码实现了一个完整的回滚编排器——它在操作执行前分析操作类型,自动选择快照/事务/补偿策略,并管理整个 undo handler 的注册和执行生命周期:

from __future__ import annotations

import enum
import hashlib
import json
import os
import shutil
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Callable, Optional, Any, Protocol
from collections.abc import Sequence
import logging

logger = logging.getLogger("agent.rollback.orchestrator")


# ========================================================================
# Strategy enum — the three fundamental rollback approaches
# ========================================================================

class RollbackStrategy(enum.Enum):
    """
    Three mutually-exclusive rollback strategies. Each maps to a different
    recovery mechanism with distinct cost/speed/reliability tradeoffs.

    SNAPSHOT     — Pre-operation full-state capture. Fast restore, high
                   storage cost.
    TRANSACTION  — Write-ahead log + atomic commit. Low storage, complex
                   implementation.
    COMPENSATION — Semantic reverse operation. For external side effects
                   that cannot be "undone" at the storage level.
    """
    SNAPSHOT = "snapshot"
    TRANSACTION = "transaction"
    COMPENSATION = "compensation"


# ========================================================================
# Operation descriptor — what the Agent intends to do
# ========================================================================

class OperationKind(enum.Enum):
    """Categorization of what kind of mutation the Agent is performing."""
    FILE_WRITE = "file_write"          # Single file content replacement
    FILE_MULTI = "file_multi"          # Multiple file writes, one logical step
    FILE_DELETE = "file_delete"        # Removing a file entirely
    DB_WRITE = "db_write"              # Database INSERT/UPDATE/DELETE
    CONFIG_CHANGE = "config_change"    # Structured config modification
    API_CALL = "api_call"              # External API with side effects
    MIXED = "mixed"                    # Combination of the above


@dataclass
class OperationDescriptor:
    """
    Describes one atomic (from the Agent's perspective) operation.
    The orchestrator uses this to select the rollback strategy.
    """
    kind: OperationKind
    target: str                   # Human-readable target (path, table, URL)
    operation_id: str             # Unique ID for tracing
    # Optional hints for strategy selection
    estimated_size_bytes: int = 0     # For SNAPSHOT cost estimation
    is_reversible_api: bool = True    # For COMPENSATION: API supports undo?
    affected_files: list[str] = field(default_factory=list)
    # Custom strategy override — if set, orchestrator skips auto-selection
    preferred_strategy: RollbackStrategy | None = None


# ========================================================================
# Undo handler protocol — the universal rollback interface
# ========================================================================

class UndoHandler(Protocol):
    """
    Protocol for any undo operation. Every strategy must produce one.
    The orchestrator calls execute(), and the handler returns True if
    the rollback succeeded. If False, the orchestrator may escalate.
    """
    def execute(self) -> bool: ...
    def describe(self) -> str: ...


@dataclass
class SimpleUndo:
    """A concrete undo handler wrapping a callable + description."""
    _fn: Callable[[], bool]
    _desc: str

    def execute(self) -> bool:
        try:
            return self._fn()
        except Exception as exc:
            logger.error("Undo handler failed: %s — %s", self._desc, exc)
            return False

    def describe(self) -> str:
        return self._desc


# ========================================================================
# Strategy selection logic — the decision matrix as code
# ========================================================================

@dataclass
class StrategyDecision:
    """
    Output of the strategy selector. Contains the chosen strategy and
    a reason string for auditability.
    """
    strategy: RollbackStrategy
    reason: str
    # If True, orchestrator also registers a fallback undo handler
    needs_fallback: bool = False
    fallback_strategy: RollbackStrategy | None = None


def select_strategy(op: OperationDescriptor) -> StrategyDecision:
    """
    Decision matrix that maps OperationDescriptor → RollbackStrategy.

    Decision logic (in priority order):
      1. If caller explicitly set preferred_strategy, use it.
      2. FILE_WRITE: SNAPSHOT if file < 100MB, else TRANSACTION.
         Rationale: snapshot cost grows linearly with file size; beyond
         ~100MB, the IO overhead of copying dominates and WAL is cheaper.
      3. FILE_MULTI: always TRANSACTION.
         Rationale: multiple files need atomicity — snapshot would require
         per-file copies and the restore order matters.
      4. FILE_DELETE: SNAPSHOT (copy before delete) if small, else
         COMPENSATION (recreate from backup source).
      5. DB_WRITE: TRANSACTION (WAL-based undo via pre-image logging).
      6. API_CALL: COMPENSATION if reversible_api=True, else best-effort
         COMPENSATION (escalate to human if irrecoverable).
      7. CONFIG_CHANGE: TRANSACTION (structured configs have clear key/value
         deltas, making WAL the natural fit).
      8. MIXED: TRANSACTION as primary, COMPENSATION as fallback for
         API-call sub-operations within the mix.
    """
    if op.preferred_strategy is not None:
        return StrategyDecision(
            strategy=op.preferred_strategy,
            reason=f"Explicit caller preference for {op.target}",
        )

    kind = op.kind

    if kind == OperationKind.FILE_WRITE:
        if op.estimated_size_bytes < 100 * 1024 * 1024:  # < 100 MB
            return StrategyDecision(
                strategy=RollbackStrategy.SNAPSHOT,
                reason=(
                    f"File {op.target} is small "
                    f"(~{op.estimated_size_bytes}B), SNAPSHOT is cheap"
                ),
            )
        else:
            return StrategyDecision(
                strategy=RollbackStrategy.TRANSACTION,
                reason=(
                    f"File {op.target} is large "
                    f"(~{op.estimated_size_bytes}B), SNAPSHOT too expensive "
                    f"— using TRANSACTION"
                ),
                needs_fallback=True,
                fallback_strategy=RollbackStrategy.SNAPSHOT,
            )

    elif kind == OperationKind.FILE_MULTI:
        return StrategyDecision(
            strategy=RollbackStrategy.TRANSACTION,
            reason=(
                f"Multi-file operation ({len(op.affected_files)} files) "
                f"requires atomicity — TRANSACTION"
            ),
            needs_fallback=True,
            fallback_strategy=RollbackStrategy.SNAPSHOT,
        )

    elif kind == OperationKind.FILE_DELETE:
        if op.estimated_size_bytes < 100 * 1024 * 1024:
            return StrategyDecision(
                strategy=RollbackStrategy.SNAPSHOT,
                reason=f"Delete of {op.target}: SNAPSHOT before deletion",
            )
        else:
            return StrategyDecision(
                strategy=RollbackStrategy.COMPENSATION,
                reason=(
                    f"Delete of large file {op.target}: cannot snapshot — "
                    f"COMPENSATION (recreate from backup)"
                ),
            )

    elif kind == OperationKind.DB_WRITE:
        return StrategyDecision(
            strategy=RollbackStrategy.TRANSACTION,
            reason="Database writes: TRANSACTION via pre-image WAL",
        )

    elif kind == OperationKind.API_CALL:
        if op.is_reversible_api:
            return StrategyDecision(
                strategy=RollbackStrategy.COMPENSATION,
                reason=f"API call to {op.target}: COMPENSATION (reversible)",
                needs_fallback=True,
                fallback_strategy=RollbackStrategy.COMPENSATION,
            )
        else:
            return StrategyDecision(
                strategy=RollbackStrategy.COMPENSATION,
                reason=(
                    f"API call to {op.target}: NOT reversible — "
                    f"COMPENSATION (best-effort, partial recovery expected)"
                ),
                needs_fallback=True,
                fallback_strategy=None,
            )

    elif kind == OperationKind.CONFIG_CHANGE:
        return StrategyDecision(
            strategy=RollbackStrategy.TRANSACTION,
            reason="Config changes: TRANSACTION via key-level delta logging",
        )

    elif kind == OperationKind.MIXED:
        return StrategyDecision(
            strategy=RollbackStrategy.TRANSACTION,
            reason=(
                "Mixed operation: TRANSACTION primary, "
                "COMPENSATION fallback for API sub-operations"
            ),
            needs_fallback=True,
            fallback_strategy=RollbackStrategy.COMPENSATION,
        )

    # Fallthrough — should never happen but be defensive
    return StrategyDecision(
        strategy=RollbackStrategy.SNAPSHOT,
        reason=f"Unknown operation kind {kind}: defaulting to SNAPSHOT",
    )


# ========================================================================
# Snapshot strategy implementation
# ========================================================================

@dataclass
class SnapshotUndo:
    """Undo handler that restores a file from a saved snapshot."""
    snapshot_path: Path
    target_path: Path

    def execute(self) -> bool:
        try:
            if not self.snapshot_path.exists():
                logger.error(
                    "Snapshot missing: %s — cannot restore %s",
                    self.snapshot_path, self.target_path,
                )
                return False
            shutil.copy2(self.snapshot_path, self.target_path)
            self.snapshot_path.unlink(missing_ok=True)
            logger.info(
                "Snapshot rollback: restored %s → %s",
                self.snapshot_path, self.target_path,
            )
            return True
        except Exception as exc:
            logger.error("Snapshot rollback failed: %s", exc)
            return False

    def describe(self) -> str:
        return f"Snapshot restore: {self.snapshot_path} → {self.target_path}"


def prepare_snapshot_undo(
    target: str,
    snapshot_dir: str | Path,
    operation_id: str,
) -> tuple[bool, Optional[SimpleUndo], Optional[str]]:
    """
    Creates a snapshot of `target` and returns an undo handler.
    Returns (success, undo_handler, error_message).
    """
    target_path = Path(target).resolve()
    snap_dir = Path(snapshot_dir)
    snap_dir.mkdir(parents=True, exist_ok=True)

    if not target_path.exists():
        # Nothing to snapshot — undo means deleting any created file
        def delete_if_created() -> bool:
            target_path.unlink(missing_ok=True)
            return True
        return (
            True,
            SimpleUndo(delete_if_created, f"Delete {target_path} if created"),
            None,
        )

    snapshot_path = snap_dir / f"snap-{operation_id}-{target_path.name}"
    try:
        shutil.copy2(target_path, snapshot_path)
        handler = SnapshotUndo(
            snapshot_path=snapshot_path, target_path=target_path,
        )
        return True, SimpleUndo(handler.execute, handler.describe()), None
    except Exception as exc:
        return False, None, str(exc)


# ========================================================================
# Transaction (WAL) strategy — simplified intent log
# ========================================================================

@dataclass
class WALEntry:
    """A single write-ahead log entry for file operations."""
    operation_id: str
    target: str
    pre_image_hash: str       # SHA-256 of file before write
    pre_image_path: str       # Path to saved pre-image
    timestamp: float = field(default_factory=time.time)


class WriteAheadLog:
    """
    Minimal write-ahead log for file operations.
    Stores pre-images rather than full file copies to keep storage low.

    In production, use a proper WAL with checksums, rotation, and crash
    recovery. This is the pattern, not the production code.
    """

    def __init__(self, wal_dir: str | Path) -> None:
        self.wal_dir = Path(wal_dir)
        self.wal_dir.mkdir(parents=True, exist_ok=True)
        self._log_path = self.wal_dir / "wal.jsonl"

    def record_intent(
        self, operation_id: str, target: str,
    ) -> Optional[SimpleUndo]:
        """
        Record "I intend to write to `target`".
        Saves a pre-image and returns an undo handler.
        Returns None if the pre-image capture fails.
        """
        target_path = Path(target).resolve()
        if not target_path.exists():
            # New file: intent is "will create". Undo = delete.
            entry = WALEntry(
                operation_id=operation_id,
                target=target,
                pre_image_hash="new_file",
                pre_image_path="",
            )
            self._append_entry(entry)

            def undo_new_file() -> bool:
                target_path.unlink(missing_ok=True)
                self._append_entry(WALEntry(
                    operation_id=f"{operation_id}-undo",
                    target=target,
                    pre_image_hash="undo_delete",
                    pre_image_path="",
                ))
                return True

            return SimpleUndo(
                undo_new_file, f"WAL undo: delete new file {target}",
            )

        # Existing file: save pre-image hash and content
        try:
            content = target_path.read_bytes()
            file_hash = hashlib.sha256(content).hexdigest()

            # Save pre-image (full content for simplicity; production
            # code would compute and store a diff)
            pre_image_path = (
                self.wal_dir / f"pre-{operation_id}-{target_path.name}"
            )
            pre_image_path.write_bytes(content)

            entry = WALEntry(
                operation_id=operation_id,
                target=target,
                pre_image_hash=file_hash,
                pre_image_path=str(pre_image_path),
            )
            self._append_entry(entry)

            pre_img = pre_image_path  # capture for closure
            def undo_wal() -> bool:
                try:
                    if pre_img.exists():
                        shutil.copy2(pre_img, target_path)
                        pre_img.unlink(missing_ok=True)
                        self._append_entry(WALEntry(
                            operation_id=f"{operation_id}-undo",
                            target=target,
                            pre_image_hash="undo_restore",
                            pre_image_path="",
                        ))
                        logger.info(
                            "WAL rollback: restored %s from pre-image",
                            target,
                        )
                        return True
                    return False
                except Exception as exc:
                    logger.error("WAL undo failed: %s", exc)
                    return False

            return SimpleUndo(
                undo_wal,
                f"WAL undo: restore {target} from pre-image",
            )

        except Exception as exc:
            logger.error(
                "WAL intent recording failed for %s: %s", target, exc,
            )
            return None

    def _append_entry(self, entry: WALEntry) -> None:
        """Append a JSON line to the WAL file."""
        with open(self._log_path, "a") as f:
            f.write(json.dumps({
                "operation_id": entry.operation_id,
                "target": entry.target,
                "pre_image_hash": entry.pre_image_hash,
                "pre_image_path": entry.pre_image_path,
                "timestamp": entry.timestamp,
            }) + "\n")

    def get_entries(self, operation_id: str) -> list[dict]:
        """Retrieve all WAL entries for a given operation_id."""
        if not self._log_path.exists():
            return []
        entries = []
        with open(self._log_path) as f:
            for line in f:
                entry = json.loads(line.strip())
                if entry.get("operation_id", "").startswith(operation_id):
                    entries.append(entry)
        return entries


# ========================================================================
# RollbackOrchestrator — the central coordinator
# ========================================================================

@dataclass
class RollbackRecord:
    """Tracks one rollback registration for observability and debugging."""
    operation_id: str
    strategy: RollbackStrategy
    target: str
    handler: SimpleUndo
    registered_at: float = field(default_factory=time.time)
    executed: bool = False


@dataclass
class RollbackOrchestrator:
    """
    Central coordinator for rollback across an Agent's entire operation.

    Responsibilities:
      1. Analyze each write operation → select strategy via select_strategy()
      2. Prepare undo handler BEFORE the write executes
      3. Register the handler in a stack (LIFO for composable rollback)
      4. On failure, pop handlers from stack and execute in reverse order
      5. Track rollback records for audit/observability

    Usage:
        orch = RollbackOrchestrator(
            snapshot_dir="/var/agent/snapshots",
            wal_dir="/var/agent/wal",
        )

        # Before Agent writes:
        undo = orch.prepare_undo(op_descriptor)
        if undo is None:
            raise RuntimeError(
                "Could not prepare rollback — aborting write"
            )

        # Agent performs the write...
        # If the write succeeds:
        orch.commit(operation_id)

        # If anything fails:
        orch.rollback_all()  # LIFO undo execution
    """

    snapshot_dir: Path
    wal_dir: Path
    _wal: WriteAheadLog | None = None
    _undo_stack: list[tuple[str, SimpleUndo]] = field(default_factory=list)
    _records: list[RollbackRecord] = field(default_factory=list)
    _committed: set[str] = field(default_factory=set)

    def __post_init__(self) -> None:
        self.snapshot_dir = Path(self.snapshot_dir)
        self.wal_dir = Path(self.wal_dir)
        self.snapshot_dir.mkdir(parents=True, exist_ok=True)
        self.wal_dir.mkdir(parents=True, exist_ok=True)
        self._wal = WriteAheadLog(self.wal_dir)

    @property
    def wal(self) -> WriteAheadLog:
        assert self._wal is not None
        return self._wal

    # -- Strategy selection + undo preparation -----------------------------

    def prepare_undo(self, op: OperationDescriptor) -> Optional[SimpleUndo]:
        """
        Analyze the operation, select the best strategy, and prepare the
        undo handler. Returns None if preparation fails — the caller MUST
        NOT proceed with the write in that case.

        This is the single entry point for "register undo before write".
        """
        decision = select_strategy(op)

        logger.info(
            "Preparing undo for %s (kind=%s, strategy=%s): %s",
            op.operation_id, op.kind.value,
            decision.strategy.value, decision.reason,
        )

        handler: Optional[SimpleUndo] = None

        if decision.strategy == RollbackStrategy.SNAPSHOT:
            ok, undo, err = prepare_snapshot_undo(
                target=op.target,
                snapshot_dir=self.snapshot_dir,
                operation_id=op.operation_id,
            )
            if not ok:
                logger.error("Snapshot preparation failed: %s", err)
                # Try fallback if available
                if decision.fallback_strategy == RollbackStrategy.TRANSACTION:
                    logger.info(
                        "Falling back to TRANSACTION for %s",
                        op.operation_id,
                    )
                    handler = self.wal.record_intent(
                        op.operation_id, op.target,
                    )
            else:
                handler = undo

        elif decision.strategy == RollbackStrategy.TRANSACTION:
            handler = self.wal.record_intent(op.operation_id, op.target)
            if handler is None and decision.fallback_strategy is not None:
                logger.info(
                    "TRANSACTION prep failed for %s, falling back to %s",
                    op.operation_id, decision.fallback_strategy.value,
                )
                ok, undo, err = prepare_snapshot_undo(
                    target=op.target,
                    snapshot_dir=self.snapshot_dir,
                    operation_id=op.operation_id,
                )
                if ok:
                    handler = undo

        elif decision.strategy == RollbackStrategy.COMPENSATION:
            # Compensation handlers are supplied by the caller because they
            # require business semantics. The caller should call
            # register_compensation() with the actual undo function after
            # prepare_undo returns.
            handler = None  # Caller provides via register_compensation()

        else:
            logger.error("Unknown strategy: %s", decision.strategy)
            return None

        # Register even if handler is None — caller may provide one later
        if handler is not None:
            self._register(
                op.operation_id, handler, decision.strategy, op.target,
            )

        return handler

    def register_compensation(
        self,
        operation_id: str,
        compensation_fn: Callable[[], bool],
        description: str,
        op: OperationDescriptor,
    ) -> None:
        """
        Register a compensation (semantic undo) handler for an operation
        that was previously analyzed as needing COMPENSATION strategy.
        Call AFTER prepare_undo() returned None for an API_CALL operation.
        """
        decision = select_strategy(op)
        if decision.strategy != RollbackStrategy.COMPENSATION:
            logger.warning(
                "register_compensation called for op %s "
                "but strategy is %s — ignoring",
                operation_id, decision.strategy.value,
            )
            return

        handler = SimpleUndo(compensation_fn, description)
        self._register(
            operation_id, handler, RollbackStrategy.COMPENSATION, op.target,
        )
        logger.info(
            "Compensation registered for %s: %s", operation_id, description,
        )

    def _register(
        self,
        operation_id: str,
        handler: SimpleUndo,
        strategy: RollbackStrategy,
        target: str,
    ) -> None:
        """Push undo handler onto the LIFO stack and record for audit."""
        self._undo_stack.append((operation_id, handler))
        self._records.append(RollbackRecord(
            operation_id=operation_id,
            strategy=strategy,
            target=target,
            handler=handler,
        ))
        logger.debug(
            "Undo registered: op=%s strategy=%s target=%s (depth=%d)",
            operation_id, strategy.value, target, len(self._undo_stack),
        )

    # -- Commit / rollback lifecycle ---------------------------------------

    def commit(self, operation_id: str) -> None:
        """
        Mark an operation as committed. Its undo handler will NOT be
        executed during rollback_all() — only uncommitted operations
        are rolled back.
        """
        self._committed.add(operation_id)
        # Remove from stack top if present (common case)
        if self._undo_stack and self._undo_stack[-1][0] == operation_id:
            self._undo_stack.pop()
            logger.debug("Commit: removed %s from undo stack", operation_id)
        else:
            logger.debug(
                "Commit marker set for %s (handler not at stack top)",
                operation_id,
            )

    def rollback_all(self) -> list[str]:
        """
        Execute all uncommitted undo handlers in LIFO order.
        Returns list of operation_ids that failed to roll back.

        LIFO ordering is critical: if op3 depends on op2 which depends
        on op1, the undo must be op3 → op2 → op1 to avoid leaving
        inconsistent state.
        """
        failed: list[str] = []
        executed_count = 0

        # Iterate in reverse (LIFO), skipping committed operations
        for operation_id, handler in reversed(self._undo_stack):
            if operation_id in self._committed:
                continue

            logger.info(
                "Executing rollback for %s: %s",
                operation_id, handler.describe(),
            )
            try:
                success = handler.execute()
                if success:
                    executed_count += 1
                    for rec in self._records:
                        if rec.operation_id == operation_id:
                            rec.executed = True
                else:
                    failed.append(operation_id)
                    logger.error("Rollback FAILED for %s", operation_id)
            except Exception as exc:
                failed.append(operation_id)
                logger.exception(
                    "Rollback exception for %s: %s", operation_id, exc,
                )

        self._undo_stack.clear()
        logger.info(
            "Rollback complete: %d executed, %d failed, %d skipped",
            executed_count, len(failed), len(self._committed),
        )
        return failed

    def rollback_one(self, operation_id: str) -> bool:
        """Roll back a single operation by ID."""
        for i, (op_id, handler) in enumerate(self._undo_stack):
            if op_id == operation_id:
                logger.info(
                    "Rolling back single operation: %s", operation_id,
                )
                success = handler.execute()
                self._undo_stack.pop(i)
                for rec in self._records:
                    if rec.operation_id == operation_id:
                        rec.executed = success
                return success
        logger.warning("No undo handler found for %s", operation_id)
        return False

    # -- Observability -----------------------------------------------------

    @property
    def pending_undos(self) -> int:
        """Number of uncommitted undo handlers on the stack."""
        return sum(
            1 for op_id, _ in self._undo_stack
            if op_id not in self._committed
        )

    def summary(self) -> dict[str, Any]:
        """Return a summary of the orchestrator's state for monitoring."""
        return {
            "total_registered": len(self._records),
            "pending_undos": self.pending_undos,
            "committed": len(self._committed),
            "strategies_used": {
                s.value: sum(
                    1 for r in self._records if r.strategy == s
                )
                for s in RollbackStrategy
            },
            "records": [
                {
                    "operation_id": r.operation_id,
                    "strategy": r.strategy.value,
                    "target": r.target,
                    "executed": r.executed,
                }
                for r in self._records
            ],
        }


# ========================================================================
# Usage example
# ========================================================================

if __name__ == "__main__":
    # Setup
    test_file = Path("/tmp/agent-rollback-orch-demo.txt")
    test_file.write_text("original content v1\n")

    orch = RollbackOrchestrator(
        snapshot_dir="/tmp/agent-rollback-snaps",
        wal_dir="/tmp/agent-rollback-wal",
    )

    # -- Example 1: Small file write → SNAPSHOT ---------------------------
    op1 = OperationDescriptor(
        kind=OperationKind.FILE_WRITE,
        target=str(test_file),
        operation_id="write-v2",
        estimated_size_bytes=50,
    )

    undo1 = orch.prepare_undo(op1)
    assert undo1 is not None, "Undo prep failed!"

    # Simulate Agent writing the file
    test_file.write_text("modified content v2\n")

    # Agent decides the write was wrong — rollback!
    orch.rollback_one("write-v2")
    assert test_file.read_text() == "original content v1\n", (
        "Rollback failed!"
    )
    print("Example 1 passed: snapshot rollback")

    # -- Example 2: Multi-file → TRANSACTION ------------------------------
    test_file2 = Path("/tmp/agent-rollback-orch-demo2.txt")
    test_file2.write_text("file2 original\n")

    op2 = OperationDescriptor(
        kind=OperationKind.FILE_MULTI,
        target="batch-update-configs",
        operation_id="batch-v2",
        affected_files=[str(test_file), str(test_file2)],
    )

    undo2 = orch.prepare_undo(op2)
    assert undo2 is not None  # WAL undo handler for the primary target

    # Simulate writes
    test_file.write_text("file1 v3\n")
    test_file2.write_text("file2 v3\n")

    orch.rollback_all()
    print("Example 2 passed: transaction rollback")

    # -- Example 3: API call → COMPENSATION -------------------------------
    op3 = OperationDescriptor(
        kind=OperationKind.API_CALL,
        target="https://api.stripe.com/v1/subscriptions",
        operation_id="create-sub",
        is_reversible_api=True,
    )

    undo3 = orch.prepare_undo(op3)
    # undo3 is None for compensation — caller must register
    if undo3 is None:
        fake_subscription_id = "sub_abc123"

        def cancel_subscription() -> bool:
            # In reality: stripe.Subscription.delete(fake_subscription_id)
            logger.info(
                "Compensation: cancelled subscription %s",
                fake_subscription_id,
            )
            return True

        orch.register_compensation(
            operation_id="create-sub",
            compensation_fn=cancel_subscription,
            description=f"Cancel Stripe subscription {fake_subscription_id}",
            op=op3,
        )

    # Simulate failure after API call → rollback
    failed = orch.rollback_all()
    print(f"Example 3 passed: compensation rollback (failed={failed})")

    # Summary
    import json as _json
    print("\nOrchestrator summary:")
    print(_json.dumps(orch.summary(), indent=2, default=str))

    # Cleanup
    test_file.unlink(missing_ok=True)
    test_file2.unlink(missing_ok=True)
    import shutil as _shutil
    _shutil.rmtree("/tmp/agent-rollback-snaps", ignore_errors=True)
    _shutil.rmtree("/tmp/agent-rollback-wal", ignore_errors=True)

编排器的设计决策

RollbackOrchestrator 的架构中有几个值得展开的设计决策:

  1. LIFO 回滚顺序:rollback_all() 以栈顺序(后进先出)执行回滚。这不是可选的——如果 Agent 在第 3 步修改了一个依赖第 2 步写入结果的文件,回滚必须从第 3 步开始反向执行到第 1 步。正序回滚会导致第 2 步的文件已经被恢复,而第 3 步的文件引用了已不存在的第 2 步内容。
  2. 策略降级与回退:当一个策略的准备阶段失败时,编排器不会直接报错——它会尝试 fallback_strategy。例如对于一个边界大小文件(刚好 110MB),预备快照可能因磁盘空间不足而失败,此时编排器自动降级到 WAL 事务模式。这种降级链路在 select_strategy() 的决策中已经编码:每个策略决策都附带一个可选的 fallback。
  3. 补偿的延迟注册:对于 API 调用,prepare_undo() 返回 None——因为补偿函数需要 API 调用的结果(如 subscription_id)才能构造。编排器支持的流程是:prepare_undo() 确认策略 → Agent 执行 API 调用并获取结果 → register_compensation() 用结果构造补偿函数。这个两阶段设计是补偿策略的固有延迟——你不可能在不知道 Stripe 返回了什么 subscription_id 的情况下构造取消函数。
  4. commit() 的语义:当一个操作执行成功并且 Agent 确认不需要回滚时,调用 commit() 将其从回滚栈中移除。注意这不等同于删除快照或 WAL 条目——这些数据仍然保留用于审计。commit 只影响回滚栈——已提交的操作在 rollback_all() 时被跳过。参见 Agent 审计日志设计 了解审计保留策略。

关于 Agent 在执行边界内的安全约束(限制一次写入的 blast radius),参见 Agent 命令执行安全。关于如何利用运行时隔离来扩大回滚的覆盖范围(容器级快照),参见 Agent 运行时隔离。关于发布门禁如何利用回滚作为安全网,参见 Agent 发布门禁设计

3. 文件级回滚:版本化快照与基于差异的恢复

第 2 节的 RollbackOrchestrator 解决了"选择哪种策略"的问题,但策略选择之后的执行才是回滚机制的真正核心。对于本地文件操作——Agent 最频繁、也最容易出错的写入类型——快照回滚不能只靠 shutil.copy2 了事。高频写入场景下,为每个文件每次写入都拍完整快照会迅速耗尽磁盘;多文件关联写入时,独立快照之间没有原子性保证——恢复到文件 A 的快照但文件 B 的快照丢失,系统进入不一致状态。

本节提出一套完整的文件级回滚方案:Copy-on-Write 文件包装器解决单文件原子性,差异回滚解决存储开销,暂存目录管理解决生命周期和清理问题。这三个组件组合起来,使 Agent 的每一次文件写入都具备细粒度的、存储高效的撤销能力。

Copy-on-Write 文件包装器

Copy-on-Write(CoW)在文件系统领域是一个经典概念——btrfs 和 ZFS 用它实现快照,Docker 用它实现容器层。在 Agent 回滚的语境中,CoW 的思想不是修改原始文件,而是写入暂存区、验证、原子替换的三阶段流水线。这个流水线在第 1 节的 safe_agent_write 上下文管理器中已经出现了雏形——但那个实现只处理单个文件的单次写入。一个完整的 CoW 包装器需要处理多文件批量写入、句柄化的提交/回滚生命周期、以及基于校验和的差异追踪。

流水线的三个阶段各自承担不同的职责:

  1. Stage(暂存):write(path, content) 不触碰目标文件。内容写入暂存目录中的一个临时文件——.cow-stage-{uuid}-{filename}。方法返回一个 WriteHandle 对象,包含暂存文件路径、目标路径、操作 ID 和校验和信息。Agent 可以持有多个 handle 并批量提交,也可以在暂存后执行验证逻辑而不用担心目标文件损坏。
  2. Validate(验证):提交之前,每个暂存文件可以经过可选的验证器链——YAML 语法检查、JSON schema 验证、nginx -t 试运行、Python 编译检查等。验证失败不会影响目标文件——因为目标文件从未被修改。这个阶段是 Agent 回滚区别于"事后补救"的关键:它把验证前置到写操作生效之前。
  3. Atomic Replace(原子替换):commit(handle) 执行 os.replace()——在 POSIX 系统上是原子操作。目标文件要么完全是旧内容,要么完全是新内容,绝不存在"写了一部分"的中间状态。如果目标文件原本不存在,暂存文件被移动到目标路径;如果目标文件存在,在替换前自动保存一份快照到 .cow-snap-{uuid}-{filename},作为回滚的恢复源。

对于多文件操作,CoW 包装器提供事务性批量提交——commit_all([handle1, handle2, ...]) 在执行任何原子替换之前,先验证所有暂存文件通过验证器链。任何一个文件验证失败,整个批次都不会生效。这种 all-or-nothing 语义使 Agent 的多文件修改具备了数据库事务级别的原子性。

基于差异的回滚

完整快照的存储开销在第 2 节已经分析过了——大文件高频写入场景下不可接受。基于差异的回滚(Diff-based Rollback)是一种替代方案:不保存完整副本,而是保存新旧内容的差异。回滚时反向应用差异即可恢复原始内容。

具体实现分为三步:

  1. 基线校验和:在第一次写入前计算文件的 SHA-256 校验和,作为后续差异计算的基线。基线存储在暂存目录中,与文件名和操作 ID 关联。
  2. 差异计算:对于文本文件(代码、配置、数据文件),使用 difflib.unified_diff 计算新旧内容的统一差异(unified diff)。差异以补丁格式存储,体积通常远小于完整快照——对于一个 10MB 的日志文件,修改 3 行产生的差异只有几百字节。
  3. 反向应用:回滚时,将差异补丁反向应用到当前文件:difflib.restore 根据旧→新差异重构旧内容,然后写回目标文件。对于二进制文件,差异方法不适用——回退到完整快照策略。

差异回滚并非万能。对于高度结构化的文件(如 JSON、YAML),语义级的键值差异比文本行差异更有意义——行级 diff 可能因为缩进变化而产生大量无效差异。对于这些场景,差异计算应委托给特定格式的解析器。此外,差异链(连续 10 次写入产生 10 个差异补丁)的恢复时间线性增长——对于高频写入,需要定期合并差异(squash)为新的基线快照。

暂存目录管理与保留策略

CoW 包装器会在暂存目录中积累三种文件:暂存文件(.cow-stage-*)、快照文件(.cow-snap-*)和差异补丁(.cow-diff-*)。如果不加管理,这些文件会无限增长。需要系统的清理和保留策略:

关于暂存目录的存储位置——推荐使用与目标文件相同的文件系统(确保 os.replace() 是原子操作,避免跨文件系统的复制)。对于容器化 Agent——暂存目录应与 Agent 的工作区在同一卷上。容器的临时文件系统在容器重启时会丢失所有暂存内容,导致已暂存但未提交的操作永久丢失。参见 Agent 运行时隔离 了解容器化 Agent 的文件系统布局建议。

CopyOnWriteAgent 实现

下面的代码实现了本节讨论的所有机制——CoW 文件包装器、差异回滚和暂存目录管理——聚合为一个 CopyOnWriteAgent 类:

from __future__ import annotations

import difflib
import hashlib
import os
import shutil
import time
import uuid
from dataclasses import dataclass, field
from pathlib import Path
from typing import Callable, Optional
import logging

logger = logging.getLogger("agent.rollback.cow")


# ==========================================================================
# WriteHandle — the token returned by write(), used for commit/rollback
# ==========================================================================

@dataclass
class WriteHandle:
    """
    Opaque handle returned by CopyOnWriteAgent.write().
    The caller uses this to commit or rollback the staged write.
    Attributes are intentionally not private — the orchestrator may
    inspect them for observability.
    """
    handle_id: str
    target_path: Path
    staged_path: Path
    operation_id: str
    content_checksum: str          # SHA-256 of staged content
    baseline_checksum: str = ""    # SHA-256 of original file ("" if new)
    is_binary: bool = False
    created_at: float = field(default_factory=time.time)


# ==========================================================================
# Diff-based helpers — compute and apply reverse diffs
# ==========================================================================

def compute_diff(
    old_content: str, new_content: str, filename: str,
) -> str:
    """
    Compute a unified diff from old → new for text content.
    This diff can be reversed to restore old content from new.
    """
    old_lines = old_content.splitlines(keepends=True)
    new_lines = new_content.splitlines(keepends=True)
    diff = difflib.unified_diff(
        old_lines, new_lines,
        fromfile=f"a/{filename}",
        tofile=f"b/{filename}",
    )
    return "".join(diff)


def apply_reverse_diff(
    current_content: str, forward_diff: str,
) -> str:
    """
    Apply a forward diff in reverse to restore the original content.
    Uses difflib.restore with 'fromfile' lines to reconstruct old content.
    """
    current_lines = current_content.splitlines(keepends=True)
    # Parse the diff to extract the 'fromfile' (old) lines
    old_lines: list[str] = []
    for line in forward_diff.splitlines(keepends=True):
        if line.startswith("-"):
            old_lines.append(line[1:])      # removed line → original
        elif line.startswith(" "):
            old_lines.append(line[1:])      # context line → original
        # Skip "+" lines (they are new content, not in original)
        # Skip "@@", "---", "+++" header lines

    return "".join(old_lines)


# ==========================================================================
# CopyOnWriteAgent — file-level rollback engine
# ==========================================================================

class CopyOnWriteAgent:
    """
    File-level rollback via copy-on-write wrapping with diff-based storage.

    Workflow:
      1. write(path, content)  → returns WriteHandle (stages in temp dir)
      2. validate(handle, ...)  → optionally run validators on staged content
      3. commit(handle)         → atomic os.replace() + save diff/snapshot
      4. rollback(handle)       → restore from diff or snapshot

    Design decisions:
      - Text files (code, config, data) use diffs for minimal storage.
      - Binary files (images, archives) use full snapshots.
      - Multi-file batches use commit_all() for atomicity across files.
      - Staging directory is self-cleaning (TTL, capacity, commit cleanup).
    """

    def __init__(
        self,
        staging_dir: str | Path,
        *,
        snapshot_retention_seconds: int = 3600,   # 1 hour default
        max_staging_size_mb: int = 500,
        ttl_seconds: int = 86400,                  # 24 hours
    ) -> None:
        self.staging_dir = Path(staging_dir)
        self.staging_dir.mkdir(parents=True, exist_ok=True)
        self.snapshot_retention = snapshot_retention_seconds
        self.max_staging_bytes = max_staging_size_mb * 1024 * 1024
        self.ttl_seconds = ttl_seconds
        self._active_handles: dict[str, WriteHandle] = {}
        self._binary_extensions = {
            ".png", ".jpg", ".jpeg", ".gif", ".pdf",
            ".zip", ".tar", ".gz", ".bin", ".exe", ".so",
            ".dylib", ".dll", ".wasm", ".mp4", ".mp3",
            ".ico", ".woff", ".woff2", ".ttf", ".eot",
        }

    # -- Public API ----------------------------------------------------------

    def write(
        self,
        target: str | Path,
        content: str | bytes,
        *,
        operation_id: str = "",
    ) -> WriteHandle:
        """
        Stage a write to `target` with the given `content`.

        The target file is NOT modified. Content is written to a staging
        file. Returns a WriteHandle for later commit/rollback.

        If content is identical to the current file content, the write is
        treated as a no-op — handle is returned but commit does nothing.
        """
        target_path = Path(target).resolve()
        op_id = operation_id or str(uuid.uuid4())[:8]
        handle_id = f"cow-{op_id}-{uuid.uuid4().hex[:12]}"

        # Determine content encoding
        if isinstance(content, str):
            content_bytes = content.encode("utf-8")
            is_binary = False
        else:
            content_bytes = content
            is_binary = True

        content_checksum = hashlib.sha256(content_bytes).hexdigest()

        # Check if target exists and compute baseline
        baseline_checksum = ""
        if target_path.exists():
            target_bytes = target_path.read_bytes()
            target_checksum = hashlib.sha256(target_bytes).hexdigest()
            # No-op: content unchanged
            if target_checksum == content_checksum:
                logger.debug(
                    "No-op write for %s (content unchanged)", target_path,
                )
                handle = WriteHandle(
                    handle_id=handle_id,
                    target_path=target_path,
                    staged_path=target_path,  # Already there
                    operation_id=op_id,
                    content_checksum=content_checksum,
                    baseline_checksum=target_checksum,
                    is_binary=is_binary,
                )
                self._active_handles[handle_id] = handle
                return handle
            baseline_checksum = target_checksum

        # Determine if binary by extension (override if already known)
        suffix = target_path.suffix.lower()
        if suffix in self._binary_extensions:
            is_binary = True

        # Write staged file
        staged_path = (
            self.staging_dir / f".cow-stage-{handle_id}-{target_path.name}"
        )
        staged_path.write_bytes(content_bytes)
        logger.debug(
            "Staged write: %s → %s (%d bytes)",
            staged_path, target_path, len(content_bytes),
        )

        handle = WriteHandle(
            handle_id=handle_id,
            target_path=target_path,
            staged_path=staged_path,
            operation_id=op_id,
            content_checksum=content_checksum,
            baseline_checksum=baseline_checksum,
            is_binary=is_binary,
        )
        self._active_handles[handle_id] = handle
        return handle

    def commit(
        self,
        handle: WriteHandle,
        *,
        validator: Callable[[Path], bool] | None = None,
        keep_snapshot: bool = True,
    ) -> bool:
        """
        Atomically commit a staged write.

        1. Run optional validator on staged content.
        2. If file exists and differs, save snapshot/diff.
        3. Atomic replace: os.replace(staged, target).
        4. Cleanup staging (unless snapshot retained).

        Returns True on success. On failure, staged file is untouched
        and target file is unchanged.
        """
        if handle.handle_id not in self._active_handles:
            logger.error("Unknown handle: %s", handle.handle_id)
            return False

        target = handle.target_path
        staged = handle.staged_path

        # No-op case: staged_path == target_path
        if staged == target:
            logger.debug("Commit no-op for %s", target)
            self._active_handles.pop(handle.handle_id, None)
            return True

        # Validate staged content if validator provided
        if validator is not None:
            try:
                if not validator(staged):
                    logger.error(
                        "Commit rejected by validator for %s", target,
                    )
                    return False
            except Exception as exc:
                logger.error(
                    "Validator raised exception for %s: %s", target, exc,
                )
                return False

        # Save snapshot or diff before atomic replace
        if handle.baseline_checksum:
            # File existed before this write — save undo data
            self._save_undo_data(handle, keep_snapshot)

        # Atomic replace
        try:
            os.replace(staged, target)
            logger.info(
                "Committed: %s (op=%s, checksum=%s...)",
                target, handle.operation_id, handle.content_checksum[:8],
            )
        except OSError as exc:
            logger.error("Commit atomic replace failed for %s: %s", target, exc)
            return False

        # Cleanup handle
        self._active_handles.pop(handle.handle_id, None)

        # Trigger background cleanup if staging dir exceeds capacity
        self._maybe_cleanup()
        return True

    def rollback(self, handle: WriteHandle) -> bool:
        """
        Rollback a write — restore target file to pre-write state.

        For committed writes: restore from diff or snapshot.
        For uncommitted writes: delete staged file, target untouched.
        """
        if handle.handle_id not in self._active_handles:
            # May have been committed already — try snapshot restore
            return self._restore_from_undo(handle)

        target = handle.target_path
        staged = handle.staged_path

        if staged == target:
            # No-op write, nothing to rollback
            self._active_handles.pop(handle.handle_id, None)
            return True

        # Uncommitted write: just delete the staged file
        staged.unlink(missing_ok=True)
        self._active_handles.pop(handle.handle_id, None)
        logger.info(
            "Rollback (uncommitted): staged file removed for %s", target,
        )
        return True

    def commit_all(
        self,
        handles: list[WriteHandle],
        *,
        validators: dict[str, Callable[[Path], bool]] | None = None,
    ) -> dict[str, bool]:
        """
        Commit multiple handles atomically.

        All validators run first. If any fail, NO file is modified.
        If all pass, all files are atomically replaced.

        Returns a dict mapping handle_id → success.
        """
        results: dict[str, bool] = {}
        validators = validators or {}

        # Phase 1: Validate all staged files
        for h in handles:
            validator = validators.get(h.operation_id)
            if validator is not None:
                if h.staged_path != h.target_path:  # skip no-ops
                    try:
                        if not validator(h.staged_path):
                            logger.error(
                                "Batch commit: validator failed for %s",
                                h.target_path,
                            )
                            results[h.handle_id] = False
                    except Exception as exc:
                        logger.error(
                            "Batch commit: validator exception for %s: %s",
                            h.target_path, exc,
                        )
                        results[h.handle_id] = False
                else:
                    results[h.handle_id] = True

        # If any validation failed, abort entire batch
        if any(not v for v in results.values()):
            logger.error(
                "Batch commit aborted: %d/%d validations failed",
                sum(1 for v in results.values() if not v),
                len(handles),
            )
            # Fill in False for unchecked handles
            for h in handles:
                results.setdefault(h.handle_id, False)
            return results

        # Phase 2: Save undo data for all, then commit all
        committed: list[WriteHandle] = []
        try:
            for h in handles:
                if h.staged_path != h.target_path and h.baseline_checksum:
                    self._save_undo_data(h, keep_snapshot=True)
                os.replace(h.staged_path, h.target_path)
                committed.append(h)
                results[h.handle_id] = True
                self._active_handles.pop(h.handle_id, None)
        except Exception as exc:
            logger.exception("Batch commit failed mid-way: %s", exc)
            # Attempt to rollback already-committed files
            for h in committed:
                try:
                    self._restore_from_undo(h)
                except Exception as rollback_exc:
                    logger.error(
                        "Batch rollback failed for %s: %s",
                        h.target_path, rollback_exc,
                    )
            # Mark remaining as failed
            for h in handles:
                results.setdefault(h.handle_id, False)
            return results

        self._maybe_cleanup()
        return results

    # -- Internal: undo data management -------------------------------------

    def _save_undo_data(
        self, handle: WriteHandle, keep_snapshot: bool,
    ) -> None:
        """
        Save the data needed to undo this write.

        For text files: save a unified diff (old → new).
        For binary files: save a full snapshot copy.
        """
        target = handle.target_path
        handle_id = handle.handle_id

        if not target.exists():
            # Target disappeared between write() and commit() —
            # nothing to save snapshot of
            logger.warning(
                "Cannot save undo data: %s no longer exists", target,
            )
            return

        if handle.is_binary:
            # Binary: save full snapshot
            snap_path = (
                self.staging_dir / f".cow-snap-{handle_id}-{target.name}"
            )
            shutil.copy2(target, snap_path)
            logger.debug(
                "Snapshot saved for binary file: %s (%d bytes)",
                snap_path, snap_path.stat().st_size,
            )
            if not keep_snapshot:
                snap_path.unlink(missing_ok=True)
        else:
            # Text: compute and save diff
            try:
                old_text = target.read_text(encoding="utf-8")
                new_text = handle.staged_path.read_text(encoding="utf-8")
                diff = compute_diff(old_text, new_text, target.name)
                diff_path = (
                    self.staging_dir / f".cow-diff-{handle_id}-{target.name}.patch"
                )
                diff_path.write_text(diff, encoding="utf-8")
                logger.debug(
                    "Diff saved for %s (%d bytes)", target, len(diff),
                )
                if not keep_snapshot:
                    diff_path.unlink(missing_ok=True)
            except UnicodeDecodeError:
                # Fallback: treat as binary
                snap_path = (
                    self.staging_dir
                    / f".cow-snap-{handle_id}-{target.name}"
                )
                shutil.copy2(target, snap_path)
                logger.debug("Fallback snapshot for non-UTF-8 file: %s", target)

    def _restore_from_undo(self, handle: WriteHandle) -> bool:
        """Restore target file from saved undo data (diff or snapshot)."""
        target = handle.target_path
        handle_id = handle.handle_id

        # Try diff first (text files)
        diff_path = (
            self.staging_dir / f".cow-diff-{handle_id}-{target.name}.patch"
        )
        if diff_path.exists():
            try:
                diff_text = diff_path.read_text(encoding="utf-8")
                if target.exists():
                    current_text = target.read_text(encoding="utf-8")
                    restored = apply_reverse_diff(current_text, diff_text)
                    target.write_text(restored, encoding="utf-8")
                diff_path.unlink()
                logger.info(
                    "Rollback via diff: %s restored", target,
                )
                return True
            except Exception as exc:
                logger.error("Diff rollback failed for %s: %s", target, exc)

        # Try snapshot (binary files or diff fallback)
        snap_path = (
            self.staging_dir / f".cow-snap-{handle_id}-{target.name}"
        )
        if snap_path.exists():
            try:
                shutil.copy2(snap_path, target)
                snap_path.unlink()
                logger.info(
                    "Rollback via snapshot: %s restored", target,
                )
                return True
            except Exception as exc:
                logger.error(
                    "Snapshot rollback failed for %s: %s", target, exc,
                )
                return False

        logger.error(
            "No undo data found for %s (handle=%s)", target, handle_id,
        )
        return False

    # -- Internal: staging directory cleanup --------------------------------

    def _maybe_cleanup(self) -> None:
        """Check if staging directory exceeds capacity and clean up if so."""
        total_size = sum(
            f.stat().st_size for f in self.staging_dir.iterdir()
            if f.is_file()
        )
        if total_size > self.max_staging_bytes:
            logger.warning(
                "Staging directory size %d MB exceeds limit %d MB — cleaning",
                total_size // (1024 * 1024),
                self.max_staging_bytes // (1024 * 1024),
            )
            self.cleanup(aggressive=True)

    def cleanup(self, *, aggressive: bool = False) -> int:
        """
        Remove stale and expired files from the staging directory.

        In normal mode: removes files older than TTL.
        In aggressive mode: also removes files by LRU until under capacity.

        Returns the number of files removed.
        """
        now = time.time()
        removed = 0
        files: list[tuple[float, Path]] = []

        for f in self.staging_dir.iterdir():
            if not f.is_file():
                continue
            age = now - f.stat().st_mtime
            if age > self.ttl_seconds:
                try:
                    f.unlink()
                    removed += 1
                except OSError:
                    pass
            elif aggressive:
                files.append((f.stat().st_mtime, f))

        # Aggressive mode: LRU eviction
        if aggressive and files:
            files.sort(key=lambda x: x[0])  # oldest first
            current_size = sum(
                f.stat().st_size for f in self.staging_dir.iterdir()
                if f.is_file()
            )
            for _, f in files:
                if current_size <= self.max_staging_bytes * 0.8:
                    break  # Stop at 80% capacity
                try:
                    size = f.stat().st_size
                    f.unlink()
                    current_size -= size
                    removed += 1
                except OSError:
                    pass

        if removed:
            logger.info(
                "Staging cleanup: %d files removed (%s mode)",
                removed,
                "aggressive" if aggressive else "normal",
            )
        return removed

    def staging_stats(self) -> dict:
        """Return statistics about the staging directory."""
        files = list(self.staging_dir.iterdir())
        total_size = sum(
            f.stat().st_size for f in files if f.is_file()
        )
        return {
            "staging_dir": str(self.staging_dir),
            "file_count": len(files),
            "total_size_bytes": total_size,
            "total_size_mb": round(total_size / (1024 * 1024), 2),
            "active_handles": len(self._active_handles),
            "max_capacity_mb": self.max_staging_bytes // (1024 * 1024),
            "ttl_seconds": self.ttl_seconds,
        }


# ==========================================================================
# Usage example
# ==========================================================================

if __name__ == "__main__":
    import tempfile

    # Setup: temporary directories
    work_dir = Path(tempfile.mkdtemp(prefix="cow-demo-work-"))
    staging_dir = Path(tempfile.mkdtemp(prefix="cow-demo-staging-"))

    # Create a test file
    test_file = work_dir / "app.conf"
    test_file.write_text("server { listen 80; }\n")

    cow = CopyOnWriteAgent(
        staging_dir=staging_dir,
        snapshot_retention_seconds=300,
        max_staging_size_mb=50,
    )

    # -- Example 1: Write, validate, commit ---------------------------------
    handle = cow.write(
        test_file,
        "server { listen 443 ssl; }\n",
        operation_id="tls-upgrade",
    )
    print(f"Handle: {handle.handle_id}")
    print(f"Staged at: {handle.staged_path}")
    print(f"Current file (unchanged): {test_file.read_text()!r}")

    # Validate staged content
    def nginx_like_validator(p: Path) -> bool:
        content = p.read_text()
        return "listen" in content and "server" in content

    success = cow.commit(handle, validator=nginx_like_validator)
    print(f"Commit success: {success}")
    print(f"File after commit: {test_file.read_text()!r}")

    # -- Example 2: Rollback a committed write ------------------------------
    cow.rollback(handle)
    print(f"File after rollback: {test_file.read_text()!r}")

    # -- Example 3: Batch commit with partial failure -----------------------
    file_a = work_dir / "a.txt"
    file_b = work_dir / "b.txt"
    file_a.write_text("A-original\n")
    file_b.write_text("B-original\n")

    h1 = cow.write(file_a, "A-new\n", operation_id="batch-test")
    h2 = cow.write(file_b, "B-new\n", operation_id="batch-test")

    results = cow.commit_all([h1, h2])
    print(f"Batch results: {results}")
    print(f"File A: {file_a.read_text()!r}")
    print(f"File B: {file_b.read_text()!r}")

    # Stats
    print(f"\nStaging stats: {cow.staging_stats()}")

    # Cleanup
    cow.cleanup(aggressive=True)
    shutil.rmtree(work_dir, ignore_errors=True)
    shutil.rmtree(staging_dir, ignore_errors=True)

设计决策:为什么不直接用文件系统快照?

细心的读者可能会问:如果在 btrfs 或 ZFS 上运行,为什么不直接用文件系统级的快照?答案是三个字:粒度和可移植性。文件系统快照是整个 subvolume 级别的——你无法对单个文件拍快照,也无法在快照之间只恢复一个文件。当 Agent 修改了 3 个文件而只需要回滚其中 1 个时,文件系统快照要么全恢复(丢失另 2 个文件的正确修改),要么不恢复。而 CopyOnWriteAgent 的差异和快照是每文件、每操作粒度的——每个 WriteHandle 独立管理自己的撤销数据。此外,文件系统快照依赖于特定的文件系统——在生产环境中,你不会为了 Agent 回滚而把整个集群从 ext4 迁移到 ZFS。本节的设计是文件系统无关的,适用于任何 POSIX 文件系统。

对于 CopyOnWriteAgent 与第 1 节 safe_agent_write 的关系——前者是后者的"多文件、可组合"升级版。safe_agent_write 适合单个文件的单次写入(上下文管理器模型),而 CopyOnWriteAgent 适合 Agent 在多个步骤中操作多个文件(句柄化模型)。两者共享相同的"先暂存、再验证、后原子替换"流水线哲学,但在生命周期管理上选择了不同的接口风格。

4. 数据级回滚:事务性包装器与状态检查点

文件级回滚解决了 Agent 写坏文件的问题——但如果 Agent 写坏的不是文件,而是数据库中的行、Redis 中的键、或内存中的状态机呢?这些结构化数据的修改与文件修改有本质区别:数据修改通常是增量的(一行 UPDATE 而非整个表替换)、频繁的(一秒数十次写入)、且具有隐式的依赖关系(一行订单状态的变化会影响下游计费逻辑)。文件级快照对此毫无帮助——你不可能在每次 UPDATE users SET status='active' 之前给整个数据库拍快照。

本节讨论数据级回滚的两种核心机制:预写日志(Write-Ahead Log, WAL)提供操作级的撤销能力,状态检查点(State Checkpoint)提供决策边界级的恢复能力。两者互补——WAL 管细粒度操作,检查点管粗粒度状态;WAL 回答"如何撤销第 7 步的这次 UPDATE",检查点回答"如何让 Agent 回到处理第 12 个任务之前的状态"。

预写日志:从文件级到数据级

第 2 节的 WriteAheadLog 演示了 WAL 在文件操作中的应用——存储文件的 pre-image,回滚时用 pre-image 覆写。数据级的 WAL 遵循相同的协议,但操作对象从文件变为结构化数据行

  1. Intent 记录(Intent Log):在修改数据之前,Agent 声明它打算做的操作——包括目标(哪个表/键)、类型(INSERT/UPDATE/DELETE/SET)和前提条件(预期当前值)。这条 intent 记录包含 intent_id 和操作描述,写入 WAL 文件。
  2. Pre-image 保存:对于 UPDATE 和 DELETE 操作,在修改前读取当前行/值并保存。对于 INSERT 操作,pre-image 是空的——回滚意味着删除新插入的行。Pre-image 可以内嵌在 WAL 条目中(小数据),也可以存储在独立的 pre-image 文件中(大数据)。
  3. 提交标记:操作执行成功后,intent 被标记为 committed。对于 crash recovery 来说,committed 的 intent 需要回滚(执行逆向操作);uncommitted 的 intent 可以直接忽略——因为数据可能已经被修改了也可能没有(崩溃发生在写入中途)。
  4. 回滚执行:对于 committed 的 intent,根据 pre-image 执行逆向操作——UPDATE 用 pre-image 替换当前值,INSERT 删除行,DELETE 重新插入行。对于 uncommitted 的 intent,检查当前状态是否与 intent 描述的前提条件一致——如果一致,说明写入确实发生了,需要回滚;如果不一致,说明写入可能未执行或已被后续操作覆盖,记录并告警。

这个协议的关键设计是 intent 先于执行——在数据被修改之前,记录"我将要做什么"和"当前值是什么"。这保证了即使 Agent 进程在写入过程中崩溃,WAL 中也有足够的信息来恢复一致性。相比之下,"先执行再记录"的模型在崩溃时可能丢失记录,导致无法判断哪些操作已经生效。

状态检查点:在决策边界上保存完整状态

WAL 解决单次操作的回滚,但当 Agent 执行了一个包含 50 个步骤的复杂任务并在第 43 步失败时,逐一回滚 42 个 WAL 条目是不现实的——WAL 回滚链的长度与操作数量线性相关,执行时间不可预测,且中间步骤的累积回滚可能引入新的不一致。这就是状态检查点(State Checkpoint)的用武之地。

状态检查点的核心思想:在 Agent 的决策边界(任务开始、子任务完成、外部调用前后、错误恢复点)上,将 Agent 的完整运行时状态序列化并保存。当需要回滚时,直接恢复到最近的合法检查点——不关心中间发生了什么操作。检查点包含的内容:

检查点的触发是声明式的——不依赖 Agent "记得"保存检查点,而是系统在每个决策边界自动触发。实现上,Agent 框架在执行 agent.step() 时会检查是否跨越了决策边界(任务 ID 变化、子任务开始/结束、工具调用返回),并在这些时刻调用 CheckpointManager.save()。Agent 自身不知道检查点的存在——这对它是透明的。

日志重放与崩溃恢复

WAL 和检查点各自独立工作,但在崩溃恢复场景中它们需要协同

  1. 确定恢复起点:Agent 进程重启后,CheckpointManager 加载最近的检查点——这确定了"已知好状态"的基线。
  2. 识别待处理操作:WriteAheadLog 加载检查点时间戳之后的所有 WAL 条目,筛选出 committed 状态的条目——这些是"检查点之后执行成功但可能未被包含在检查点状态中"的操作。
  3. 分类处理:
    • Committed + 副作用可逆:执行回滚——恢复到操作前的数据状态。
    • Committed + 副作用不可逆:记录告警并标记为"需要人工审查"。例如 API 调用已经创建了外部资源,回滚只能执行补偿操作。
    • Uncommitted:检查当前数据状态——如果数据已是新值,执行回滚;如果是旧值,说明操作未生效,直接丢弃 WAL 条目。
  4. 重建执行上下文:从检查点恢复 Agent 的任务队列和上下文摘要,Agent 从检查点的"下一步"继续执行——不是重头开始,不是重复已完成的工作。

这个协同流程确保了两件事:数据一致性——崩溃后的数据状态要么完全恢复到检查点、要么可解释地被标记为需要人工处理;执行连续性——Agent 不会在崩溃后重复执行已成功的操作,避免双重扣款、重复发送邮件等灾难性后果。

WriteAheadLog 与 CheckpointManager 实现

下面的代码实现了数据级的 WAL(带 intent/commit 协议)和 CheckpointManager(带序列化和恢复):

from __future__ import annotations

import hashlib
import json
import os
import time
import uuid
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Optional, Protocol
import logging

logger = logging.getLogger("agent.rollback.data")


# ==========================================================================
# Data-level operation types — structured data, not files
# ==========================================================================

@dataclass
class DataOperation:
    """
    Describes one data-level mutation recognized by the WAL.
    This is NOT the same as Section 2's OperationDescriptor — that one
    is for orchestrating strategy selection. This one represents a single
    concrete data change (one row, one key, one document).
    """
    operation_id: str
    target: str                    # e.g. "db:users:row:42", "redis:session:abc"
    op_type: str                   # INSERT, UPDATE, DELETE, SET, HSET, LPUSH...
    pre_image: Any = None          # Value before the mutation
    new_value: Any = None          # Value being written (for audit)
    preconditions: dict[str, Any] = field(default_factory=dict)
    timestamp: float = field(default_factory=time.time)


# ==========================================================================
# Data store abstraction — the WAL needs to read/write actual data
# ==========================================================================

class DataStore(Protocol):
    """
    Protocol for any data store the WAL can operate on.
    Implementations: SQL database, Redis, in-memory dict, etc.
    """
    def read(self, target: str) -> Any: ...
    def write(self, target: str, value: Any) -> bool: ...
    def delete(self, target: str) -> bool: ...


class InMemoryStore:
    """Trivial in-memory store for demonstration."""
    def __init__(self) -> None:
        self._data: dict[str, Any] = {}

    def read(self, target: str) -> Any:
        return self._data.get(target)

    def write(self, target: str, value: Any) -> bool:
        self._data[target] = value
        return True

    def delete(self, target: str) -> bool:
        return self._data.pop(target, None) is not None


# ==========================================================================
# WriteAheadLog — data-level intent/commit protocol
# ==========================================================================

class WriteAheadLog:
    """
    Data-level Write-Ahead Log with intent/commit protocol.

    Protocol:
      1. log_intent(operation)   → returns intent_id; saves pre-image.
      2. Execute the operation on the data store.
      3. mark_committed(intent_id) → marks the intent as safe-to-undo.

    Crash recovery:
      - get_uncommitted() returns intents not marked committed.
      - get_committed_after(timestamp) returns committed intents that
        may need rollback.
    """

    def __init__(self, wal_dir: str | Path, store: DataStore) -> None:
        self.wal_dir = Path(wal_dir)
        self.wal_dir.mkdir(parents=True, exist_ok=True)
        self.store = store
        self._journal_path = self.wal_dir / "data-wal.jsonl"
        self._pre_image_dir = self.wal_dir / "pre-images"
        self._pre_image_dir.mkdir(exist_ok=True)

    # -- Intent logging ----------------------------------------------------

    def log_intent(
        self,
        target: str,
        op_type: str,
        new_value: Any,
        *,
        operation_id: str = "",
        preconditions: dict[str, Any] | None = None,
    ) -> str:
        """
        Log an intended data operation and capture the pre-image.

        Args:
            target: Data store target (e.g. "db:users:row:42")
            op_type: Operation type (INSERT, UPDATE, DELETE, SET...)
            new_value: The value being written.
            operation_id: Optional caller-supplied ID.
            preconditions: Expected current values (checked at commit time).

        Returns:
            intent_id — used later for mark_committed() or rollback().
        """
        intent_id = operation_id or f"wal-{uuid.uuid4().hex[:16]}"

        # Capture pre-image BEFORE any modification
        pre_image: Any = None
        if op_type in ("UPDATE", "DELETE", "SET", "HSET"):
            pre_image = self.store.read(target)

        # Save large pre-images to disk, small ones inline in the journal
        pre_image_ref: str | None = None
        pre_image_inline: Any = pre_image
        if isinstance(pre_image, (str, bytes)) and len(str(pre_image)) > 4096:
            # Large pre-image: save to file
            pre_image_ref = self._save_pre_image(intent_id, pre_image)
            pre_image_inline = None

        entry = {
            "intent_id": intent_id,
            "target": target,
            "op_type": op_type,
            "new_value": str(new_value)[:1024],  # Truncate for log readability
            "pre_image": pre_image_inline,
            "pre_image_ref": pre_image_ref,
            "preconditions": preconditions or {},
            "status": "intended",          # intended → committed → rolled_back
            "timestamp": time.time(),
            "checksum": hashlib.sha256(
                json.dumps({
                    "target": target,
                    "op_type": op_type,
                    "pre_image": str(pre_image)[:1024],
                }, sort_keys=True, default=str).encode()
            ).hexdigest()[:16],
        }
        self._append_journal(entry)

        logger.debug(
            "WAL intent logged: %s %s %s (pre_image=%s)",
            intent_id, op_type, target,
            type(pre_image).__name__ if pre_image is not None else "None",
        )
        return intent_id

    # -- Commit marking ----------------------------------------------------

    def mark_committed(self, intent_id: str) -> bool:
        """
        Mark a previously logged intent as committed.

        After this call:
          - The operation is considered successfully executed.
          - On rollback, the pre-image will be restored.
          - On crash recovery, this intent will be replayed if needed.

        Returns True if the intent was found and marked, False otherwise.
        """
        return self._update_status(intent_id, "committed")

    def mark_rolled_back(self, intent_id: str) -> bool:
        """Mark an intent as rolled back (for audit trail)."""
        return self._update_status(intent_id, "rolled_back")

    # -- Crash recovery queries --------------------------------------------

    def get_uncommitted(self) -> list[dict]:
        """
        Return all intents that were logged but never committed.

        These represent operations that were intended but may or may not
        have been executed. Crash recovery should:
          1. Check if the data store reflects the new_value.
          2. If yes → rollback (restore pre_image).
          3. If no → discard (operation never executed).
        """
        return self._query_by_status("intended")

    def get_committed_after(self, since_timestamp: float) -> list[dict]:
        """
        Return all committed intents after a given timestamp.

        Used after restoring a checkpoint: find all operations that
        succeeded after the checkpoint was saved.
        """
        entries = self._query_by_status("committed")
        return [
            e for e in entries
            if e.get("timestamp", 0) > since_timestamp
        ]

    def get_all_intents(self, since_timestamp: float = 0) -> list[dict]:
        """Return all journal entries (all statuses) after a timestamp."""
        if not self._journal_path.exists():
            return []
        entries = []
        with open(self._journal_path) as f:
            for line in f:
                if not line.strip():
                    continue
                try:
                    entry = json.loads(line)
                    if entry.get("timestamp", 0) >= since_timestamp:
                        entries.append(entry)
                except json.JSONDecodeError:
                    logger.warning(
                        "Corrupt WAL line skipped: %.80s...", line,
                    )
        return entries

    # -- Rollback execution ------------------------------------------------

    def rollback_intent(self, intent_id: str) -> bool:
        """
        Execute rollback for a single intent.

        Logic:
          - INSERT: delete the target.
          - UPDATE/SET: restore pre_image.
          - DELETE: re-insert pre_image.

        Returns True if rollback succeeded.
        """
        entry = self._find_entry(intent_id)
        if entry is None:
            logger.error("Rollback failed: intent %s not found", intent_id)
            return False

        op_type = entry["op_type"]
        target = entry["target"]

        try:
            if op_type == "INSERT":
                self.store.delete(target)
                logger.info(
                    "Rollback INSERT: deleted %s", target,
                )

            elif op_type in ("UPDATE", "SET", "HSET"):
                pre_image = entry.get("pre_image")
                if pre_image is None and entry.get("pre_image_ref"):
                    pre_image = self._load_pre_image(entry["pre_image_ref"])
                if pre_image is not None:
                    self.store.write(target, pre_image)
                    logger.info(
                        "Rollback %s: restored pre-image for %s",
                        op_type, target,
                    )
                else:
                    logger.warning(
                        "Rollback %s: no pre-image for %s, cannot restore",
                        op_type, target,
                    )
                    return False

            elif op_type == "DELETE":
                pre_image = entry.get("pre_image")
                if pre_image is None and entry.get("pre_image_ref"):
                    pre_image = self._load_pre_image(entry["pre_image_ref"])
                if pre_image is not None:
                    self.store.write(target, pre_image)
                    logger.info(
                        "Rollback DELETE: re-inserted %s", target,
                    )
                else:
                    logger.warning(
                        "Rollback DELETE: no pre-image for %s", target,
                    )
                    return False

            else:
                logger.error(
                    "Unknown op_type for rollback: %s", op_type,
                )
                return False

            self._update_status(intent_id, "rolled_back")
            return True

        except Exception as exc:
            logger.exception(
                "Rollback exception for intent %s: %s", intent_id, exc,
            )
            return False

    # -- Internal helpers --------------------------------------------------

    def _append_journal(self, entry: dict) -> None:
        """Append a JSON line to the WAL journal (append-only)."""
        with open(self._journal_path, "a") as f:
            f.write(json.dumps(entry, default=str) + "\n")

    def _update_status(self, intent_id: str, new_status: str) -> bool:
        """
        Update the status of an intent in the journal.
        WAL is append-only for integrity, so this appends a new status entry.
        """
        entry = self._find_entry(intent_id)
        if entry is None:
            return False
        status_entry = {
            "intent_id": intent_id,
            "target": entry["target"],
            "op_type": entry["op_type"],
            "status": new_status,
            "timestamp": time.time(),
            "previous_status": entry.get("status", "unknown"),
        }
        self._append_journal(status_entry)
        return True

    def _find_entry(self, intent_id: str) -> dict | None:
        """Find the most recent entry for a given intent_id."""
        if not self._journal_path.exists():
            return None
        found: dict | None = None
        with open(self._journal_path) as f:
            for line in f:
                if not line.strip():
                    continue
                try:
                    entry = json.loads(line)
                    if entry.get("intent_id") == intent_id:
                        found = entry
                except json.JSONDecodeError:
                    pass
        return found

    def _query_by_status(self, status: str) -> list[dict]:
        """Return all latest entries with a given status."""
        if not self._journal_path.exists():
            return []
        # Collect latest entry per intent_id
        latest: dict[str, dict] = {}
        with open(self._journal_path) as f:
            for line in f:
                if not line.strip():
                    continue
                try:
                    entry = json.loads(line)
                    iid = entry.get("intent_id", "")
                    if iid not in latest or (
                        entry.get("timestamp", 0)
                        > latest[iid].get("timestamp", 0)
                    ):
                        latest[iid] = entry
                except json.JSONDecodeError:
                    pass
        return [
            e for e in latest.values()
            if e.get("status") == status
        ]

    def _save_pre_image(self, intent_id: str, data: Any) -> str:
        """Save a large pre-image to a file, return the file path."""
        file_path = self._pre_image_dir / f"{intent_id}.pre"
        if isinstance(data, bytes):
            file_path.write_bytes(data)
        else:
            file_path.write_text(str(data), encoding="utf-8")
        return str(file_path)

    def _load_pre_image(self, ref: str) -> Any:
        """Load a pre-image from a file reference."""
        file_path = Path(ref)
        if not file_path.exists():
            return None
        return file_path.read_bytes()


# ==========================================================================
# CheckpointManager — full agent state serialization at decision boundaries
# ==========================================================================

@dataclass
class AgentCheckpoint:
    """
    A serialized snapshot of the Agent's full runtime state at a
    decision boundary.
    """
    checkpoint_id: str
    label: str                     # Human-readable label (e.g. "task-12-start")
    timestamp: float
    # Agent task state
    current_task_id: str = ""
    completed_steps: list[str] = field(default_factory=list)
    pending_steps: list[str] = field(default_factory=list)
    # Workspace manifest (file paths → checksums, not file contents)
    workspace_manifest: dict[str, str] = field(default_factory=dict)
    # Context window digest (compressed summary, not full context)
    context_digest: str = ""
    # External resource references
    external_resources: dict[str, str] = field(default_factory=dict)
    # Arbitrary extra state
    extra: dict[str, Any] = field(default_factory=dict)


class CheckpointManager:
    """
    Manages Agent state checkpoints — serialize and restore at
    decision boundaries.

    Decision boundaries:
      - Task/subtask start and end
      - Before and after external API calls
      - After every N steps (configurable, default 10)
      - On error (before error handler runs)
      - On explicit checkpoint request
    """

    def __init__(
        self,
        checkpoint_dir: str | Path,
        *,
        max_checkpoints: int = 50,
        auto_checkpoint_every_n_steps: int = 10,
    ) -> None:
        self.checkpoint_dir = Path(checkpoint_dir)
        self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
        self.max_checkpoints = max_checkpoints
        self.auto_checkpoint_interval = auto_checkpoint_every_n_steps
        self._step_counter: int = 0
        self._checkpoints: list[AgentCheckpoint] = []

    # -- Checkpoint lifecycle ----------------------------------------------

    def save(
        self,
        label: str,
        *,
        current_task_id: str = "",
        completed_steps: list[str] | None = None,
        pending_steps: list[str] | None = None,
        workspace_manifest: dict[str, str] | None = None,
        context_digest: str = "",
        external_resources: dict[str, str] | None = None,
        extra: dict[str, Any] | None = None,
    ) -> str:
        """
        Save an agent state checkpoint. Returns the checkpoint_id.

        This is the primary API — called by the Agent framework at every
        decision boundary. The checkpoint is immediately serialized to
        disk for crash safety.
        """
        checkpoint_id = f"ckpt-{uuid.uuid4().hex[:16]}"
        ckpt = AgentCheckpoint(
            checkpoint_id=checkpoint_id,
            label=label,
            timestamp=time.time(),
            current_task_id=current_task_id,
            completed_steps=list(completed_steps or []),
            pending_steps=list(pending_steps or []),
            workspace_manifest=dict(workspace_manifest or {}),
            context_digest=context_digest,
            external_resources=dict(external_resources or {}),
            extra=dict(extra or {}),
        )

        # Serialize to disk immediately (crash safety)
        self._write_checkpoint(ckpt)

        self._checkpoints.append(ckpt)
        logger.info(
            "Checkpoint saved: %s (label=%s, task=%s, %d steps)",
            checkpoint_id, label, current_task_id,
            len(ckpt.completed_steps),
        )

        # Enforce retention limit
        while len(self._checkpoints) > self.max_checkpoints:
            oldest = self._checkpoints.pop(0)
            self._delete_checkpoint(oldest.checkpoint_id)
            logger.debug(
                "Checkpoint pruned (retention): %s", oldest.checkpoint_id,
            )

        return checkpoint_id

    def restore(self, checkpoint_id: str | None = None) -> AgentCheckpoint | None:
        """
        Restore agent state from a checkpoint.

        If checkpoint_id is None, restore the latest checkpoint.
        Returns None if no checkpoint found.
        """
        if checkpoint_id is None:
            if not self._checkpoints:
                # Try loading from disk
                self._load_checkpoints()
            if not self._checkpoints:
                logger.warning("No checkpoints available for restore")
                return None
            checkpoint_id = self._checkpoints[-1].checkpoint_id

        # Search in-memory first, then disk
        for c in self._checkpoints:
            if c.checkpoint_id == checkpoint_id:
                logger.info(
                    "Checkpoint restored: %s (label=%s)",
                    c.checkpoint_id, c.label,
                )
                return c

        # Load from disk
        ckpt = self._read_checkpoint(checkpoint_id)
        if ckpt:
            logger.info(
                "Checkpoint restored from disk: %s (label=%s)",
                ckpt.checkpoint_id, ckpt.label,
            )
            return ckpt

        logger.error("Checkpoint not found: %s", checkpoint_id)
        return None

    def list_checkpoints(self) -> list[dict[str, Any]]:
        """Return a summary of all checkpoints."""
        if not self._checkpoints:
            self._load_checkpoints()
        return [
            {
                "checkpoint_id": c.checkpoint_id,
                "label": c.label,
                "timestamp": c.timestamp,
                "task_id": c.current_task_id,
                "completed_steps": len(c.completed_steps),
                "pending_steps": len(c.pending_steps),
                "files_in_manifest": len(c.workspace_manifest),
            }
            for c in self._checkpoints
        ]

    def delete_checkpoint(self, checkpoint_id: str) -> bool:
        """Delete a specific checkpoint."""
        self._checkpoints = [
            c for c in self._checkpoints
            if c.checkpoint_id != checkpoint_id
        ]
        return self._delete_checkpoint(checkpoint_id)

    # -- Auto-checkpoint hook ----------------------------------------------

    def on_step(self, agent_state: dict[str, Any]) -> str | None:
        """
        Called after each agent.step(). Automatically saves a checkpoint
        if the step counter reaches the configured interval.

        Returns checkpoint_id if saved, None otherwise.
        """
        self._step_counter += 1
        if self._step_counter % self.auto_checkpoint_interval != 0:
            return None

        return self.save(
            label=f"auto-step-{self._step_counter}",
            current_task_id=agent_state.get("task_id", ""),
            completed_steps=agent_state.get("completed_steps", []),
            pending_steps=agent_state.get("pending_steps", []),
            workspace_manifest=agent_state.get("workspace_manifest", {}),
            context_digest=agent_state.get("context_digest", ""),
            external_resources=agent_state.get("external_resources", {}),
        )

    # -- Serialization (disk I/O) ------------------------------------------

    def _checkpoint_path(self, checkpoint_id: str) -> Path:
        return self.checkpoint_dir / f"{checkpoint_id}.json"

    def _write_checkpoint(self, ckpt: AgentCheckpoint) -> None:
        """Serialize a checkpoint to a JSON file."""
        path = self._checkpoint_path(ckpt.checkpoint_id)
        data = {
            "checkpoint_id": ckpt.checkpoint_id,
            "label": ckpt.label,
            "timestamp": ckpt.timestamp,
            "current_task_id": ckpt.current_task_id,
            "completed_steps": ckpt.completed_steps,
            "pending_steps": ckpt.pending_steps,
            "workspace_manifest": ckpt.workspace_manifest,
            "context_digest": ckpt.context_digest,
            "external_resources": ckpt.external_resources,
            "extra": ckpt.extra,
        }
        # Atomic write: write to temp file, then rename
        tmp_path = path.with_suffix(".tmp")
        tmp_path.write_text(
            json.dumps(data, indent=2, default=str, ensure_ascii=False),
            encoding="utf-8",
        )
        os.replace(tmp_path, path)

    def _read_checkpoint(self, checkpoint_id: str) -> AgentCheckpoint | None:
        """Deserialize a checkpoint from a JSON file."""
        path = self._checkpoint_path(checkpoint_id)
        if not path.exists():
            return None
        try:
            data = json.loads(path.read_text(encoding="utf-8"))
            return AgentCheckpoint(
                checkpoint_id=data["checkpoint_id"],
                label=data["label"],
                timestamp=data["timestamp"],
                current_task_id=data.get("current_task_id", ""),
                completed_steps=data.get("completed_steps", []),
                pending_steps=data.get("pending_steps", []),
                workspace_manifest=data.get("workspace_manifest", {}),
                context_digest=data.get("context_digest", ""),
                external_resources=data.get("external_resources", {}),
                extra=data.get("extra", {}),
            )
        except (json.JSONDecodeError, KeyError) as exc:
            logger.error(
                "Corrupt checkpoint file %s: %s", path, exc,
            )
            return None

    def _delete_checkpoint(self, checkpoint_id: str) -> bool:
        """Delete a checkpoint file from disk."""
        path = self._checkpoint_path(checkpoint_id)
        if path.exists():
            path.unlink()
            return True
        return False

    def _load_checkpoints(self) -> None:
        """Load all checkpoints from disk into memory."""
        if not self.checkpoint_dir.exists():
            return
        loaded: list[AgentCheckpoint] = []
        for path in sorted(self.checkpoint_dir.glob("*.json")):
            ckpt_id = path.stem
            ckpt = self._read_checkpoint(ckpt_id)
            if ckpt:
                loaded.append(ckpt)
        self._checkpoints = sorted(loaded, key=lambda c: c.timestamp)
        logger.debug("Loaded %d checkpoints from disk", len(self._checkpoints))


# ==========================================================================
# Usage example — WAL + CheckpointManager crash recovery
# ==========================================================================

if __name__ == "__main__":
    import tempfile

    tmp = Path(tempfile.mkdtemp(prefix="data-rollback-demo-"))

    # -- WAL demonstration --------------------------------------------------
    store = InMemoryStore()
    store.write("db:users:row:1", {"name": "Alice", "status": "inactive"})
    print(f"Initial: {store.read('db:users:row:1')}")

    wal = WriteAheadLog(wal_dir=tmp / "wal", store=store)

    # Log intent, then "execute", then commit
    intent_id = wal.log_intent(
        target="db:users:row:1",
        op_type="UPDATE",
        new_value={"name": "Alice", "status": "active"},
    )
    print(f"Intent logged: {intent_id}")

    # Simulate Agent executing the update
    store.write(
        "db:users:row:1",
        {"name": "Alice", "status": "active"},
    )
    wal.mark_committed(intent_id)
    print(f"After commit: {store.read('db:users:row:1')}")

    # Rollback!
    success = wal.rollback_intent(intent_id)
    print(f"Rollback success: {success}")
    print(f"After rollback: {store.read('db:users:row:1')}")

    # -- CheckpointManager demonstration -----------------------------------
    ckpt_mgr = CheckpointManager(
        checkpoint_dir=tmp / "checkpoints",
        max_checkpoints=5,
    )

    # Simulate an Agent operating on a task
    ckpt_id = ckpt_mgr.save(
        label="task-42-before-processing",
        current_task_id="task-42",
        completed_steps=["step-1", "step-2"],
        pending_steps=["step-3", "step-4", "step-5"],
        workspace_manifest={
            "/work/config.yaml": "abc123def",
            "/work/output.json": "456789abc",
        },
        context_digest="User asked to update 5 config files...",
        external_resources={
            "stripe_sub_id": "sub_xyz789",
            "github_pr": "https://github.com/org/repo/pull/123",
        },
    )
    print(f"\nCheckpoint saved: {ckpt_id}")

    # Simulate crash... Agent restarts and restores
    ckpt = ckpt_mgr.restore()
    if ckpt:
        print(f"Restored checkpoint: {ckpt.label}")
        print(f"  Task: {ckpt.current_task_id}")
        print(f"  Completed: {ckpt.completed_steps}")
        print(f"  Pending: {ckpt.pending_steps}")
        print(f"  Workspace files: {len(ckpt.workspace_manifest)}")
        print(f"  External resources: {ckpt.external_resources}")

    # List checkpoints
    print(f"\nAll checkpoints: {json.dumps(ckpt_mgr.list_checkpoints(), indent=2)}")

    # Cleanup
    import shutil
    shutil.rmtree(tmp, ignore_errors=True)

设计决策

数据级 WAL 和检查点管理器中有几个关键的设计折衷需要明确:

  1. WAL 的 append-only 性质:日志文件从不原地修改——状态更新(intended → committed → rolled_back)以追加新行的方式写入。这种设计牺牲了磁盘空间(一个 intent 可能有多条日志行),换取了崩溃安全性——即使在写入状态更新行时崩溃,日志文件最多丢失最后一行,而不会损坏已有的条目。对于 Agent 回滚这种低频但关键的场景,append-only 的可靠性代价是可接受的。
  2. Pre-image 的存储分界:4KB 以下的数据 pre-image 内嵌在日志行中,4KB 以上的存储在独立文件中。这个分界值基于两个考虑:日志行保持在可 grep 的大小内(便于人工排查),以及大多数数据操作的行级 pre-image 远小于 4KB(用户记录、配置值、缓存键)。独立的 pre-image 文件适用于 blob 数据(如 JSON 文档、序列化对象)。
  3. 检查点不包含文件内容:检查点的 workspace_manifest 只记录文件路径和校验和——不保存文件内容。文件内容的回滚由第 3 节的 CopyOnWriteAgent 负责。这个分工避免了检查点序列化成为新的存储瓶颈——Agent 可能操作数百个文件,但检查点只记录一个校验和列表。恢复时,CheckpointManager 验证每个文件校验和是否匹配——如果不匹配(文件被修改过),委托给 CopyOnWriteAgent 进行文件级回滚。
  4. 上下文摘要 vs. 完整上下文:保存 Agent 的完整 LLM 上下文窗口(通常 32K-200K tokens)到检查点中会使每个检查点膨胀到数十 MB——这是不可接受的。本文选择的方案是保存上下文的压缩摘要(关键决策、提示词、结论),用于审计和恢复时的方向指引。Agent 恢复时不是"盲恢复"到检查点——它可以从检查点读取"我上次的任务是 X,已经完成了步骤 A、B、C",然后用这些信息重新构建执行上下文。关于 Agent 的可观察性基础设施——包括如何捕获这些决策边界信息——参见 Agent 可观察性

关于如何在运行时隔离环境中协同使用文件级回滚和数据级回滚——例如在 Docker 容器中同时应用 CoW 快照和 WAL——参见 Agent 运行时隔离。关于审计日志如何利用 WAL 条目作为不可篡改的操作记录,参见 Agent 审计日志设计。关于在执行安全层面如何限制单次写入的爆炸半径,从而减少需要回滚的范围,参见 Agent 命令执行安全

5. 环境级回滚:容器快照与不可变基础设施

第 3 节和第 4 节分别解决了文件和数据层面的回滚问题——但当 Agent 的破坏超出文件和数据,进入运行时环境层面时,这些机制就不够用了。假设 Agent 执行了一个"优化系统性能"的任务——它修改了 /etc/sysctl.conf、重启了多个服务、添加了 iptables 规则、变更了 systemd unit 文件,还在 /usr/local/bin 中安装了自定义脚本。这些变更跨越了文件、进程、内核参数和包管理——逐一撤销每个变更需要精确记录每一个操作并编写对应的逆向逻辑,而这在实践中几乎不可能做到完全可靠。你需要的不是一个逐文件、逐配置项的回滚器——你需要的是一颗"环境时间胶囊",让整个运行时环境回到操作之前的状态,不管 Agent 在中间做了什么。

这就是环境级回滚的核心理念:把回滚的粒度从"文件"和"行"提升到"容器"和"基础设施栈"。它的基础假设是:如果无法可靠地跟踪和撤销 Agent 的每一个微观操作,那就放弃微观追踪,改用整体快照——在 Agent 开始工作之前保存一份完整的环境副本,完成后要么提交变更、要么丢弃整个环境并恢复到快照。这种策略在容器化和基础设施即代码(IaC)盛行的今天,不仅可行,而且在很多场景下比逐操作回滚更简单、更可靠。

Docker Overlay 文件系统快照:利用容器层实现零拷贝快照

Docker 的 overlay2 存储驱动在实现容器镜像分层时已经内置了"快照即写"的能力——每个容器层只记录相对于父层的增量,父层自身不可变。环境级回滚可以直接利用这一机制:在 Agent 开始工作之前,对容器文件系统创建一个 overlay 快照层;Agent 的所有写入都发生在这个快照层上;如果需要回滚,丢弃快照层即可——原始文件系统层完全未被触碰。

具体实现上,Docker 的 docker commit 可以将容器的当前文件系统保存为一个新的镜像层,等效于一个命名快照。但在 Agent 回滚的场景中,更高效的路径是直接操作 overlay 文件系统——不经过 Docker daemon 的序列化/反序列化开销。在一个 overlay 挂载中,下层(lowerdir)是只读的原始状态,上层(upperdir)包含所有修改,工作层(workdir)供 overlay 内部使用。Agent 回滚只需一步:删除 upperdir 中的所有内容——下层未被修改,文件系统瞬间回到 Agent 操作之前的完整状态。

这种 overlay 快照方案的优势在于三点:速度快——回滚是一次 rm -rf upperdir/*(或更精确的清理),不需要复制文件;存储效率高——overlay 只存储 Agent 实际修改过的文件,而不是整个文件系统的副本;与容器编排天然集成——在 Kubernetes 中,Pod 的重启或重建本身就是一种环境级回滚——旧的容器被终止,新的容器从不可变的镜像启动,Agent 的所有运行时修改随旧容器一起消失。关于 Agent 在容器中的隔离设计,参见 Agent 运行时隔离

容器检查点/恢复:CRIU 集成与进程状态冻结

Overlay 快照解决了文件系统回滚,但文件系统之外还有进程状态——Agent 可能启动了后台进程、打开了网络连接、在内存中维护了临时缓存。单纯的 overlay 回滚无法恢复进程状态。这就是 CRIU(Checkpoint/Restore In Userspace)的用武之地:将容器的完整进程状态(包括内存页、文件描述符、网络连接、信号处理器)序列化到磁盘,然后可以在任何时间点恢复到完全相同的进程状态

CRIU 在 Agent 回滚中的作用是两面的:一方面,作为检查点——在 Agent 任务的关键阶段(如在执行一个高风险操作之前)对容器进行完整检查点保存,如果后续操作出错,恢复检查点即可回到操作前的精确状态——不只是文件内容,还包括 Agent 进程正在做什么、它的内存中有什么、它的网络连接处于什么状态。另一方面,作为跨主机的环境迁移——如果一台主机上的 Agent 执行环境已经被破坏到无法信任的程度(例如内核参数被修改),可以将容器检查点迁移到另一台干净的主机上恢复执行。

CRIU 在生产环境中有一些现实约束需要注意:检查点保存过程会冻结进程——对于长时间运行的 Agent 任务,冻结的耗时与进程的内存占用成正比(一个 2GB 内存的进程大约需要 1-3 秒的冻结时间);检查点文件大小约等于进程内存占用——一个大型 Agent 进程的检查点可能达到数 GB;网络连接在恢复后可能处于 TCP_CLOSE 状态——需要在恢复时重建连接。这些约束意味着 CRIU 不适合作为"每个操作都做一次检查点"的细粒度回滚工具——它的定位是粗粒度的环境检查点,每个任务开始前做一次,或者在识别到高风险操作序列时触发。

不可变部署模式:蓝绿部署与金丝雀发布

Agent 环境回滚的一个理念升级是从"修复损坏的环境"转变为"永远不修复——只替换"。这正是不可变基础设施(Immutable Infrastructure)的核心信条:环境被创建后永不修改,任何变更都通过创建新环境、验证、然后切换流量来实现。当 Agent 需要"回滚"时,不是尝试修复当前环境,而是将流量切回之前的环境版本——当前环境随后被销毁。

在 Agent 部署的上下文中,两种不可变部署模式特别适合环境回滚:

这两种模式的共同点是:回滚不需要 Agent 自己执行任何撤销逻辑——Agent 甚至不需要知道自己处于蓝绿或金丝雀模式下。回滚是由编排层(orchestrator)通过流量切换完成的。Agent 的错误变更永远不必被"修复"——它们被"替换"掉了。关于 Agent 状态机在不同部署模式下的行为差异,参见 Agent 状态机设计

基础设施即代码回滚:Terraform 状态管理

Agent 的环境变更不仅限于容器内部——它可能创建或修改云资源(如 AWS RDS 实例、S3 存储桶、IAM 策略),这些变更不受 overlay 快照或 CRIU 检查点的影响。当 Agent 通过 Terraform 等 IaC 工具管理基础设施时,回滚机制必须扩展到基础设施状态层

Terraform 的 terraform state 文件是基础设施回滚的核心——它记录了所有已创建资源的元数据(资源 ID、属性、依赖关系)。当 Agent 执行 terraform apply 修改了基础设施后需要回滚时,有三种选项:

  1. 状态回退(State Rollback):terraform apply 之前备份 state 文件,回滚时用备份覆盖当前 state 文件,然后执行 terraform apply——Terraform 对比新 state(实际上是旧 state)和实际基础设施,生成一个"恢复计划"并执行。这个方法假设 Terraform 可以计算从当前状态到目标状态的逆向变更,且变更过程本身不会引入破坏。
  2. Plan 反转(Inverse Plan):terraform apply 之前保存生成的 plan 文件(-out=plan.tfplan),回滚时通过 diff 工具分析 plan 文件中的变更,生成一个反向的 Terraform 配置,然后 apply。这个方法的精确度高于状态回退——它基于已知的 forward plan 构造 inverse plan,而不是依赖 Terraform 的状态推断。
  3. 销毁并重建(Destroy & Recreate):最彻底的方案——在 Agent 执行之前保存完整的 Terraform 配置(.tf 文件)和 state 文件;回滚时执行 terraform destroy(销毁所有资源),然后用保存的配置和 state 文件执行 terraform apply(重建资源)。这是环境回滚的终极武器——不修复变更,直接销毁整个基础设施栈然后重建。代价是停机时间和资源重建开销。

在 Agent 基础设施回滚的实践中最常见的错误是只备份 state 文件而不备份前一个 apply 的 plan 文件。state 文件包含"资源当前是什么",但缺乏"资源在 apply 之前是什么"的信息——没有前一个 plan,你就不知道哪些资源被修改了、如何修改的。因此一个好的实践是:每次 Agent 触发的 terraform apply 之前,自动执行 terraform plan -out=pre-apply.plan 并保存该 plan 文件。这个 plan 既是审计记录(Agent 打算做什么),也是回滚的逆向依据。

干净恢复:销毁并重建的最后手段

当文件回滚、数据回滚、容器快照、基础设施回退都无法恢复到一个确定的干净状态时——例如 Agent 的操作触发了某个内核 panic、损坏了文件系统元数据、或者修改了多个紧密耦合的系统导致无法判断恢复顺序——唯一的选择就是销毁并重建(Destroy and Recreate)。这不是一个"优雅"的策略——它粗暴、昂贵、有停机代价——但在某些场景下,它是唯一有效的策略。

销毁并重建的前提是环境的完整定义必须存在——没有一份可执行的环境定义(Docker Compose 文件、Kubernetes 清单、Terraform 配置),销毁之后就无法重建。这意味着环境级回滚的底线要求是:Agent 执行环境必须由声明式配置管理,且该配置在 Agent 操作之前必须被保存到一个不受 Agent 修改影响的位置。Git 是天然适合这个角色的存储——Agent 在开始任务之前,将当前环境的声明式配置和 state 文件提交到一个独立的分支或仓库中,作为"保险单"。

下面的代码实现了一个 ContainerEnvironment 类——它集成了 overlay 快照、容器检查点、配置备份和销毁重建四个环境级回滚机制:

from __future__ import annotations

import json
import os
import shutil
import subprocess
import tarfile
import tempfile
import time
import uuid
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional, Callable
import logging

logger = logging.getLogger("agent.rollback.environment")


# ==========================================================================
# ContainerEnvironment — environment-level rollback for containerized agents
# ==========================================================================

@dataclass
class EnvSnapshot:
    """
    Represents a complete environment snapshot at a point in time.

    Contains:
      - Overlay layer snapshot (container filesystem diff)
      - Container checkpoint (optional, via CRIU or docker checkpoint)
      - Infrastructure state backup (Terraform state, K8s manifests)
      - Metadata for identification and recovery
    """
    snapshot_id: str
    container_id: str
    label: str
    timestamp: float = field(default_factory=time.time)
    # Overlay snapshot
    overlay_upper_dir: Optional[Path] = None
    overlay_backup_path: Optional[Path] = None
    # Container checkpoint
    checkpoint_path: Optional[Path] = None
    checkpoint_available: bool = False
    # Infrastructure state
    terraform_state_backup: Optional[Path] = None
    terraform_plan_backup: Optional[Path] = None
    k8s_manifests_backup: Optional[Path] = None
    # Environment config (the declarative definition)
    env_config_path: Optional[Path] = None


@dataclass
class ContainerEnvironment:
    """
    Environment-level rollback manager.

    Provides four levels of recovery, in increasing order of cost:
      1. Overlay rollback  — discard Agent's filesystem changes via overlay
                              layer cleanup. Fastest, lossless for files.
      2. Checkpoint restore — restore full container state (files + processes
                              + memory + network) from a CRIU checkpoint.
                              Includes process state that overlay misses.
      3. IaC rollback       — restore infrastructure state via Terraform state
                              rollback or inverse plan application.
      4. Destroy & recreate — nuclear option: destroy everything, rebuild
                              from declarative config. Most expensive, always
                              works (if config is intact).

    Design principle:
      Always try the cheapest mechanism first. Escalate only when cheaper
      mechanisms fail or don't apply to the current state.
    """

    def __init__(
        self,
        container_id: str = "",
        *,
        snapshot_dir: str | Path = "/var/agent/env-snapshots",
        overlay_root: str | Path = "/var/lib/docker/overlay2",
        max_snapshots: int = 10,
    ) -> None:
        self.container_id = container_id
        self.snapshot_dir = Path(snapshot_dir)
        self.snapshot_dir.mkdir(parents=True, exist_ok=True)
        self.overlay_root = Path(overlay_root)
        self.max_snapshots = max_snapshots
        self._snapshots: list[EnvSnapshot] = []

    # ------------------------------------------------------------------
    # Snapshot — save the current environment state
    # ------------------------------------------------------------------

    def snapshot(
        self,
        container_id: str,
        *,
        label: str = "",
        include_checkpoint: bool = False,
        backup_terraform: bool = False,
        terraform_dir: str | Path = "",
        backup_k8s_manifests: bool = False,
        k8s_manifest_paths: list[str] | None = None,
    ) -> str:
        """
        Create a full environment snapshot.

        Steps:
          1. Snapshot the overlay filesystem (backup upperdir).
          2. Optionally create a CRIU container checkpoint.
          3. Optionally backup Terraform state + plan files.
          4. Optionally backup K8s manifest files.

        Returns the snapshot_id for later restore via rollback().
        """
        snapshot_id = f"env-snap-{uuid.uuid4().hex[:16]}"
        snap_label = label or f"snapshot-{time.strftime('%Y%m%d-%H%M%S')}"
        snap = EnvSnapshot(
            snapshot_id=snapshot_id,
            container_id=container_id,
            label=snap_label,
            env_config_path=None,
        )

        # -- Step 1: Overlay filesystem snapshot --------------------------
        try:
            overlay_snap_path = self._snapshot_overlay(container_id, snapshot_id)
            snap.overlay_backup_path = overlay_snap_path
            logger.info(
                "Overlay snapshot saved: %s (%s)",
                snapshot_id, overlay_snap_path,
            )
        except Exception as exc:
            logger.warning(
                "Overlay snapshot failed (non-fatal): %s", exc,
            )
            # Overlay snapshot failure is non-fatal — continue with other
            # snapshot types. During rollback, absence of overlay snapshot
            # means we must escalate to destroy & recreate.

        # -- Step 2: Container checkpoint (optional, CRIU/docker) --------
        if include_checkpoint:
            try:
                ckpt_path = self._create_checkpoint(container_id, snapshot_id)
                snap.checkpoint_path = ckpt_path
                snap.checkpoint_available = True
                logger.info(
                    "Container checkpoint saved: %s", ckpt_path,
                )
            except Exception as exc:
                logger.warning(
                    "Container checkpoint failed (non-fatal): %s", exc,
                )

        # -- Step 3: Terraform state + plan backup -----------------------
        if backup_terraform and terraform_dir:
            try:
                tf_dir = Path(terraform_dir)
                # Backup terraform.tfstate
                tf_state = tf_dir / "terraform.tfstate"
                if tf_state.exists():
                    backup_path = (
                        self.snapshot_dir
                        / f"{snapshot_id}-terraform.tfstate"
                    )
                    shutil.copy2(tf_state, backup_path)
                    snap.terraform_state_backup = backup_path

                # Backup the last plan file if available
                tf_plan = tf_dir / "pre-apply.plan"
                if tf_plan.exists():
                    backup_plan = (
                        self.snapshot_dir
                        / f"{snapshot_id}-terraform.plan"
                    )
                    shutil.copy2(tf_plan, backup_plan)
                    snap.terraform_plan_backup = backup_plan

                logger.info(
                    "Terraform state backup saved for %s", snapshot_id,
                )
            except Exception as exc:
                logger.warning(
                    "Terraform backup failed (non-fatal): %s", exc,
                )

        # -- Step 4: K8s manifests backup --------------------------------
        if backup_k8s_manifests and k8s_manifest_paths:
            try:
                manifests_tar = (
                    self.snapshot_dir
                    / f"{snapshot_id}-k8s-manifests.tar.gz"
                )
                with tarfile.open(manifests_tar, "w:gz") as tar:
                    for mp in k8s_manifest_paths:
                        p = Path(mp)
                        if p.exists():
                            tar.add(p, arcname=p.name)
                snap.k8s_manifests_backup = manifests_tar
                logger.info(
                    "K8s manifests backup saved: %d files",
                    len(k8s_manifest_paths),
                )
            except Exception as exc:
                logger.warning(
                    "K8s manifests backup failed (non-fatal): %s", exc,
                )

        self._snapshots.append(snap)

        # Enforce retention
        while len(self._snapshots) > self.max_snapshots:
            oldest = self._snapshots.pop(0)
            self._cleanup_snapshot(oldest)

        logger.info(
            "Environment snapshot complete: %s (label=%s, "
            "overlay=%s, checkpoint=%s, terraform=%s)",
            snapshot_id, snap_label,
            snap.overlay_backup_path is not None,
            snap.checkpoint_available,
            snap.terraform_state_backup is not None,
        )
        return snapshot_id

    # ------------------------------------------------------------------
    # Rollback — restore environment from a snapshot
    # ------------------------------------------------------------------

    def rollback(
        self,
        snapshot_id: str,
        *,
        strategy: str = "auto",  # "auto", "overlay", "checkpoint", "destroy"
        container_id: str = "",
    ) -> bool:
        """
        Restore the environment to a previously saved snapshot.

        Strategy "auto" tries mechanisms in this order:
          1. Overlay rollback (cheapest, no process disruption).
          2. Checkpoint restore (includes process state).
          3. IaC state rollback (terraform state revert).
          4. Destroy & recreate (last resort).

        Each mechanism returns True/False. On failure, the next mechanism
        is tried. On success, returns True immediately.

        Returns True if any mechanism succeeded, False if all failed.
        """
        snap = self._find_snapshot(snapshot_id)
        if snap is None:
            logger.error(
                "Rollback failed: snapshot %s not found", snapshot_id,
            )
            return False

        cid = container_id or snap.container_id

        # Determine which strategies to try
        strategies_to_try: list[str]
        if strategy == "auto":
            strategies_to_try = [
                "overlay", "checkpoint", "terraform", "destroy",
            ]
        else:
            strategies_to_try = [strategy]

        for strat in strategies_to_try:
            try:
                if strat == "overlay" and snap.overlay_backup_path:
                    success = self._rollback_overlay(
                        cid, snap.overlay_backup_path,
                    )
                    if success:
                        logger.info(
                            "Rollback via overlay: %s restored", snapshot_id,
                        )
                        return True

                elif strat == "checkpoint" and snap.checkpoint_available:
                    success = self._restore_checkpoint(
                        cid, snap.checkpoint_path,
                    )
                    if success:
                        logger.info(
                            "Rollback via checkpoint: %s restored",
                            snapshot_id,
                        )
                        return True

                elif strat == "terraform" and snap.terraform_state_backup:
                    success = self._rollback_terraform(
                        snap.terraform_state_backup,
                        snap.terraform_plan_backup,
                    )
                    if success:
                        logger.info(
                            "Rollback via terraform: %s restored",
                            snapshot_id,
                        )
                        return True

                elif strat == "destroy":
                    success = self.destroy_and_recreate(snap)
                    if success:
                        logger.info(
                            "Rollback via destroy & recreate: %s",
                            snapshot_id,
                        )
                        return True

                else:
                    logger.debug(
                        "Skipping strategy %s (data not available)", strat,
                    )

            except Exception as exc:
                logger.error(
                    "Rollback strategy %s failed for %s: %s",
                    strat, snapshot_id, exc,
                )
                # Continue to next strategy

        logger.error(
            "All rollback strategies failed for snapshot %s", snapshot_id,
        )
        return False

    # ------------------------------------------------------------------
    # Destroy & Recreate — the nuclear option
    # ------------------------------------------------------------------

    def destroy_and_recreate(
        self,
        snapshot_or_config: EnvSnapshot | str | Path,
    ) -> bool:
        """
        Destroy the current environment and recreate from declarative config.

        This is the most expensive recovery path:
          1. Stop and remove the container (docker rm -f).
          2. Clean up volumes, networks, overlay layers.
          3. Restore declarative config from snapshot backup.
          4. Rebuild and restart (docker compose up, terraform apply, etc.).

        Prerequisites:
          - Declarative environment config must exist (Docker Compose file,
            Terraform config, K8s manifests).
          - Config must NOT have been modified by the Agent (stored in a
            snapshot or git repo outside the Agent's reach).

        Returns True if the environment was successfully recreated.
        """
        cid = self.container_id
        config_path: Optional[Path] = None

        if isinstance(snapshot_or_config, EnvSnapshot):
            snap = snapshot_or_config
            cid = snap.container_id
            config_path = snap.env_config_path
        elif isinstance(snapshot_or_config, (str, Path)):
            config_path = Path(snapshot_or_config)

        if config_path is None:
            logger.error(
                "Destroy & recreate: no config path available",
            )
            return False

        logger.warning(
            "Destroy & recreate initiated for container %s. "
            "This will cause SERVICE DISRUPTION.", cid,
        )

        try:
            # Phase 1: Destroy the current environment
            # Stop and remove the container
            subprocess.run(
                ["docker", "rm", "-f", cid],
                capture_output=True, check=False, timeout=60,
            )
            logger.info("Container %s destroyed", cid)

            # Clean up associated overlay layers (best-effort)
            self._purge_overlay_layers(cid)

            # Phase 2: Recreate from declarative config
            if config_path.suffix in (".yml", ".yaml"):
                # Assume Docker Compose
                result = subprocess.run(
                    ["docker", "compose", "-f", str(config_path), "up", "-d"],
                    capture_output=True, text=True, timeout=120,
                )
                if result.returncode != 0:
                    logger.error(
                        "Docker Compose up failed: %s", result.stderr,
                    )
                    return False
            elif config_path.suffix == ".tf":
                # Assume Terraform root module directory
                tf_dir = config_path.parent
                result = subprocess.run(
                    ["terraform", "apply", "-auto-approve"],
                    cwd=str(tf_dir),
                    capture_output=True, text=True, timeout=300,
                )
                if result.returncode != 0:
                    logger.error(
                        "Terraform apply failed: %s", result.stderr,
                    )
                    return False
            else:
                logger.error(
                    "Unknown config format: %s", config_path.suffix,
                )
                return False

            logger.info(
                "Environment recreated successfully from %s", config_path,
            )
            return True

        except subprocess.TimeoutExpired as exc:
            logger.error("Destroy & recreate timed out: %s", exc)
            return False
        except Exception as exc:
            logger.exception("Destroy & recreate failed: %s", exc)
            return False

    # ------------------------------------------------------------------
    # Internal: overlay filesystem snapshot & rollback
    # ------------------------------------------------------------------

    def _get_overlay_upper_dir(self, container_id: str) -> Optional[Path]:
        """
        Locate the overlay2 upperdir for a Docker container.

        Parses docker inspect output to find the GraphDriver data.
        Falls back to scanning /var/lib/docker/overlay2 if inspect fails.
        """
        try:
            result = subprocess.run(
                [
                    "docker", "inspect", container_id,
                    "--format", "{{.GraphDriver.Data.UpperDir}}",
                ],
                capture_output=True, text=True, timeout=10,
            )
            if result.returncode == 0 and result.stdout.strip():
                upper_dir = Path(result.stdout.strip())
                if upper_dir.exists():
                    return upper_dir
        except Exception:
            pass

        # Fallback: scan overlay2 directory for the container's layer
        # Docker stores container ID → layer ID mapping in
        # /var/lib/docker/image/overlay2/layerdb/mounts//mount-id
        try:
            mount_id_path = (
                self.overlay_root.parent
                / "image/overlay2/layerdb/mounts"
                / container_id / "mount-id"
            )
            if mount_id_path.exists():
                layer_id = mount_id_path.read_text().strip()
                upper = self.overlay_root / layer_id / "diff"
                if upper.exists():
                    return upper
        except Exception:
            pass

        return None

    def _snapshot_overlay(
        self, container_id: str, snapshot_id: str,
    ) -> Path:
        """
        Save a backup of the container's overlay upperdir.

        The backup is a tar.gz archive of the upperdir contents.
        During rollback, the upperdir is cleaned and the archive is
        extracted back.
        """
        upper_dir = self._get_overlay_upper_dir(container_id)
        if upper_dir is None:
            raise RuntimeError(
                f"Cannot locate overlay upperdir for container {container_id}"
            )

        backup_path = (
            self.snapshot_dir / f"{snapshot_id}-overlay.tar.gz"
        )

        # Create tar.gz of upperdir contents (excluding workdir artifacts)
        with tarfile.open(backup_path, "w:gz") as tar:
            for item in upper_dir.iterdir():
                # Skip overlay2 work directory artifacts
                if item.name.startswith(".") and item.is_dir():
                    continue
                tar.add(item, arcname=item.name)

        logger.debug(
            "Overlay snapshot: %s (%d bytes)",
            backup_path, backup_path.stat().st_size,
        )
        return backup_path

    def _rollback_overlay(
        self, container_id: str, backup_path: Optional[Path],
    ) -> bool:
        """
        Restore the overlay upperdir from a snapshot backup.

        Steps:
          1. Stop the container (filesystem must be quiesced).
          2. Clean the upperdir.
          3. Extract snapshot backup into upperdir.
          4. Restart the container.
        """
        if backup_path is None or not backup_path.exists():
            logger.warning(
                "Overlay backup not available: %s", backup_path,
            )
            return False

        upper_dir = self._get_overlay_upper_dir(container_id)
        if upper_dir is None:
            logger.error(
                "Cannot locate overlay upperdir for container %s",
                container_id,
            )
            return False

        try:
            # Stop container to quiesce filesystem
            subprocess.run(
                ["docker", "stop", container_id],
                capture_output=True, timeout=30, check=False,
            )

            # Clean upperdir (preserve the directory itself)
            for item in upper_dir.iterdir():
                if item.is_dir():
                    shutil.rmtree(item, ignore_errors=True)
                else:
                    item.unlink(missing_ok=True)

            # Extract snapshot backup into upperdir
            with tarfile.open(backup_path, "r:gz") as tar:
                tar.extractall(path=upper_dir)

            # Restart container
            subprocess.run(
                ["docker", "start", container_id],
                capture_output=True, timeout=30, check=False,
            )

            logger.info(
                "Overlay rollback complete for container %s", container_id,
            )
            return True

        except Exception as exc:
            logger.exception("Overlay rollback failed: %s", exc)
            return False

    # ------------------------------------------------------------------
    # Internal: container checkpoint (CRIU / docker checkpoint)
    # ------------------------------------------------------------------

    def _create_checkpoint(
        self, container_id: str, snapshot_id: str,
    ) -> Path:
        """
        Create a Docker checkpoint (uses CRIU under the hood).

        Requires Docker experimental features enabled and CRIU installed
        on the host. The checkpoint is saved to a tar.gz file.

        Docker checkpoint saves:
          - Process tree state (memory pages, registers, file descriptors)
          - Network connection state
          - Filesystem state (already captured by overlay snapshot)
        """
        checkpoint_dir = (
            self.snapshot_dir / f"{snapshot_id}-checkpoint"
        )
        checkpoint_dir.mkdir(parents=True, exist_ok=True)

        try:
            result = subprocess.run(
                [
                    "docker", "checkpoint", "create",
                    container_id, snapshot_id,
                    "--checkpoint-dir", str(checkpoint_dir),
                ],
                capture_output=True, text=True, timeout=60,
            )
            if result.returncode != 0:
                raise RuntimeError(
                    f"Docker checkpoint create failed: {result.stderr}"
                )

            # Package checkpoint into a tar.gz for storage
            ckpt_tar = (
                self.snapshot_dir / f"{snapshot_id}-checkpoint.tar.gz"
            )
            with tarfile.open(ckpt_tar, "w:gz") as tar:
                tar.add(checkpoint_dir, arcname=".")

            # Cleanup the unpacked directory
            shutil.rmtree(checkpoint_dir, ignore_errors=True)

            return ckpt_tar

        except subprocess.TimeoutExpired:
            raise RuntimeError("Docker checkpoint create timed out")
        except FileNotFoundError:
            raise RuntimeError(
                "Docker checkpoint requires experimental features "
                "(dockerd --experimental) and CRIU installed on host",
            )

    def _restore_checkpoint(
        self, container_id: str, checkpoint_path: Optional[Path],
    ) -> bool:
        """
        Restore a container from a Docker checkpoint.

        Steps:
          1. Stop the container if running.
          2. Extract checkpoint tar.gz.
          3. Start container from checkpoint (docker start --checkpoint).
        """
        if checkpoint_path is None or not checkpoint_path.exists():
            logger.warning("Checkpoint file not available: %s", checkpoint_path)
            return False

        checkpoint_dir = checkpoint_path.parent / f"{checkpoint_path.stem}"
        checkpoint_dir.mkdir(exist_ok=True)

        try:
            # Extract checkpoint
            with tarfile.open(checkpoint_path, "r:gz") as tar:
                tar.extractall(path=checkpoint_dir)

            # Stop container if running
            subprocess.run(
                ["docker", "stop", container_id],
                capture_output=True, timeout=30, check=False,
            )

            # Restore from checkpoint
            checkpoint_name = checkpoint_path.stem.replace("-checkpoint", "")
            result = subprocess.run(
                [
                    "docker", "start", "--checkpoint",
                    checkpoint_name, container_id,
                    "--checkpoint-dir", str(checkpoint_dir),
                ],
                capture_output=True, text=True, timeout=60,
            )

            if result.returncode != 0:
                logger.error(
                    "Docker checkpoint restore failed: %s", result.stderr,
                )
                return False

            logger.info(
                "Container %s restored from checkpoint %s",
                container_id, checkpoint_name,
            )
            return True

        except Exception as exc:
            logger.exception("Checkpoint restore failed: %s", exc)
            return False
        finally:
            shutil.rmtree(checkpoint_dir, ignore_errors=True)

    # ------------------------------------------------------------------
    # Internal: Terraform state rollback
    # ------------------------------------------------------------------

    def _rollback_terraform(
        self,
        state_backup: Optional[Path],
        plan_backup: Optional[Path],
    ) -> bool:
        """
        Rollback Terraform-managed infrastructure.

        If plan_backup is available: analyze the forward plan to generate
        an inverse plan, then apply it.

        If only state_backup is available: restore the state file and
        run terraform apply to reconcile actual resources with old state.
        """
        if state_backup is None:
            logger.warning(
                "Terraform state backup not available for rollback",
            )
            return False

        tf_dir = state_backup.parent
        tf_state_path = tf_dir / "terraform.tfstate"

        try:
            # Strategy A: Inverse plan (most precise)
            if plan_backup is not None and plan_backup.exists():
                # Show the plan in JSON format for analysis
                result = subprocess.run(
                    ["terraform", "show", "-json", str(plan_backup)],
                    capture_output=True, text=True, timeout=30, cwd=str(tf_dir),
                )
                if result.returncode == 0:
                    plan_data = json.loads(result.stdout)
                    # In production: analyze resource_changes and generate
                    # an inverse Terraform config. For this implementation,
                    # we fall through to state-based rollback.
                    logger.debug(
                        "Plan analysis: %d resource changes detected",
                        len(plan_data.get("resource_changes", [])),
                    )

            # Strategy B: State-based rollback
            # Backup current state first (defense in depth)
            if tf_state_path.exists():
                backup_current = tf_state_path.with_suffix(
                    ".tfstate.backup-current",
                )
                shutil.copy2(tf_state_path, backup_current)

            # Restore the saved state
            shutil.copy2(state_backup, tf_state_path)

            # Reconcile: terraform apply with restored state
            result = subprocess.run(
                ["terraform", "apply", "-auto-approve"],
                capture_output=True, text=True, timeout=300,
                cwd=str(tf_dir),
            )
            if result.returncode != 0:
                logger.error(
                    "Terraform apply (rollback) failed: %s", result.stderr,
                )
                return False

            logger.info("Terraform state rollback applied successfully")
            return True

        except subprocess.TimeoutExpired:
            logger.error("Terraform rollback timed out")
            return False
        except Exception as exc:
            logger.exception("Terraform rollback failed: %s", exc)
            return False

    # ------------------------------------------------------------------
    # Internal: helper methods
    # ------------------------------------------------------------------

    def _find_snapshot(self, snapshot_id: str) -> Optional[EnvSnapshot]:
        """Find a snapshot by ID in the in-memory list."""
        for snap in self._snapshots:
            if snap.snapshot_id == snapshot_id:
                return snap
        return None

    def _cleanup_snapshot(self, snap: EnvSnapshot) -> None:
        """Delete all files associated with a snapshot."""
        for path_attr in [
            "overlay_backup_path",
            "checkpoint_path",
            "terraform_state_backup",
            "terraform_plan_backup",
            "k8s_manifests_backup",
            "env_config_path",
        ]:
            p = getattr(snap, path_attr, None)
            if p and Path(p).exists():
                Path(p).unlink(missing_ok=True)

    def _purge_overlay_layers(self, container_id: str) -> None:
        """
        Best-effort cleanup of overlay layers for a destroyed container.
        Docker normally handles this, but this is a safety net.
        """
        try:
            upper_dir = self._get_overlay_upper_dir(container_id)
            if upper_dir and upper_dir.exists():
                shutil.rmtree(upper_dir, ignore_errors=True)
                logger.debug("Purged overlay layers for %s", container_id)
        except Exception:
            pass

    def list_snapshots(self) -> list[dict]:
        """Return metadata about all stored snapshots."""
        return [
            {
                "snapshot_id": s.snapshot_id,
                "label": s.label,
                "container_id": s.container_id,
                "timestamp": s.timestamp,
                "overlay": s.overlay_backup_path is not None,
                "checkpoint": s.checkpoint_available,
                "terraform_state": s.terraform_state_backup is not None,
            }
            for s in self._snapshots
        ]


# ==========================================================================
# Usage example
# ==========================================================================

if __name__ == "__main__":
    import tempfile

    tmp = Path(tempfile.mkdtemp(prefix="env-rollback-demo-"))
    snap_dir = tmp / "snapshots"
    snap_dir.mkdir()

    env = ContainerEnvironment(
        container_id="agent-container-001",
        snapshot_dir=snap_dir,
        overlay_root=tmp / "fake-overlay2",
    )

    # Simulate: create some "overlay" files for snapshotting
    fake_overlay_upper = tmp / "fake-overlay2" / "abc123" / "diff"
    fake_overlay_upper.mkdir(parents=True)
    (fake_overlay_upper / "nginx.conf").write_text(
        "server { listen 80; }\n"
    )
    (fake_overlay_upper / "app.env").write_text("DEBUG=false\n")

    print(f"Before: nginx.conf = {(fake_overlay_upper / 'nginx.conf').read_text()!r}")

    # Snapshot the environment
    snap_id = env.snapshot(
        container_id="agent-container-001",
        label="pre-agent-task",
        include_checkpoint=False,  # CRIU not available in demo
    )
    print(f"Snapshot created: {snap_id}")

    # Simulate Agent making destructive changes
    (fake_overlay_upper / "nginx.conf").write_text(
        "server { listen 443 ssl; typo_here; }\n"
    )
    (fake_overlay_upper / "bad-file.tmp").write_text("corrupt data\n")
    print(f"After agent: nginx.conf = {(fake_overlay_upper / 'nginx.conf').read_text()!r}")

    # Rollback via overlay (simulated — in real scenario docker stop/start
    # would be invoked)
    # We simulate the rollback by restoring from overlay backup
    backup_path = snap_dir / f"{snap_id}-overlay.tar.gz"
    if backup_path.exists():
        # Clean upperdir
        for item in list(fake_overlay_upper.iterdir()):
            if item.is_file():
                item.unlink()
        # Restore from backup
        with tarfile.open(backup_path, "r:gz") as tar:
            tar.extractall(path=fake_overlay_upper)
        print(f"Rollback: nginx.conf = {(fake_overlay_upper / 'nginx.conf').read_text()!r}")
        # bad-file.tmp should be gone (wasn't in the snapshot)
        bad_file = fake_overlay_upper / "bad-file.tmp"
        print(f"bad-file.tmp exists: {bad_file.exists()}")

    # List snapshots
    print(f"\nSnapshots: {json.dumps(env.list_snapshots(), indent=2, default=str)}")

    # Cleanup
    import shutil as _shutil
    _shutil.rmtree(tmp, ignore_errors=True)

设计决策:为什么不是每个操作都环境级快照?

ContainerEnvironment 的四层恢复策略揭示了一个重要的设计决策:环境级回滚的触发粒度是"任务"而非"操作"。在第 2-4 节中,每个文件写入、每条数据库操作都注册一个 undo handler——这是细粒度回滚。环境级回滚工作在不同的粒度上:它在 Agent 开始一个任务之前拍快照,任务完成后要么提交变更(丢弃快照)、要么整体回滚(恢复快照)。这个设计有两个根本原因:

  1. 成本差异:文件级快照(CoW)只需复制几个 KB 到几 MB 的数据;环境级快照需要备份 overlay 层(可能数百 MB)并可选地创建进程检查点(数 GB)。环境级快照的 IO 开销是文件级快照的 100-1000 倍——在每个操作之前拍环境快照会使 Agent 的执行速度降到不可接受的水平。
  2. 隔离假设:环境级回滚依赖一个关键假设——Agent 在一个任务周期内对环境的修改是自包含的。如果 Agent 的任务 A 修改了系统配置,任务 B 依赖这些修改,那么只回滚任务 B 而保留任务 A 的修改是不可行的——环境不是一个无状态的独立层,任务之间有隐式依赖。因此环境级回滚的语义是"all-or-nothing"——要么任务全部提交,要么整个环境回到任务开始之前。

在实践中,推荐的策略是双重保险:在每个任务开始前拍环境级快照(粗粒度安全网),同时每个写操作仍然注册细粒度的 undo handler(精确定位回滚)。如果 Agent 在第 43 步失败了 1 个文件写入,细粒度回滚只需要 1 次文件恢复——不需要整个环境回滚到任务开始。但如果 Agent 在第 43 步已经执行了 42 个不可追踪的副作用操作(修改了系统配置、安装了软件包、变更了内核参数),环境级快照就是最后的防线。关于如何利用 Agent 状态机来定义"可回滚的任务边界"——即什么时候环境级快照应该被触发——参见 Agent 状态机设计。关于在错误恢复流程中如何判断是选择细粒度回滚还是环境级回滚,参见 Agent 错误恢复

6. 基于栈的回滚:组合与序列化撤销操作

前五节从不同层面解决了"如何撤销一次操作"的问题——文件级(CoW + diff)、数据级(WAL + 检查点)、环境级(overlay 快照 + 销毁重建)。但当 Agent 在一个任务中执行了数十个操作——每个操作都在不同层面、使用不同策略注册了 undo handler——一个新的问题浮出水面:如何确保这些 undo handler 以正确的顺序、正确的依赖关系被调用?回滚不是"随便叫几个 undo 函数然后希望它们恢复正确状态"——顺序错了,回滚可能引入比原始错误更严重的不一致。

这就是栈式回滚(Stack-Based Rollback)的用武之地:将所有 undo handler 推入一个后进先出(LIFO)的栈中,回滚时从栈顶依次弹出并执行。LIFO 不是可选的——它是操作依赖关系的结构化反映:第 3 步的写入引用了第 2 步创建的文件,因此回滚时必须先撤销第 3 步(移除对第 2 步文件的引用),再撤销第 2 步(删除文件)。正序回滚会先删除第 2 步的文件,导致第 3 步的回滚尝试引用一个已不存在的文件而失败。

栈语义:为什么是 LIFO 而非任意序

Agent 操作的依赖关系链决定了回滚必须是 LIFO。考虑三个连续的 Agent 操作:

  1. op-1:创建文件 /etc/app/feature-flags.yaml,内容为 enable_new_auth: false
  2. op-2:修改 /etc/app/main.conf,在 include 指令中添加对 feature-flags.yaml 的引用。
  3. op-3:重启 app.service 以使新配置生效。

回滚时,正确的顺序是 op-3 → op-2 → op-1:先撤销重启操作(如果重启引入了运行时问题),再撤销配置文件的 include 引用,最后删除 feature-flags.yaml。如果顺序是 op-1 → op-2 → op-3(正序),第一步就删除了 feature-flags.yaml,而 main.conf 中仍然有对它的 include 引用——服务下次重启时将因为 file not found 而失败。

栈结构原生映射这种依赖关系:每次 Agent 注册一个 undo handler,它被 push 到栈顶。当回滚被触发时,pop 从栈顶依次取出 handler 执行——这保证了后注册的 handler 先执行。不需要显式声明操作之间的依赖关系——栈的 LIFO 性质自动确保了"后依赖先撤销"。

依赖感知回滚:当文件引用文件时

LIFO 解决了时间顺序上的依赖,但没有解决语义依赖——两个操作可能没有直接的时间先后关系,但其中一个操作的对象引用了另一个操作的对象。例如,Agent 在步骤 5 创建了文件 A.py,在步骤 12 创建了 B.py,而 B.pyimport A。这两个操作在时间上相隔 7 步,但在语义上 B 依赖 A。

语义依赖使纯 LIFO 回滚变得不够:如果只回滚步骤 5(删除 A.py),步骤 12 的 B.py 就会因为 import 一个不存在的模块而损坏。这就是依赖感知回滚(Dependency-Aware Rollback)需要解决的问题——在回滚一个操作时,同时检查是否有后续操作依赖它的输出,如果有,要么连带回滚这些后续操作,要么阻止回滚并报告依赖冲突。

实现依赖感知的务实方案不是构建完整的依赖图(这等同于实现一个构建系统),而是采用操作分组策略:Agent 在一组语义相关的操作(如"添加一个新功能的所有文件")之间插入显式的分组边界。每个分组内的操作共享一个组 ID,在回滚时以组为单位进行——要么全组回滚,要么全组保留。组间的依赖关系可以通过人工标注分析工具调用日志推断——如果工具调用 T2 的参数中包含 T1 创建的文件的路径,则 T2 依赖 T1。关于工具调用日志的采集与分析,参见 Agent 可观察性

RollbackStack:可组合的 undo handler 栈实现

下面的代码实现了一个完整的 RollbackStack——它支持 LIFO 推入/弹出、依赖感知的组回滚、部分回滚到指定检查点、以及回滚过程中的错误处理策略。这是本文中抽象层次最高的回滚组件——它不关心 undo handler 内部是快照、WAL 还是补偿,只关心它们的执行顺序和失败处理:

from __future__ import annotations

import enum
import time
import uuid
from dataclasses import dataclass, field
from typing import Callable, Optional, Any
import logging

logger = logging.getLogger("agent.rollback.stack")


# ==========================================================================
# Error handling policy during rollback — what happens when undo itself fails?
# ==========================================================================

class UndoErrorPolicy(enum.Enum):
    """
    Defines what the RollbackStack does when an undo handler fails.

    ABORT_ON_FAILURE  — Stop immediately when any undo fails. Leave remaining
                        handlers on the stack. This is the safest default
                        because continuing after a failed undo risks
                        compounding the inconsistency.
    CONTINUE_ON_FAILURE — Log the failure and continue executing remaining
                           undo handlers. Use only when handlers are known
                           to be independent (no cross-dependencies).
                           This is riskier but may recover more state.
    RETRY_ONCE         — Retry the failed undo handler once, then apply
                         ABORT_ON_FAILURE or CONTINUE_ON_FAILURE.
    """
    ABORT_ON_FAILURE = "abort"
    CONTINUE_ON_FAILURE = "continue"
    RETRY_ONCE = "retry_once"


# ==========================================================================
# Undo handler — the composable unit pushed onto the stack
# ==========================================================================

@dataclass
class UndoEntry:
    """
    One entry on the rollback stack. Contains the undo callable and
    metadata for observability, grouping, and checkpoint targeting.
    """
    entry_id: str                        # Unique ID for this undo entry
    operation_id: str                    # ID of the operation this undoes
    group_id: str = ""                   # Optional group for dependency-aware rollback
    handler: Callable[[], bool] = field(repr=False)  # The undo function
    description: str = ""                # Human-readable description
    created_at: float = field(default_factory=time.time)
    metadata: dict[str, Any] = field(default_factory=dict)  # Extra context


# ==========================================================================
# Checkpoint — a marker on the stack for partial rollback
# ==========================================================================

@dataclass
class StackCheckpoint:
    """
    A named marker pushed onto the stack. Not an undo handler itself —
    it serves as a target for rollback_to(). When rollback_to(checkpoint)
    is called, handlers are popped and executed until this checkpoint
    is reached (the checkpoint marker itself is popped but not executed).
    """
    checkpoint_id: str
    label: str
    created_at: float = field(default_factory=time.time)
    metadata: dict[str, Any] = field(default_factory=dict)


# ==========================================================================
# RollbackStack — the main composable undo stack
# ==========================================================================

class RollbackStack:
    """
    Stack-based rollback engine with composable undo handlers.

    Features:
      - LIFO push/pop for correct undo ordering.
      - Group-based rollback for dependency-aware partial undos.
      - Checkpoint markers for rollback_to() targeting.
      - Pluggable error handling policy.
      - Observability via undo_records (full execution history).

    Lifecycle:
      1. push(undo_handler)     — register undo before each Agent write.
      2. mark_checkpoint(label) — insert a named marker on the stack.
      3. commit(entry_id)       — remove a handler from stack (write succeeded
                                  and is confirmed good).
      4. rollback_to(checkpoint_id) — pop and execute handlers until
                                       checkpoint is reached.
      5. rollback_all()         — pop and execute all handlers on the stack.

    Design decision: the stack stores BOTH UndoEntry (handlers) and
    StackCheckpoint (markers). Handlers are executed when popped; markers
    are simply removed. This unified stack allows rollback_to() to target
    any checkpoint regardless of how many handlers were pushed after it.
    """

    def __init__(
        self,
        *,
        error_policy: UndoErrorPolicy = UndoErrorPolicy.ABORT_ON_FAILURE,
        max_stack_depth: int = 10_000,
    ) -> None:
        self._stack: list[UndoEntry | StackCheckpoint] = []
        self.error_policy = error_policy
        self.max_stack_depth = max_stack_depth
        # Records of all executed undos for audit
        self.undo_records: list[dict[str, Any]] = []
        # Set of committed entry_ids (comitted handlers are skipped during
        # rollback — they stay on the stack for audit but don't execute)
        self._committed: set[str] = set()

    # ------------------------------------------------------------------
    # Push — register undo handler before executing a write
    # ------------------------------------------------------------------

    def push(
        self,
        handler: Callable[[], bool],
        *,
        operation_id: str = "",
        group_id: str = "",
        description: str = "",
        metadata: dict[str, Any] | None = None,
    ) -> str:
        """
        Push an undo handler onto the stack.

        Call this BEFORE the Agent executes the corresponding write
        operation. If the write fails, the handler is already on the
        stack and ready for rollback.

        Args:
            handler:     Callable returning True if undo succeeded.
            operation_id: ID of the write operation this undoes.
            group_id:    Optional group for dependency-aware rollback.
            description: Human-readable for logs and debugging.
            metadata:    Arbitrary key-value context (file paths, row IDs).

        Returns:
            entry_id — use this with commit() or pop_one().

        Raises:
            RuntimeError: if stack exceeds max_stack_depth.
        """
        if len(self._stack) >= self.max_stack_depth:
            raise RuntimeError(
                f"RollbackStack overflow: {len(self._stack)} >= "
                f"{self.max_stack_depth}. This likely means commit() "
                f"is not being called for successful operations."
            )

        entry_id = f"undo-{uuid.uuid4().hex[:12]}"
        entry = UndoEntry(
            entry_id=entry_id,
            operation_id=operation_id or entry_id,
            group_id=group_id,
            handler=handler,
            description=description,
            metadata=metadata or {},
        )
        self._stack.append(entry)

        logger.debug(
            "Push undo: %s (op=%s, group=%s, depth=%d) — %s",
            entry_id, operation_id, group_id, len(self._stack),
            description or "(no description)",
        )
        return entry_id

    # ------------------------------------------------------------------
    # Checkpoint — mark a position on the stack for targeted rollback
    # ------------------------------------------------------------------

    def mark_checkpoint(
        self,
        label: str = "",
        *,
        metadata: dict[str, Any] | None = None,
    ) -> str:
        """
        Place a named checkpoint marker on the stack.

        Later, rollback_to(checkpoint_id) will pop and execute all
        handlers above this checkpoint.

        Use this at decision boundaries: task start, before risky
        operations, after a batch of successful writes.
        """
        checkpoint_id = f"ckpt-{uuid.uuid4().hex[:12]}"
        ckpt = StackCheckpoint(
            checkpoint_id=checkpoint_id,
            label=label,
            metadata=metadata or {},
        )
        self._stack.append(ckpt)
        logger.info(
            "Checkpoint %s: '%s' at depth %d",
            checkpoint_id, label, len(self._stack),
        )
        return checkpoint_id

    # ------------------------------------------------------------------
    # Commit — mark a handler as "confirmed good", don't execute on rollback
    # ------------------------------------------------------------------

    def commit(self, entry_id: str) -> bool:
        """
        Mark an undo handler as committed.

        Committed handlers are NOT executed during rollback_all() or
        rollback_to(). They remain on the stack as markers for audit
        purposes.

        Returns True if the entry was found, False otherwise.
        """
        for i, item in enumerate(self._stack):
            if isinstance(item, UndoEntry) and item.entry_id == entry_id:
                self._committed.add(entry_id)
                logger.debug(
                    "Commit: %s (op=%s) — skipped during rollback",
                    entry_id, item.operation_id,
                )
                return True
        logger.warning("Commit: entry %s not found on stack", entry_id)
        return False

    # ------------------------------------------------------------------
    # Pop — execute and remove the top handler
    # ------------------------------------------------------------------

    def pop(self) -> bool:
        """
        Execute and remove the top undo handler from the stack.

        If the top item is a checkpoint marker, it is simply removed
        (no execution). If the top item is a committed handler, it is
        also removed without execution.

        Returns True if undo succeeded, False if it failed or if the
        stack was empty or the top item wasn't an undo handler.
        """
        if not self._stack:
            logger.debug("Pop: stack is empty, nothing to undo")
            return True

        top = self._stack[-1]

        # Checkpoint marker — just remove it
        if isinstance(top, StackCheckpoint):
            self._stack.pop()
            logger.debug(
                "Pop: removed checkpoint %s ('%s')",
                top.checkpoint_id, top.label,
            )
            return True

        # Undo handler — execute it
        if isinstance(top, UndoEntry):
            if top.entry_id in self._committed:
                self._stack.pop()
                logger.debug(
                    "Pop: skipped committed handler %s", top.entry_id,
                )
                return True

            # Execute the undo handler
            success = self._execute_undo(top)
            self._stack.pop()
            return success

        return False

    # ------------------------------------------------------------------
    # Pop one — execute and remove a specific handler by entry_id
    # ------------------------------------------------------------------

    def pop_one(self, entry_id: str) -> bool:
        """
        Find and execute a specific undo handler by its entry_id,
        then remove it from the stack.

        This breaks LIFO ordering and should ONLY be used when you
        are certain the handler has no dependencies with entries above
        or below it on the stack.

        Returns True if the handler was found and executed successfully.
        """
        for i, item in enumerate(self._stack):
            if isinstance(item, UndoEntry) and item.entry_id == entry_id:
                if item.entry_id in self._committed:
                    self._stack.pop(i)
                    logger.debug("Pop-one: skipped committed %s", entry_id)
                    return True
                success = self._execute_undo(item)
                self._stack.pop(i)
                logger.info(
                    "Pop-one: %s %s (group=%s)",
                    "SUCCESS" if success else "FAILED",
                    entry_id, item.group_id,
                )
                return success

        logger.warning("Pop-one: entry %s not found on stack", entry_id)
        return False

    # ------------------------------------------------------------------
    # Rollback by group — undo all handlers in a dependency group
    # ------------------------------------------------------------------

    def rollback_group(self, group_id: str) -> dict[str, bool]:
        """
        Rollback all undo handlers belonging to the given group_id.

        Handlers are executed in LIFO order within the group (i.e.,
        the last-pushed handler in the group is undone first).

        Non-group entries are skipped (left on the stack).

        Returns a dict mapping entry_id → success.
        """
        results: dict[str, bool] = {}
        group_entries: list[tuple[int, UndoEntry]] = []

        # Collect all entries in the group (with their stack positions)
        for i, item in enumerate(self._stack):
            if isinstance(item, UndoEntry) and item.group_id == group_id:
                if item.entry_id not in self._committed:
                    group_entries.append((i, item))

        if not group_entries:
            logger.info(
                "Rollback group '%s': no uncommitted entries found",
                group_id,
            )
            return results

        # Execute in reverse index order (LIFO within group)
        # We must remove from highest index to lowest to not invalidate
        # indices of entries we haven't processed yet
        group_entries.sort(key=lambda x: x[0], reverse=True)

        for idx, entry in group_entries:
            try:
                success = self._execute_undo(entry)
                results[entry.entry_id] = success
                # Remove from stack
                self._stack.pop(idx)
            except Exception as exc:
                logger.exception(
                    "Rollback group: exception in %s: %s",
                    entry.entry_id, exc,
                )
                results[entry.entry_id] = False
                if self.error_policy == UndoErrorPolicy.ABORT_ON_FAILURE:
                    break

        logger.info(
            "Rollback group '%s': %d/%d succeeded",
            group_id,
            sum(1 for v in results.values() if v),
            len(results),
        )
        return results

    # ------------------------------------------------------------------
    # Rollback to checkpoint — partial rollback to a named marker
    # ------------------------------------------------------------------

    def rollback_to(
        self,
        checkpoint_id: str,
        *,
        inclusive: bool = False,
    ) -> dict[str, bool]:
        """
        Pop and execute all undo handlers from the top of the stack
        until the specified checkpoint marker is reached.

        Args:
            checkpoint_id: The checkpoint marker to rollback to.
            inclusive: If True, also pop and remove the checkpoint
                       marker itself. If False (default), the checkpoint
                       remains on the stack for future use.

        Returns:
            Dict mapping entry_id → success for all executed undo handlers.

        If the checkpoint is not found on the stack, raises ValueError.
        If an undo fails and error_policy is ABORT_ON_FAILURE, remaining
        handlers are NOT executed and stay on the stack.
        """
        # Find the checkpoint position
        ckpt_index: Optional[int] = None
        for i, item in enumerate(self._stack):
            if (
                isinstance(item, StackCheckpoint)
                and item.checkpoint_id == checkpoint_id
            ):
                ckpt_index = i
                break

        if ckpt_index is None:
            raise ValueError(
                f"Checkpoint {checkpoint_id} not found on stack. "
                f"Available checkpoints: "
                f"{[item.checkpoint_id for item in self._stack if isinstance(item, StackCheckpoint)]}"
            )

        results: dict[str, bool] = {}
        # Pop from top down to (and optionally including) the checkpoint
        stop_index = ckpt_index if inclusive else ckpt_index + 1

        while len(self._stack) > stop_index:
            item = self._stack[-1]

            if isinstance(item, StackCheckpoint):
                # Intermediate checkpoint — remove it, don't execute
                self._stack.pop()
                logger.debug(
                    "Rollback_to: removing intermediate checkpoint %s",
                    item.checkpoint_id,
                )
                continue

            if isinstance(item, UndoEntry):
                if item.entry_id in self._committed:
                    self._stack.pop()
                    continue

                try:
                    success = self._execute_undo(item)
                    results[item.entry_id] = success
                except Exception as exc:
                    logger.exception(
                        "Rollback_to: exception in %s: %s",
                        item.entry_id, exc,
                    )
                    results[item.entry_id] = False

                self._stack.pop()

                # Check error policy
                if (
                    not results.get(item.entry_id, False)
                    and self.error_policy == UndoErrorPolicy.ABORT_ON_FAILURE
                ):
                    logger.error(
                        "Rollback_to: ABORTING after failed undo %s. "
                        "%d handlers remain on stack.",
                        item.entry_id,
                        len(self._stack) - stop_index,
                    )
                    break

        logger.info(
            "Rollback_to checkpoint %s: %d undos executed, %d succeeded",
            checkpoint_id,
            len(results),
            sum(1 for v in results.values() if v),
        )
        return results

    # ------------------------------------------------------------------
    # Rollback all — LIFO execution of every handler on the stack
    # ------------------------------------------------------------------

    def rollback_all(self) -> dict[str, bool]:
        """
        Pop and execute ALL undo handlers on the stack in LIFO order.

        Checkpoint markers are removed without execution.
        Committed handlers are removed without execution.

        Returns:
            Dict mapping entry_id → success for all executed handlers.

        Error handling follows self.error_policy:
          - ABORT_ON_FAILURE: stop at first failure.
          - CONTINUE_ON_FAILURE: log + continue.
          - RETRY_ONCE: retry failed handler once, then apply ABORT/CONTINUE.
        """
        results: dict[str, bool] = {}

        while self._stack:
            item = self._stack[-1]

            if isinstance(item, StackCheckpoint):
                self._stack.pop()
                logger.debug(
                    "Rollback_all: removing checkpoint %s", item.checkpoint_id,
                )
                continue

            if isinstance(item, UndoEntry):
                if item.entry_id in self._committed:
                    self._stack.pop()
                    continue

                success = self._execute_undo(item)
                results[item.entry_id] = success
                self._stack.pop()

                if not success:
                    if self.error_policy == UndoErrorPolicy.ABORT_ON_FAILURE:
                        logger.error(
                            "Rollback_all: ABORTING after failed undo %s. "
                            "%d handlers remain on stack.",
                            item.entry_id, len(self._stack),
                        )
                        break
                    elif self.error_policy == UndoErrorPolicy.RETRY_ONCE:
                        logger.warning(
                            "Rollback_all: RETRYING failed undo %s",
                            item.entry_id,
                        )
                        # Retry once
                        retry_success = item.handler()
                        results[item.entry_id] = retry_success
                        if not retry_success and (
                            self.error_policy == UndoErrorPolicy.ABORT_ON_FAILURE
                        ):
                            break
                        # else: CONTINUE_ON_FAILURE — already logged

        logger.info(
            "Rollback_all: %d handlers executed, %d succeeded, %d remaining",
            len(results),
            sum(1 for v in results.values() if v),
            len(self._stack),
        )
        return results

    # ------------------------------------------------------------------
    # Internal helpers
    # ------------------------------------------------------------------

    def _execute_undo(self, entry: UndoEntry) -> bool:
        """
        Execute a single undo handler and record the result.

        Handles:
          - Normal execution (handler returns True/False).
          - Exceptions from the handler (caught, logged, treated as failure).
          - Recording to undo_records for audit trail.
        """
        logger.info(
            "Executing undo: %s (op=%s) — %s",
            entry.entry_id, entry.operation_id,
            entry.description or "(no description)",
        )

        start_time = time.monotonic()
        try:
            result = entry.handler()
            elapsed = time.monotonic() - start_time
            status = "success" if result else "failed"
        except Exception as exc:
            elapsed = time.monotonic() - start_time
            result = False
            status = "exception"
            logger.exception(
                "Undo handler %s raised exception: %s", entry.entry_id, exc,
            )

        # Record for audit
        self.undo_records.append({
            "entry_id": entry.entry_id,
            "operation_id": entry.operation_id,
            "group_id": entry.group_id,
            "description": entry.description,
            "status": status,
            "elapsed_ms": round(elapsed * 1000, 2),
            "timestamp": time.time(),
        })

        return result

    # ------------------------------------------------------------------
    # Observability
    # ------------------------------------------------------------------

    @property
    def depth(self) -> int:
        """Number of items on the stack (handlers + checkpoints)."""
        return len(self._stack)

    @property
    def pending_undo_count(self) -> int:
        """Number of uncommitted undo handlers on the stack."""
        return sum(
            1 for item in self._stack
            if isinstance(item, UndoEntry)
            and item.entry_id not in self._committed
        )

    @property
    def checkpoints(self) -> list[dict[str, Any]]:
        """List all checkpoint markers currently on the stack."""
        return [
            {
                "checkpoint_id": item.checkpoint_id,
                "label": item.label,
                "depth_from_top": len(self._stack) - i - 1,
            }
            for i, item in enumerate(self._stack)
            if isinstance(item, StackCheckpoint)
        ]

    def summary(self) -> dict[str, Any]:
        """Return a complete summary for monitoring and debugging."""
        entries = [
            {
                "type": "handler" if isinstance(item, UndoEntry) else "checkpoint",
                "id": (
                    item.entry_id if isinstance(item, UndoEntry)
                    else item.checkpoint_id
                ),
                "description": (
                    item.description if isinstance(item, UndoEntry)
                    else f"checkpoint: {item.label}"
                ),
                "committed": (
                    item.entry_id in self._committed
                    if isinstance(item, UndoEntry)
                    else False
                ),
                "group_id": (
                    item.group_id if isinstance(item, UndoEntry)
                    else ""
                ),
            }
            for item in self._stack
        ]
        return {
            "total_items": len(self._stack),
            "pending_undos": self.pending_undo_count,
            "committed": len(self._committed),
            "checkpoint_count": len(self.checkpoints),
            "undo_records_count": len(self.undo_records),
            "stack_snapshot": entries,
            "recent_failures": [
                r for r in self.undo_records[-20:]
                if r["status"] != "success"
            ],
        }


# ==========================================================================
# Usage example — demonstrating LIFO ordering, checkpoint rollback,
# group rollback, and error handling.
# ==========================================================================

if __name__ == "__main__":
    # Setup: a simulated filesystem (in-memory dict for demonstration)
    fs: dict[str, str] = {
        "/etc/app/base.conf": "log_level=info\n",
    }

    stack = RollbackStack(error_policy=UndoErrorPolicy.ABORT_ON_FAILURE)

    # -- Example 1: Basic LIFO push/pop -------------------------------------
    print("=== Example 1: Basic LIFO push/pop ===")

    # op-1: Create a feature flags file
    stack.push(
        handler=lambda: fs.pop("/etc/app/feature-flags.yaml", None) or True,
        operation_id="op-1",
        group_id="feature-flags",
        description="Delete feature-flags.yaml if created",
        metadata={"path": "/etc/app/feature-flags.yaml"},
    )
    fs["/etc/app/feature-flags.yaml"] = "enable_new_auth: true\n"
    print(f"  After op-1: {fs}")

    # op-2: Modify base.conf to include feature-flags
    original_base = fs["/etc/app/base.conf"]

    def undo_base_conf():
        fs["/etc/app/base.conf"] = original_base
        return True

    stack.push(
        handler=undo_base_conf,
        operation_id="op-2",
        description="Restore base.conf to pre-include state",
        metadata={"path": "/etc/app/base.conf"},
    )
    fs["/etc/app/base.conf"] = (
        original_base + "include /etc/app/feature-flags.yaml\n"
    )
    print(f"  After op-2: base.conf = {fs['/etc/app/base.conf']!r}")

    # op-3: "Restart" the service (simulated — no-op handler)
    def undo_restart():
        print("  Undo restart: service rolled back to previous version")
        return True

    stack.push(
        handler=undo_restart,
        operation_id="op-3",
        description="Restart service with previous config",
    )
    print(f"  After op-3: stack depth = {stack.depth}")

    # Rollback all — should execute op-3, op-2, op-1 in that order (LIFO)
    print("\n  --- Rollback all ---")
    results = stack.rollback_all()
    print(f"  Results: {results}")
    print(f"  Final fs: {fs}")
    print(f"  Stack depth after rollback: {stack.depth}")
    assert fs["/etc/app/base.conf"] == "log_level=info\n", (
        "LIFO rollback failed!"
    )
    assert "/etc/app/feature-flags.yaml" not in fs, (
        "Feature flags file not removed!"
    )
    print("  PASSED: LIFO rollback restored original state\n")

    # -- Example 2: Checkpoint-based partial rollback -----------------------
    print("=== Example 2: Checkpoint partial rollback ===")
    stack2 = RollbackStack()

    fs["/etc/app/base.conf"] = "log_level=info\n"

    # Save checkpoint before any writes
    ckpt1 = stack2.mark_checkpoint("pre-task-start")
    print(f"  Checkpoint 1: {ckpt1}")

    # Execute 3 operations
    for i in range(3):
        stack2.push(
            handler=lambda i=i: (
                fs.pop(f"/tmp/step-{i}.txt", None) or True
            ),
            operation_id=f"step-{i}",
            description=f"Undo step {i}",
        )
        fs[f"/tmp/step-{i}.txt"] = f"content {i}\n"

    print(f"  After 3 steps: stack depth = {stack2.depth}, fs = {list(fs.keys())}")

    # Rollback to checkpoint 1 — should undo all 3 steps
    results = stack2.rollback_to(ckpt1)
    print(f"  Rollback_to results: {results}")
    assert all(f"/tmp/step-{i}.txt" not in fs for i in range(3)), (
        "Checkpoint rollback failed!"
    )
    print("  PASSED: Checkpoint rollback restored to pre-task state\n")

    # -- Example 3: Group-based rollback ------------------------------------
    print("=== Example 3: Group-based rollback ===")
    stack3 = RollbackStack()

    fs["/etc/app/base.conf"] = "log_level=info\n"

    # Group "auth-feature": two related operations
    stack3.push(
        handler=lambda: fs.pop("/etc/app/auth.conf", None) or True,
        operation_id="auth-1",
        group_id="auth-feature",
        description="Remove auth.conf",
    )
    fs["/etc/app/auth.conf"] = "oauth2_enabled: true\n"

    stack3.push(
        handler=lambda: fs.pop("/etc/app/auth-keys.pem", None) or True,
        operation_id="auth-2",
        group_id="auth-feature",
        description="Remove auth-keys.pem",
    )
    fs["/etc/app/auth-keys.pem"] = "-----BEGIN RSA KEY-----\n"

    # Unrelated operation (different group)
    stack3.push(
        handler=lambda: fs.pop("/etc/app/logging.conf", None) or True,
        operation_id="logging-1",
        group_id="logging-feature",
        description="Remove logging.conf",
    )
    fs["/etc/app/logging.conf"] = "verbose: true\n"

    print(f"  Before rollback: auth.conf={('/etc/app/auth.conf' in fs)}, "
          f"auth-keys.pem={('/etc/app/auth-keys.pem' in fs)}, "
          f"logging.conf={('/etc/app/logging.conf' in fs)}")

    # Rollback only the auth-feature group
    results = stack3.rollback_group("auth-feature")
    print(f"  Group rollback results: {results}")

    # Auth files should be gone, logging should still be there
    assert "/etc/app/auth.conf" not in fs, "Auth conf not removed!"
    assert "/etc/app/auth-keys.pem" not in fs, "Auth keys not removed!"
    assert "/etc/app/logging.conf" in fs, "Logging conf wrongly removed!"
    print(f"  After rollback: auth.conf={('/etc/app/auth.conf' in fs)}, "
          f"auth-keys.pem={('/etc/app/auth-keys.pem' in fs)}, "
          f"logging.conf={('/etc/app/logging.conf' in fs)}")
    print("  PASSED: Group rollback only removed auth-feature files\n")

    # -- Example 4: Error handling during rollback --------------------------
    print("=== Example 4: Error handling ===")
    stack4 = RollbackStack(error_policy=UndoErrorPolicy.CONTINUE_ON_FAILURE)

    counter = {"value": 0}

    # Push a handler that will succeed
    def succeed_handler():
        counter["value"] += 1
        return True

    # Push a handler that will fail
    def fail_handler():
        counter["value"] += 1
        raise RuntimeError("Simulated undo failure!")

    # Push another succeed handler
    def succeed_handler2():
        counter["value"] += 1
        return True

    stack4.push(succeed_handler2, operation_id="good-2",
               description="Second good handler")
    stack4.push(fail_handler, operation_id="bad-1",
               description="Failing handler")
    stack4.push(succeed_handler, operation_id="good-1",
               description="First good handler")

    print(f"  Stack depth before rollback: {stack4.depth}")
    results = stack4.rollback_all()
    print(f"  Rollback results: {results}")
    print(f"  counter = {counter['value']} (expected 3 — all handlers executed)")
    print(f"  Stack depth after: {stack4.depth}")
    assert counter["value"] == 3, (
        f"CONTINUE_ON_FAILURE: expected all 3 executed, got {counter['value']}"
    )
    print("  PASSED: CONTINUE_ON_FAILURE executed all handlers despite failure\n")

    # -- Example 5: ABORT on failure ----------------------------------------
    stack5 = RollbackStack(error_policy=UndoErrorPolicy.ABORT_ON_FAILURE)
    counter5 = {"value": 0}

    stack5.push(
        lambda: (counter5.update({"value": counter5["value"] + 1}) or True),
        operation_id="last",
        description="Last handler (should not execute)",
    )
    stack5.push(
        lambda: (_ for _ in ()).throw(RuntimeError("fail!")),
        operation_id="middle",
        description="Middle handler (will fail)",
    )
    stack5.push(
        lambda: (counter5.update({"value": counter5["value"] + 1}) or True),
        operation_id="first",
        description="First handler (LIFO: executes first, should succeed)",
    )

    results = stack5.rollback_all()
    print(f"  ABORT results: {results}")
    print(f"  counter5 = {counter5['value']} (expected 1 — only first executed)")
    print(f"  Stack depth after: {stack5.depth} (expected 1 — last handler remains)")
    assert counter5["value"] == 1, (
        f"ABORT: expected only first executed, got {counter5['value']}"
    )
    assert stack5.depth == 1, (
        f"ABORT: expected 1 remaining, got {stack5.depth}"
    )
    print("  PASSED: ABORT_ON_FAILURE stopped at first failure")

    # Final summary
    print(f"\n  Stack summary: {json.dumps(stack.summary(), indent=2, default=str)}")

错误处理策略的选择指南

RollbackStack 提供了三种错误处理策略——ABORT_ON_FAILURECONTINUE_ON_FAILURERETRY_ONCE——每种策略对应不同的恢复哲学和风险偏好。选择哪种策略不取决于个人喜好,而取决于操作的性质

一个推荐的实践模式是:根据 group_id 动态选择策略。Agent 框架在将 undo handler 推入栈时,可以根据操作的类型设置 group_id——例如 "critical-config" 组的错误策略为 ABORT,"cache-files" 组的错误策略为 CONTINUE。这个粒度控制使系统能在安全性和覆盖范围之间取得最优平衡。

栈溢出与内存管理

一个容易被忽视但实际中会引发严重问题的是:栈无限增长。如果 Agent 在一个长任务中执行了 5000 次写操作——每个操作都 push 一个 undo handler——但永远不调用 commit(),栈就会持续膨胀到内存溢出。这就是 max_stack_depth 参数存在的意义:当栈深度超过上限(默认 10,000),新的 push 将抛出异常,迫使调用者要么调用 commit() 清出空间,要么重新设计 Agent 的任务粒度。

实践中,commit() 的最佳调用时机是每个子任务完成时——当 Agent 完成了一个独立的子任务并且验证其输出是正确的,就可以将子任务的所有 undo handler 标记为 committed。这些 handler 不从栈中移除(保留用于审计),但在 rollback_all()rollback_to() 时会被跳过。在子任务边界处配合 mark_checkpoint()——即使后续子任务失败,也不会回滚已完成的子任务。关于 Agent 任务和子任务的边界定义,参见 Agent 状态机设计

回滚的时间上限与渐进式恢复

最后一个需要直面的问题:回滚本身需要时间——如果 Agent 执行了 100 个写操作,每个 undo handler 的平均执行时间是 200ms(恢复文件、数据库行、API 调用),完整回滚需要 20 秒。在这 20 秒内,系统处于"正在回滚中"的中间状态——下游服务可能已经感知到问题并触发了它们自己的恢复流程,与 Agent 的回滚形成竞争条件。

缓解方案是分级回滚(Tiered Rollback):将 undo handler 按恢复速度分为三级——(1)级 1"秒级":恢复关键配置和服务路由(执行时间 < 2 秒),(2)级 2"分钟级":恢复数据状态和文件内容(执行时间 < 30 秒),(3)级 3"小时级":补偿外部 API 调用和跨系统副作用(可能需要人工介入)。在分级模型中,rollback_all() 不再是"全部 undo 一次执行",而是先执行级 1(快速止损),再执行级 2(状态修复),最后执行级 3(补偿协调)。

分级回滚的栈实现可以通过多栈模型优先级标记来完成——每个 undo handler 在 push 时携带一个 tier 标记(1/2/3),rollback_tier(n) 只执行该 tier 及以上的 handler。这是一个高级话题——本文的 RollbackStack 提供了基础的 LIFO 和 checkpoint 机制,分级回滚可以在此基础上扩展。关于错误恢复流程中如何与回滚栈协同,参见 Agent 错误恢复

7. 提交前验证——Commit-and-Validate 模式

前六节构建了完整的回滚机制:从策略选择(第 2 节)、文件级实现(第 3 节)、数据级包装器(第 4 节)、环境级快照(第 5 节)到可组合的栈式回滚(第 6 节)。这一整条回滚链路回答了"出错了怎么撤销"的问题。但一个更优秀的设计应该追问:能否在出错之前就阻止错误的写入?这就是 Commit-and-Validate 模式的核心理念——将验证前置到提交之前,把回滚从"事后补救"降级为"最后防线"。

两阶段提交:暂存 → 验证 → 最终化

第 3 节的 CopyOnWriteAgent 已经内置了这个流水线的雏形——write() 暂存内容、commit() 中触发 validator 回调、通过验证后才执行 os.replace() 原子替换。但生产环境中的验证远比"一个回调函数"复杂。一个成熟的 Commit-and-Validate 流水线包含以下阶段:

  1. Stage(暂存):Agent 的所有写操作不直接修改目标文件,而是写入暂存区(staging area)。暂存区与生产文件系统隔离,目标文件在整个暂存期间保持原有状态不变。阶段输出是一组 WriteHandle 对象,每个 handle 包含暂存文件路径、目标路径、校验和和操作 ID。
  2. Validate(验证):在暂存内容上运行验证器链(validator chain),而不是在目标文件上。验证器链是一个有序的检查序列——每个验证器返回 pass/fail,失败即中断。典型验证器链包含:语法检查(YAML/JSON/TOML 解析器)、模式验证(JSON Schema、XML Schema)、编译检查(python -m py_compilenginx -t)、安全扫描(敏感信息检测、注入模式匹配)、功能测试(单元测试或集成测试套件针对暂存文件运行)。
  3. Finalize(最终化):所有验证器通过后,执行原子提交——os.replace() 将暂存文件替换目标文件,或执行数据库事务提交。如果任何验证器失败,所有暂存文件被丢弃,目标文件从未被修改。

这个三阶段流水线的核心价值在于隔离性——验证阶段发生任何意外(验证器崩溃、OOM、超时)都不会影响生产环境。即使在最终化阶段,原子替换保证目标文件要么完全保持不变,要么完全替换为新内容,绝不存在"写了一半"的中间状态。对比"先写后验"(Agent 直接修改文件,然后运行 nginx -t 检查)——如果 nginx -t 失败,文件已经被破坏,此时你已经在执行回滚了。两阶段提交将验证从"触发回滚"提升为"防止写入",从根本上减少了回滚的发生频率。

金丝雀写入:先应用到子集,验证后传播

对于影响面广泛的操作——例如修改 50 台服务器共享的配置文件模板——即使在暂存区验证通过,仍然存在"验证环境与生产环境不完全一致"的风险。金丝雀写入(Canary Write)是应对这种风险的模式:不一次性对所有目标执行原子提交,而是先选择一个最小子集(1-2 个目标)执行提交并观测,确认无问题后再逐步推广到剩余目标

金丝雀写入的流程可以建模为:

金丝雀写入不是免费的——它引入了额外的延迟(观测窗口)和复杂度(决策逻辑、监控集成)。对于低风险操作(修改一个注释、格式化代码),金丝雀写入是过度设计。但对于高风险操作(修改生产配置、更新数据库 schema、改变服务路由规则),金丝雀写入是值得的——它把"一次错误写入影响 100% 目标"降级为"一次错误写入只影响 2% 目标,且在 5 分钟内自动恢复"。

验证标准:从校验和到功能测试

验证器链的质量决定了 Commit-and-Validate 模式的有效性——验证器太弱会漏过错误,验证器太强会导致过于频繁的拒绝(影响 Agent 效率)。以下是按强度递增的验证标准分类:

实践中的推荐策略是:级联验证——先运行成本低的验证器(级 1-2),通过后再运行成本高的(级 3-4)。大多数错误在级 1-2 就能被拦截,只有少数边界情况需要级 3-4。对于级 5,仅在 Agent 修改的是"代码文件"(如 Python/Go/JS 源文件)时才启用——因为功能测试的启动成本和非确定性使其不适合用作通用验证器。

人机协同的验证门禁

Commit-and-Validate 模式的顶点是人机协同验证门禁——在自动验证器链的末端设置一个人工审批节点。当自动验证器全部通过但操作风险等级超过阈值时(如修改金额相关的财务配置、删除数据库表、更改身份认证规则),系统暂停提交并请求人工审批。

人机协同门禁的设计需要考虑两个维度:何时触发人工审批(风险分级)和如何呈现信息(审批界面)。风险分级可以通过操作类型(文件修改 vs 数据库删除)、影响范围(1 个文件 vs 50 个文件)、可逆性(可快照恢复 vs 只能补偿)来评分——分数超过阈值即触发审批。审批界面应呈现:操作的具体内容(diff)、验证器链的执行结果、操作影响的范围(受影响的目标列表)、以及"批准后如果出错如何回滚"的预案。关于人机审批的工作流设计,参见 Agent 人工审批工作流;关于验证门禁与发布流程的结合,参见 Agent 发布 Gate 设计

8. 局限、权衡,以及回滚无法拯救你的场景

前七节详细构建了一个多层回滚体系——从单文件快照到栈式组合撤销再到提交前验证——这个体系的覆盖面看起来相当完整。但本文必须以诚实的自我审视收尾:回滚不是银弹,有些损害一旦发生就不可逆转。本节分析回滚的五个根本局限,以及在每种情况下应该采取的实际策略。

外部副作用:已发送的子弹无法收回

这是回滚最根本的边界:任何已经离开 Agent 所在系统边界的操作,都无法真正"回滚"。Agent 调用 Stripe API 创建了一笔扣款——补偿策略可以发起退款,但这笔扣款已经发生了:Stripe 的手续费可能不退、用户的银行可能已经产生了外汇兑换损失、税务记录已经生成。Agent 发送了一封邮件——补偿策略可以发一封更正邮件,但原始邮件已经在收件人的收件箱中,可能已经被阅读、转发或截图。Agent 向 Kafka/Redis PubSub 发布了一条消息——下游消费者可能已经消费并基于该消息做出了业务决策。

对于外部副作用,唯一诚实的策略是:补偿 + 透明告警。补偿函数执行语义逆向操作(退款、更正通知、死信消息),同时向系统管理员和下游利益相关者发送明确的告警——说明原始操作是什么、补偿操作做了什么、还遗留了什么影响。永远不要对上游系统声称"已回滚"——用"已补偿"代替。关于补偿策略的详细设计,回顾第 2 节中 RollbackOrchestrator 的 COMPENSATION 策略分支。关于如何将补偿操作纳入审计日志供后续追溯,参见 Agent 审计日志设计

分布式写入:跨 Agent 的一致性挑战

当多个 Agent 同时在同一组资源上执行写入操作时,回滚面临分布式系统经典的一致性挑战。假设 Agent-A 修改了文件 F1,Agent-B 基于 F1 的当前内容修改了文件 F2——现在需要回滚 Agent-A 的操作。恢复 F1 的旧快照会导致 F2 的内容基于一个不复存在的前提,F2 进入逻辑不一致状态。这种情况在单体 Agent 系统中不会出现(单体 Agent 的操作可以线性化),但在多 Agent 协作的系统中几乎不可避免。

缓解策略分为三个层次:

实践中,推荐层 1 优先——通过设计避免跨 Agent 资源争用。当层 1 不可行时(Agent 确实需要协作修改共享资源),采用层 3(级联告警)作为安全网。层 2(分布式事务)是学术上优雅但工程上高风险的选择——两阶段提交在跨 Agent 上下文中的超时和分区容错问题使其实用性受限。关于多 Agent 的协作与编排,参见 多 Agent 编排

人机协同操作:不可逆的人类决策

当 Agent 的操作包含人工审批环节时,回滚面临一个哲学层面的困境:已经做出的人类决策无法被"撤销"。假设 Agent 建议删除某个数据库表,人类审批者同意了——表被删除了。即使快照恢复了表的数据,审批者已经做出的"同意删除"决策是一个历史事实,无法抹去。更关键的是,审批者可能基于这次审批经验调整了未来的审批阈值("上次审批的表删除是安全的,下次同类请求可以自动通过")——回滚抹去了操作的结果但留下了决策的影响。

这种困境的实际影响因场景而异。对于纯技术决策("同意修改 Nginx 配置的限流阈值"),回滚后的影响最小——审批者的决策是环境相关的,回滚意味着环境前提发生了变化。对于业务决策("同意向用户退款 $500"),回滚(退款已被补偿操作撤销)留下了审计记录中的矛盾——有一笔退款,然后有一笔退款的反操作,业务分析需要理解这整个链条。

应对策略是决策链的可追溯性而非可撤销性:每一次人工审批都应记录完整的决策上下文(Agent 的建议理由、审批者身份、审批时间、当时的系统状态快照),使后续的"为什么当时同意了"有据可查。回滚发生后,在原决策记录上追加一条"已回滚"标记和回滚原因,而非删除或覆写原决策记录。关于审批工作流的完整设计,参见 Agent 人工审批工作流

时间边界的回滚:快照保留成本与恢复价值的权衡

回滚的有效窗口受限于快照和 undo 数据的保留时间。第 3 节的 CopyOnWriteAgent 默认保留快照 1 小时,第 4 节的检查点管理器可以配置 N 天或 N 个版本。问题的本质是一个经济学权衡:保留更长时间的快照增加了潜在的恢复价值(可以回滚更早期的错误),但也增加了存储成本和攻击面(更多的历史数据可能泄露敏感信息)

指导原则不是技术性的,而是业务驱动的:

一个容易被忽视的维度是快照数据的安全清理——快照包含完整的文件/数据库内容,可能包含 PII、密钥或凭证。快照过期时,不仅要删除文件,还应确保安全擦除(secure deletion)。在云环境中,依赖存储层的加密和访问控制通常足够;在本地磁盘上,至少确保快照目录的权限是所有者只读。

诚实告示:完美回滚不存在——设计部分恢复

本文从第 1 节到第 7 节构建的回滚体系,在理想条件下(所有操作都可逆、所有 undo handler 都执行成功、没有外部副作用或跨 Agent 依赖冲突)可以实现"100% 恢复到操作前状态"。但生产环境从来不是理想条件。以下是对回滚真实性最诚实的评估:

基于这些现实,本文的核心建议从"设计可回滚的系统"修正为:"设计可部分恢复的系统——接受在某些场景下回滚只能恢复到 80%、90% 或 99% 的状态,并在此基础上设计降级路径和人工介入流程"。一个实际可行的架构不是"全自动回滚",而是"自动回滚 + 自动告警 + 人工兜底"的三层恢复模型——自动回滚覆盖 95% 的常见场景,自动告警覆盖剩余 4%(回滚部分失败,需要人工确认),人工兜底覆盖最后 1%(回滚完全失败或产生了复杂的跨系统影响)。

元回滚:回滚本身失败了怎么办

回滚的最后一重讽刺是:回滚操作本身也是一个写操作——它修改文件、更新数据库、调用 API——因此它本身也可能失败。如果你的回滚代码抛出了一个异常,或者快照文件损坏无法读取,你就陷入了"需要回滚,但回滚也坏了"的元问题。

应对元回滚失败的策略是一个升级链

  1. 重试回滚:回滚失败的第一反应应该是重试——回滚的失败可能来自瞬态故障(磁盘 I/O 阻塞、网络超时)。但有次数限制(最多 3 次),避免无限重试循环。第 6 节的 RollbackStackRETRY_ONCE 策略就是这一层的实现。
  2. 降级回滚:如果原始回滚策略失败,尝试降级策略——快照恢复失败 → 尝试差异补丁恢复;WAL 恢复失败 → 尝试从上一个检查点恢复;文件级回滚失败 → 尝试环境级回滚(容器恢复到上一个 checkpoint)。降级回滚恢复的粒度更粗(可能丢失部分状态),但比完全无法恢复好。
  3. 安全停止:如果降级回滚也失败,进入"安全停止"模式——停止 Agent 的所有后续写入操作(冻结工作区),向人类操作员发出紧急告警(PagerDuty/钉钉/飞书),并保留所有残留状态(损坏的文件、未执行完成的 undo handler 列表、WAL 日志)供人工诊断。
  4. 法医恢复:最后手段——从系统级备份(如数据库的 PITR、文件系统的定时快照)恢复状态。这是代价最高的恢复方式(可能丢失数小时的数据),但也是最终的安全网。Agent 的回滚系统不应替代系统级备份——它是在备份窗口内提供更细粒度的恢复能力。

元回滚的设计原则是:永远不要假设回滚会成功——为回滚的失败预留恢复路径。在代码层面,这意味着每个回滚操作都应有 try/except 包裹,且 except 分支不应只记录日志然后放弃——它应触发升级链。关于 Agent 的错误恢复全貌(包括回滚失败后的恢复流程),参见 Agent 错误恢复

常见问题

1. 回滚和版本控制(git)有什么不同?

版本控制(git)追踪的是你显式提交的文件版本——它在两次 commit 之间对文件的变化一无所知。Agent 在两次 git commit 之间可能执行了数十次文件写入,任何一次都可能造成破坏,而 git 只能恢复到上一个 commit——这意味着你会丢失两次 commit 之间的所有有效修改。回滚的操作粒度是每次写入,它在写入之前捕获状态,可以精确撤销单次写入而不影响其他写入。此外,git 只能追踪被纳入仓库的文件——Agent 经常修改 git 仓库外的文件(配置文件、环境变量文件、数据库记录),这些操作需要文件级快照和事务回滚而非版本控制。二者是互补关系:git 管理长期版本历史,回滚管理近期每次写入的撤销路径

2. 维护快照的存储成本有多高?快照应该何时过期?

存储成本取决于三个变量:文件大小、写入频率和保留窗口。一个 50MB 的日志文件被 Agent 每小时写 10 次,保留 24 小时的快照 = 50MB × 10 × 24 = 12GB。对于文本文件,差异补丁(diff)可以大幅降低成本——修改 3 行的 10MB 文件,diff 可能只有 200 字节而非 10MB。快照过期时间应根据操作类型设定:文件写操作的快照保留 1-24 小时(通常 1 小时内就能发现错误),数据库检查点保留到下一次备份完成,环境级快照(容器/VM)保留 1-4 小时(因为体积大且回滚窗口短)。最佳实践是设置容量上限(如暂存目录最大 500MB)和TTL 自动清理(如 24 小时后台清理线程),而非手动管理快照生命周期。对于高频写入场景,使用 WAL/差异回滚替代完整快照,将存储开销从 O(文件大小 × 写入次数) 降到 O(差异大小 × 写入次数)。

3. 我能否在一个操作上同时使用快照和事务两种策略?

可以,而且这正是 RollbackOrchestrator 的设计意图——每个策略决策都携带一个可选的 fallback_strategy。实际场景中最常见的组合是:(1)事务为主、快照为回退——对大文件使用 WAL 事务记录变更,但如果 WAL 的 pre-image 保存失败,降级为完整快照。这比单独使用任意一种策略都更可靠。(2)快照 + 补偿混合——对于"修改配置文件 + 重启服务"的操作,文件使用快照回滚,服务重启使用补偿(重启回旧版本)。RollbackOrchestrator 允许每个子操作选择独立的策略,LIFO 栈保证正确的回滚顺序。(3)不应同时使用的场景:对一个文件同时拍快照又记录 WAL——这会浪费存储且没有额外收益。策略选择的依据是操作特征(文件大小、写入频率、是否需要原子性),而非"哪个策略更好"——没有普遍更好的策略,只有更适合当前操作的策略。

4. Agent 已经调用了一个外部 API,我该如何"回滚"这个调用?

你无法真正回滚一个外部 API 调用——你只能补偿它。补偿和回滚的区别是:回滚恢复状态(文件内容、数据库行),补偿执行语义逆向操作(退款、取消订阅、发送更正通知)。补偿的可靠性取决于外部 API 是否提供逆向端点——Stripe 的退款 API、AWS 的资源删除 API 是可逆的(但仍然有副作用,如手续费不退还);发送给用户的邮件是不可逆的(你只能发一封更正邮件,但原始邮件已经送达)。对于不可逆的外部 API 调用,最佳实践是:(1)在执行前要求人工确认——将不可逆操作的门槛提高;(2)先执行可逆的预操作——如先在数据库中创建记录并标记为"pending",确认无误后再调用外部 API;(3)为补偿操作编写补偿操作——如果取消订阅的 API 也失败了,你需要升级到人工介入。永远对上游系统使用"已补偿"而非"已回滚"的语言——这是对真实情况的诚实描述。

5. 回滚本身失败了怎么办?我需要设计"元回滚"吗?

是的,回滚失败是一个真实且危险的场景——回滚操作本身也是写操作(修改文件、更新数据库、调用 API),因此同样可能失败。应对策略是一个四级升级链:(1)重试回滚——最多 3 次,适用于瞬态故障;(2)降级回滚——快照恢复失败 → 尝试差异补丁;文件级回滚失败 → 尝试环境级回滚(容器 checkpoint 恢复);(3)安全停止——冻结 Agent 的所有后续写入,发送紧急告警给人类操作员,保留所有残留状态供诊断;(4)法医恢复——从系统级备份(数据库 PITR、文件系统定时快照)恢复,代价最高但有最终兜底。元回滚不是"另一个回滚系统"——那会导致无限递归。元回滚是降级路径的设计——为回滚失败预留一个不那么优雅但能兜底的恢复方式。关于完整的恢复流程设计,参考 Agent 错误恢复

6. 我应该保留多少个检查点?有没有保留策略的经验规则?

检查点的保留数量由三个因素决定:存储预算、回滚需求和检查点间隔。经验规则如下:(1)按数量保留——保留最近 N 个检查点,N 取 5-20,这是最简单也最常用的策略。对于文件级操作(第 3 节),保存最近 10 个快照或差异补丁;对于 Agent 状态检查点(第 4 节),保存最近 3-5 个(状态序列化体积较大)。超出 N 的检查点自动删除。(2)按时间窗口保留——保留过去 T 小时内的所有检查点,T 取 1-24。这种策略适用于写入频率不固定的场景——保证任何时候都能回滚到过去 T 小时内的任意时间点。(3)分级保留——结合数量和时间的优势:保留最近 5 个检查点(确保最近操作可回滚),外加过去 24 小时中每小时一个检查点(确保时间跨越),再加过去 7 天每天一个检查点(确保长期可追溯)。关键原则:检查点的价值随时间呈指数衰减——1 小时前的检查点被使用的概率比 10 分钟前的低一个数量级,1 天前的比 1 小时前的更低。将存储预算重点分配给近期检查点,远期检查点可以压缩合并。