首页 > 基础资料 博客日记

PyTorch KernelAgent 源码解读 ---(3)--- orchestrator

2026-05-16 22:00:02基础资料围观1

这篇文章介绍了PyTorch KernelAgent 源码解读 ---(3)--- orchestrator,分享给大家做个参考,收藏极客资料网收获更多编程知识

PyTorch KernelAgent 源码解读 ---(3)--- orchestrator

目录

0x00 摘要

orchestrator 是 KernelAgent 系统中的一个核心组件,负责协调和管理多个工作进程(worker),实现并行执行任务并从中选择最优结果。Fuser/orchestrator.py 文件实现了 Orchestrator 类,用于多进程协调任务执行。其功能用一句话概括:fork N 个Worker竞赛,首个 PASS胜出,其余终止,产物打包返回。

orchestrator 的精简架构如下:

orchestrator

0x01 Fuse

1.1 流水线

前文提到,KernelAgent 的完整流水线(pipeline)分为四个阶段?

  • FuserAgent – 保持 Python 语义的代码到代码融合,以 Orchestrator 为入口,将PyTorch 模块重写为可融合子模块
  • Extract(提取):形状推理与合约生成,让LLM分析产出subgraphs.json(子图列表)。
  • Dispatch(分发):将每个子图描述转化为Triton规格,并行派发给TritonKernelAgent生成内核。
  • Compose(组合):将验证通过的内核拼接成一个完整的composed_kernel.py`,包含自测代码。

其实,实际代码和官方文档有出入,实际代码中,Orchestrator 是 Extract 阶段的子组件,负责"重写"这一步,而 subgraph_extractor 是 Extract 阶段的完整入口。或者说,Orchestrator = 多 worker 竞赛生成融合代码,是 Pipeline 第一步 Extract 阶段的执行引擎。

即,Extract 阶段包含两步,Orchestrator 负责第一步:

Extract 阶段(subgraph_extractor.py 统一入口)
|
|-- Step 1a: Orchestrator 重写代码
|   输入:原始 PyTorch 问题文件
|   过程:多 Worker 竞赛,LLM 将代码重构为可融合子模块
|   输出:重构后的 code.py(仍是 PyTorch,但已拆分为 nn.Module 子模块)
|
|-- Step 1b: LLM 分析子图
    输入:原始问题 + 重构后的 code.py
    过程:单次 LLM 调用,提取 shapes/ops/dtypes
     输出:subgraphs.json(JSON 数组,描述每个子图的精确形状签名)

在 subgraph_extractor.py 的 extract_subgraphs_to_json() 函数中可以清晰看到:

  • 先创建 Orchestrator 并 orch.run() -> 得到 code.py
  • 再调用 LLM 分析 code.py -> 得到 subgraphs.json

1.2 功能

Orchestrator (Fuser/orchestrator.py) 的功能是 将 PyTorch 模块重写为可融合的子模块代码。Fuser 阶段的核心目的是在保持程序语义的前提下,重新组织模型结构以促进更有效的操作融合。通过将相关的连续操作打包成独立的模块,同时保留控制流结构,该阶段为后续的子图提取、Triton 内核生成和最终合成提供了优化的基础。这种结构化的方法既提高了性能潜力,又保持了代码的可读性和功能完整性。

具体来说,Orchestrator 的核心职责是:

  • 启动 N 个 Worker 进程(默认 4 个),每个 worker 独立请求 LLM 将原始 PyTorch
    问题文件重构为"融合友好"的代码
  • 每个 Worker 使用不同的提示词变体(4 种措辞)增加多样性
  • 第一个通过验证(run_tests() 打印 PASS)的 worker 胜出

Orchestrator 的关键机制如下:

  • winner_queue: 第一个成功的 worker 把结果放入队列
  • cancel_event: 收到 winner 后通知其他 worker 停止
  • console_mux: 聚合所有 worker 的 LLM 流式输出到终端
  • 跨 worker SHA256 去重(通过 register_digest),避免两个 worker 验证相同代码
  • 成功后打包 result.tar.gz(含重构后的 code.py)

输出:一个重构后的 Python 文件(仍然是 PyTorch 代码,但已拆分为 nn.Module 子模块),这个文件随后被 subgraph_extractor 分析产出 subgraphs.json 。

样例

Fuse 阶段是将原始模型中的多个连续操作融合成更少的子模块,以便后续生成更高效的 Triton 内核。以下面代码为例,其具体变化分析,是从线性结构到模块化结构。

输入:任意复杂度的原始 PyTorch 模型(包含多个单独的 nn 层),即单一 forward 方法,顺序执行多个操作

x → conv → bn → tanh → max_pool2d → norm

class Model(nn.Module):
    def forward(self, x):
        if x.sum() > 0: 
            x = self.conv(x)
            x = self.bn(x)
            x = torch.tanh(x)
            x = F.max_pool2d(x, 2)
        return self.norm(x)
  • 输出:结构化为可融合子图的模块(将相关操作打包成独立模块)

x → branch(融合了conv+bn+tanh+max_pool2d) → norm

class FusedModel(nn.Module):
    def __init__(self, channels: int):
        super().__init__()
        self.branch = ConvBnTanhMaxPool(channels=channels)  # 融合模块
        self.norm = ChannelwiseNorm(channels=channels)      # 独立模块

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if x.sum() > 0:  # 控制流 intact(保持不变)
            x = self.branch(x)  # 单个调用执行多个操作
        return self.norm(x)

具体操作

Fuse 阶段的具体操作如下:

  • 操作融合识别
    • 识别可融合操作链:conv → bn → tanh → max_pool2d
    • 保留控制流:if x.sum ()> 0: 条件判断保持不变
    • 分离独立操作:norm 作为一个独立的操作保留
  • 模块创建
    • 创建 ConvBnTanhMaxPool 模块,封装连续的操作
    • 将 norm 作为独立的 ChannelwiseNorm 模块
  • 控制流处理
    • 保持控制流完整:if x.sum ()> 0: 判断条件没有被融合,保持原有的控制逻辑
    • 条件分支内的操作被融合:在 if 块内的操作被打包成单一模块

为什么这样做?

为何这样做可以性能优化?

  • 减少内核启动开销

    • 原始情况:每个操作都需要一次 GPU 内核启动
    • 融合后:多个操作在一个内核中执行,减少启动开销
  • 提高内存访问效率

    • 减少中间张量:避免每个操作之间的临时张量存储
    • 更好的内存局部性:相关操作的数据访问模式更加集中
  • Triton 内核生成优化

    • 便于子图提取
      • 为 subgraph_extractor.py 提供清晰的子图边界
      • 每个 nn.Module 对应一个潜在的融合子图
    • 明确的输入输出边界
      • 每个融合模块有明确的输入输出形状定义
      • 便于 dispatch_kernel_agent.py 生成特定形状的 Triton 内核
  • 控制流保留的重要性

    • 语义完整性
      • 保持原始程序的执行逻辑
      • 确保融合不会改变程序行为
    • 适应动态执行路径
      • 条件分支允许根据输入数据动态决定是否执行某些操作
      • 保持模型的动态特性

1.3 Orchestrator 对后续阶段的影响

  • 对子图提取阶段的影响
    • 清晰的分割点:每个 nn.Module 代表一个潜在的融合单元
    • 形状信息明确:每个模块的输入输出形状更容易推断
  • 对 Triton 生成阶段的影响
    • 针对性优化:为特定的操作组合生成优化的 Triton 内核
    • 减少碎片化:避免为单个简单操作生成内核
  • 对合成阶段的影响
    • 模块化组装:可以更容易地将生成的 Triton 内核重新组装成完整模型

1.4 Orchestrator 内部工作机制

工作进程入口函数:_worker_process_main 函数

  • 作为子进程的实际执行入口
  • 在子进程中重新导入必要的类和模块
  • 将配置数据从字典还原为 WorkerConfig 对象

总结报告生成 ResultSummary 对象

  • 包含运行 ID、胜出者工作进程 ID、成果路径和原因说明
  • 提供完整的执行结果概览

持久化存储

  • 将总结信息保存为 summary.json 文件
  • 记录胜出者信息到专门的目录

0x02 多 Worker 并行机制详解

2.1 基础概念

KernelAgent 采用多 Worker 并行执行的架构设计,这是其核心的性能优化策略之一。每个 Worker 作为独立的进程运行相同的问题解决任务,但可能使用不同的参数或策略。

2.2 为何要多个 Worker 并行

应用场景

LLM 固有的不确定性

  • LLM 的输出具有一定的随机性
  • 不同的提示词变体可能产生不同的效果
  • 并行执行增加了找到高质量解决方案的概率

问题复杂性

  • 某些问题可能有多种有效的解决路径
  • 不同的方法可能在不同的情况下表现更好
  • 并行探索能够快速找到最适合的解决方案

探索不同解决方案

参数多样性

  • 每个 Worker 可能使用不同的 variant_index(变体索引)
  • 在 prompting.py 中定义了不同的提示词变体(VARIANT_WORDINGS)
  • 不同的 Worker 会使用不同的提示词变体来引导 LLM 生成不同的解决方案

随机性和探索性

  • 并行执行增加了探索解空间的广度
  • 不同的随机种子和初始条件可能导致不同的优化路径
  • 提高找到有效解决方案的概率

加快收敛速度

时间效率

  • 并行执行可以同时尝试多种策略
  • 避免串行试错的漫长过程
  • 最先成功的解决方案可以立即终止其他进程

资源利用

  • 充分利用多核 CPU 资源
  • 在资源允许范围内加速问题求解

2.3 为何要选出一个 winner

KernelAgent 的主要目标是快速获得可用的融合内核,而不是进行详尽的统计分析。而且,在大多数情况下,第一个成功解决方案已经满足需求。

避免资源浪费

早期停止机制

  • 一旦某个 Worker 成功找到解决方案,立即终止其他 worker
  • 节省计算资源和时间
    • 时间成本:GPU 计算资源昂贵,让所有 Worker 运行到结束会浪费大量计算时间

    • 能源成本:不必要的计算会消耗更多电力

    • 硬件资源:释放 GPU 资源供其他任务使用

    • 机会成本

      • 尽早获得可用解决方案,让开发人员可以继续其他工作
      • 避免在单一任务上占用过多资源
      • 开发周期通常有限,需要尽快获得可用结果
      • 过度追求“最佳”解决方案可能会延误整体进度
  • 在 orchestrator.py 中实现 winner_queue 机制

解决方案等价性

  • 在大多数情况下,所有 Worker 产生的解决方案在功能上是等价的
  • 一旦找到一个可行解,其他解的价值相对较低

统一输出

  • 确保最终只有一个解决方案被采纳
  • 避免多个解决方案之间的冲突

质量保证

选择最佳方案

  • 最先通过验证的方案通常是最优的
  • 通过 run_candidate 函数进行严格的验证
  • 验证标准包括 PASS 标记或 ALL_TESTS_PASSED 信标

确定性结果

  • 确保每次运行产生一致的输出
  • 避免随机性导致的结果不一致

2.4 winner 选择机制

KernelAgent 当前采用竞争性并行策略,即多个 Worker 同时运行,但只要有一个成功就立即停止其他进程。

竞争性队列

  • 使用共享的 winner_queue 来接收第一个成功完成的任务结果,所有 Worker 向同一个 winner_queue 发送结果
  • 所有工作进程都向同一个队列发送结果,但只有第一个成功的结果会被采纳,记录胜出的工作进程 ID 和相关信息
  • 一旦有结果,立即终止所有其他进程
  • 队列大小限制为 1(maxsize=1)

验证标准

优先级顺序:

  • 检查 _PASS_REGEX 是否在输出中(run_tests 打印 "PASS")
  • 检查 _SENTINEL信标("ALL_TESTS_PASSED")
  • 验证返回码是否为 0

选择逻辑位于 worker.py 中:

if rr.passed:
    self.logger.info("PASS at iter %d via %s", k, rr.validator_used)
    try:
        self.winner_queue.put({
            "worker_id": self.cfg.worker_id,
            "iter": k,
            "validator": rr.validator_used,
            "runs_dir": str(run_root),
            "artifacts_dir": str(self.dirs["artifacts"]),
        }, timeout=0.1)
    except queue.Full:
        pass

选择后的行为

立即终止其他进程

  • 设置 cancel_event 以通知所有其他 worker 停止工作
  • 通过 p.terminate()p.kill() 强制终止进程

结果包装和记录

  • 将胜出者的成果打包到压缩文件中,包含代码文件和运行目录内容

  • 记录胜出者信息到 summary.json 文件

  • 在专门的目录中保存胜出者信息

0x03 KernelAgent 中的 Prompt 特点分析

在 KernelAgent 系统中,prompt 是指导 LLM(大语言模型)生成特定解决方案的关键输入,特别是在 orchestrator 的 worker 过程中用于生成融合子图模块。

3.1 Prompt 策略

render_prompt 函数定义了 prompt。

构建 Prompt 的硬性要求列表如下:

  • 返回单个可运行 Python 文件,用 python 代码块包围
  • 每个融合子图必须表示为独立的 nn.Module 类
  • 包含 run_tests() 函数验证数值等价性
  • 成功时打印 "PASS" 并退出
  • 不允许网络或文件 I/O 操作

完整解决方案重发

  • 每次尝试都重新发出整个单文件解决方案
  • 避免增量修改导致的不一致性

动态错误上下文(迭代优化)

  • 在每次迭代中加入上一次的错误信息作为 error_context
  • 帮助模型了解之前失败的原因并进行修正
  • 修复后仍需重新发出整个文件
@dataclass(frozen=True)
class RenderedPrompt:
    system: str
    user: str
    extras: dict[str, Any]

def render_prompt(
    problem_path: Path,
    variant_index: int,
    attempt_index: int,
    error_context: str | None,
    enable_reasoning_extras: bool,
    seed: int | None = None,
    model_name: str | None = None,
) -> RenderedPrompt:
    """Render system+user prompts and extras for the Responses API (deterministic)."""
    content = problem_path.read_text(encoding="utf-8")
    user = build_user_prompt(
        attempt_index=attempt_index,
        problem_file_content=content,
        error_context=error_context,
        variant_index=variant_index,
    )
    extras: dict[str, Any] = {}
    if seed is not None:
        extras["seed"] = seed
    if enable_reasoning_extras:
        # Use high reasoning effort for GPT-5 per policy
        extras["reasoning"] = {"effort": "high"}
        # Align with Responses API text options for clearer outputs
        text_options: dict[str, Any] = {"format": {"type": "text"}}
        if model_name:
            if model_name.startswith("gpt-5"):
                text_options["verbosity"] = "high"
            elif model_name.startswith("o4-mini"):
                text_options["verbosity"] = "medium"
        extras["text"] = text_options
    return RenderedPrompt(system=SYSTEM_PROMPT, user=user, extras=extras)

3.2 Prompt 结构特点

System Prompt

简洁明了

  • SYSTEM_PROMPT"Return a single runnable Python file only."
  • 明确指示模型只返回一个可运行的 Python 文件

User Prompt

BASE_DEVELOPER_PROMPT 是 User Prompt。

此处做了角色设定:

  • 设定模型为专家级 PyTorch 工程师,专注于推理阶段的图融合
  • 强调专业领域和目标导向
BASE_DEVELOPER_PROMPT = (
    "You are an expert PyTorch engineer focused on inference-only graph fusion.\n\n"
    "Hard requirements:\n"
    "- Return ONE runnable Python file, fenced as a single ```python block.\n"
    "- Each fused subgraph must be represented by its own nn.Module class with a clearly documented forward; do not leave raw nn.* ops inline in the top-level Model.\n"
    "- Include a function run_tests() that validates numerical equivalence to the original using helpers in the problem file. "
    "On success, run_tests() must print 'PASS' and exit(0).\n"
    "- If you cannot implement run_tests(), then at minimum print the exact sentinel ALL_TESTS_PASSED and exit(0) when tests succeed.\n"
    "- No network or file I/O outside the current directory. Avoid extra dependencies.\n"
    "- Deterministic: set seeds where relevant.\n\n"
    "Fusion guidance:\n"
    "- Detect scaled dot-product attention patterns and aggressively fuse the entire block (QKV linears, splits/reshapes, scaled QK^T, causal masking, ReLU or gating, applying V, and head merge) into a single attention subgraph whenever feasible.\n"
    "- Only decompose attention into smaller subgraphs when you are certain fusion is impossible.\n\n"
    "Iteration contract:\n"
    "- On each attempt, re-emit the entire single-file solution.\n"
    "- When ERROR_CONTEXT is provided, carefully analyze and fix issues, then re-emit the whole file.\n"
)

具体构建代码如下。

def build_user_prompt(
    attempt_index: int,
    problem_file_content: str,
    error_context: str | None,
    variant_index: int,
) -> str:
    parts: list[str] = []
    parts.append(_variant_line(variant_index))
    parts.append("")
    parts.append(BASE_DEVELOPER_PROMPT)
    parts.append("")
    parts.append(f"ATTEMPT: {attempt_index}")
    if error_context:
        parts.append("")
        parts.append("ERROR_CONTEXT:")
        parts.append(error_context.strip())
    parts.append("")
    parts.append("PROBLEM_FILE_CONTENT:")
    parts.append(problem_file_content)
    return "\n".join(parts)

其中 _variant_line 代码如下。

def _variant_line(idx: int) -> str:
    i = idx % len(VARIANT_WORDINGS)
    return VARIANT_WORDINGS[i]

VARIANT_WORDINGS

VARIANT_WORDINGS 是在 prompting.py 文件中定义的一个元组,包含四个不同的提示词变体。

VARIANT_WORDINGS: tuple[str, str, str, str] = (
    "Rewrite the provided model into fusable subgraph modules with explicit input/output shapes.",
    "Refactor the given model into fusion-friendly submodules, specifying exact tensor shapes.",
    "Decompose the model into subgraphs suitable for fusion; document all input/output shapes.",
    "Split the model into fusable modules and clearly state the shape contracts for each.",
    "Every fused subgraph must be packaged as its own nn.Module (no inline nn.* ops at top level)",
)
# 注意这里有5个元素!

3.3 VARIANT_WORDINGS 的区别分析

这些变体用于在多 Worker 并行执行时,让不同的 Worker 使用略微不同的提示词来引导 LLM 生成多样化的解决方案。

设计意图

设计意图如下:

  • 多样性探索

    • 通过微调提示词措辞,鼓励 LLM 探索不同的解决方案路径
    • 增加找到成功解决方案的可能性
  • 避免重复

    • 不同的 Worker 使用不同的提示词,减少生成相似解决方案的机会
    • 提高并行执行的效率

这种设计巧妙地利用了提示词工程中的微妙差异来增加解决方案的多样性,同时保持了任务目标的一致性。

各变体详细分析如下。

第一个变体

第一个变体为:"Rewrite the provided model into fusable subgraph modules with explicit input/output shapes."

关键词汇

  • Rewrite:强调重写或重构现有模型
  • fusable subgraph modules:明确指出要创建可融合的子图模块
  • explicit input/output shapes:特别强调显式的输入/输出形状

侧重点

  • 重视模型重构的过程
  • 强调输入输出形状的明确性

第二个变体

第二个变体如下:"Refactor the given model into fusion-friendly submodules, specifying exact tensor shapes."

关键词汇

  • Refactor:强调重构,比 rewrite 更偏向于代码结构优化
  • fusion-friendly:强调为融合而优化的设计
  • specifying exact tensor shapes:强调精确的张量形状规格

侧重点

  • 更注重融合友好的设计
  • 强调精确性

第三个变体

"Decompose the model into subgraphs suitable for fusion; document all input/output shapes."

关键词汇

  • Decompose:强调分解,暗示将复杂模型拆分成更小的部分
  • suitable for fusion:强调适合融合的特性
  • document all input/output shapes:强调文档化所有形状

侧重点

  • 更强调整体到部分的分解过程
  • 强调文档化的重要性

第四个变体

"Split the model into fusable modules and clearly state the shape contracts for each."

关键词汇

  • Split:强调分离,比前几个词更简单直接
  • fusable modules:类似于前面的表述
  • clearly state the shape contracts:强调契约或合约的概念

侧重点

  • 更简洁直接的语言
  • 强调契约/合约概念

微妙差异总结

动词选择差异

  • Rewrite:重新编写
  • Refactor:重构优化
  • Decompose:分解拆分
  • Split:分离

术语差异

  • shapes vs tensor shapes vs shape contracts:不同层次的抽象
  • documentation vs specification vs stating:不同的明确程度

语义细微差别

  • 每个变体虽然传达相同的核心意图,但在措辞上略有不同
  • 这些差异可能影响 LLM 对任务的理解和解决方案的构建方式

3.4 Worker 进程数与 VARIANT_WORDINGS 数目关系分析

先说结论,Worker 进程数与 VARIANT_WORDINGS 数目不一定一致。

  • Worker 数量是可配置的,通常由用户在运行时指定
  • VARIANT_WORDINGS 数量是固定的(5 个,尽管注释说 4 个),实际只使用前 4 个变体,因为取模运算是对 4 取模而非 5
  • 分配使用取模运算,导致当 worker 数超过变体数时会循环使用

这种设计提供了灵活性,允许用户根据资源情况调整并发数量,而不受限于提示词变体的数量。

Worker 进程数配置

根据代码分析,Worker 进程数由 OrchestratorConfig 中的 workers 参数决定,这个参数是在运行时传入的。

VARIANT_WORDINGS 数量

VARIANT_WORDINGS 固定包含 4 个不同的提示词变体:

关系分析

取模运算分配

_make_worker_cfg 方法中,每个 worker 的 variant_index 是这样确定的:

variant_index = idx % 4  # 使用硬编码的 4

循环分配机制

  • 如果 Worker 数量小于或等于 VARIANT_WORDINGS 数量(5 个),则每个 Worker 可能获得不同的变体
  • 如果 Worker 数量超过 VARIANT_WORDINGS 数量,则开始循环重复使用变体
  • 使用 % 4 操作符意味着最多只会循环前 4 个变体,第 5 个变体实际上不会被使用

实际情况:不一定相等

  • Worker 进程数:可配置,通常在命令行参数中指定(例如 --workers 4
  • VARIANT_WORDINGS 数目:固定为 5 个(尽管注释说是 4 个)

分配策略

# 示例:如果有 6 个 workers
worker_0: variant_index = 0 % 4 = 0  # 使用第 1 个变体
worker_1: variant_index = 1 % 4 = 1  # 使用第 2 个变体
worker_2: variant_index = 2 % 4 = 2  # 使用第 3 个变体
worker_3: variant_index = 3 % 4 = 3  # 使用第 4 个变体
worker_4: variant_index = 4 % 4 = 0  # 重新使用第 1 个变体
worker_5: variant_index = 5 % 4 = 1  # 重新使用第 2 个变体

设计考虑

灵活性

  • 允许动态调整 worker 数量,而不受提示词变体数量的限制
  • 可以根据问题复杂度和可用资源调整并发度

多样性

  • 即使 worker 数量超过变体数量,也能提供一定程度的多样性
  • 不同的 worker 仍然会有不同的执行路径(由于随机性、迭代次数等因素)

3.5 迭代优化

关于迭代优化。我们可以用 KernelBench 的研究来看。KernelBench框架使模型能够在迭代优化过程中接收并利用反馈。这些真实信号包括NVCC编译器错误信息、执行统计数据(例如正确性检查和挂钟时间),以及PyTorch分析器(操作时间分解)。

动态迭代

他们在多轮过程中为模型提供每次生成的反馈:在初始生成后,向模型提供其之前的生成结果G,以及当前生成对应的编译器/执行反馈E和/或分析器输出P。

然后将每次生成及其后续反馈定义为一轮(turn),并在N轮内运行这一迭代优化过程。利用执行反馈有助于减少错误,并随时间提升整体加速效果。

研究人员发现迭代优化在不同模型和KernelBench的各个级别上均持续提升了性能。

此外,通过分析迭代优化轨迹,他们发现模型在执行反馈E的帮助下能更有效地自我纠正,尤其是在修复与执行错误相关的问题上。DeepSeek-R1在Level 1和Level 2上,经过10轮优化后,能在超过90%的任务中生成功能正确的内核。然而,剩余的错误内核几乎总是由于功能不正确而失败,这可能是因为正确性反馈的颗粒度不如执行失败信息细致。

0x04 worker.py 功能详解

worker.py 是 KernelAgent 系统中执行层面的核心组件,是单个工作进程的实现,负责执行特定的融合任务。

worker.py 实现了从问题描述到有效解决方案的完整迭代流程。它通过与 LLM 的交互、代码执行验证和错误反馈机制,不断改进解决方案,并参与多进程竞争以成为最终的获胜者。

4.1 核心职责

worker.py 的核心职责如下:

  • 从问题文件生成提示
  • 与 LLM 交互获取解决方案
  • 执行生成的代码并验证
  • 竞争成为获胜者(winner)

4.2 配置与参数

WorkerConfig

  • max_iters:最大迭代次数
  • llm_timeout_s:LLM 超时时间
  • run_timeout_s:代码执行超时时间
  • variant_index:使用的提示变体索引

并行控制参数

  • isolated:是否隔离环境
  • deny_network:是否禁止网络访问
  • cancel_event:取消事件用于提前终止

4.3 主要类和数据结构

WorkerState

WorkerState 数据类如下。

@dataclass
class WorkerState:
    worker_id: str
    iter_index: int
    last_response_id: str | None
    last_error: str | None
    passed: bool

字段说明

  • worker_id:工作进程标识符
  • iter_index:当前迭代次数
  • last_response_id:上次 LLM 响应 ID
  • last_error:上次错误信息
  • passed:是否已通过验证

Worker 类

Worker 定义如下。每个工作进程在自己的工作目录中运行,避免文件冲突,保留完整的执行历史。

class Worker:
    def __init__(
        self,
        cfg: WorkerConfig,
        problem_path: Path,
        winner_queue: Any,
        cancel_event: Any,
        on_delta: Callable[[str, None]] | None = None,
    ) -> None:
        self.cfg = cfg
        self.problem_path = problem_path
        self.winner_queue = winner_queue
        self.cancel_event = cancel_event
        self.on_delta = on_delta
        self.logger = setup_file_logger(
            cfg.workspace_dir / "logs" / "worker.log", name=f"worker-{cfg.worker_id}"
        )
        self.dirs = _ensure_dirs(cfg.workspace_dir)

Worker 的业务逻辑在 run 函数中实现。

代码提取

  • 从 LLM 响应中提取 Python 代码
  • 验证代码格式
extracted = extract_single_python_file(result.get("output_text", ""))

代码验证

首先会做重复检测, signature = ops 列表 + input shapes + output shapes + weight shapes + layout + dtype,JSON 序列化后比较。

sha = sha256_of_code(extracted.code)
status, owner = register_digest(
    self.cfg.shared_digests_dir, sha, self.cfg.worker_id, k
)
  • cross_worker_duplicate:不同工作进程间的重复
  • same_worker_duplicate:同一工作进程内的重复

此处子图去重(dedup)是基于shape signature而非代码文本。其理由如下:

  • 同一操作组合但不同变量名、注释、代码风格不应产生重复内核
  • 内核的性能和正确性主要取决于 shapes、dtypes、ops组合,而非源码文本
  • LLM可能对同一子图产出不同文本描述,但它们在计算语义上完全相同

代码执行

其次会调用 run_candidate 函数来执行:

  • 创建独立的运行环境
  • 执行生成的代码
  • 捕获 stdout/stderr
  • 验证结果

run_candidaterunner.py 中定义,其中验证标准如下:

  • 优先级 1:_PASS_REGEX.search(out_text)(查找 "PASS")
  • 优先级 2:_SENTINEL in out_text(查找 "ALL_TESTS_PASSED")
  • 退出码:必须为 0

获胜者竞争机制

获胜者队列为 winner_queue。

  • 队列大小为 1,确保只有一个胜利者
  • 第一个通过验证的解决方案获胜
self.winner_queue.put({
    "worker_id": self.cfg.worker_id,
    "iter": k,
    "validator": rr.validator_used,
    "runs_dir": str(run_root),
    "artifacts_dir": str(self.dirs["artifacts"]),
    }, timeout=0.1)

WorkerManager如何实现“个worker 成功后立刻停止其他worker“?具体如下:

使用multiprocessing.Event(self.success_event)作为跨进程信号:

  • 每个 worker 在每轮迭代开始前检查success_event.is_set(),若已设置则立即停止。
  • 主进程通过 self.result_queue轮询结果,发现某 worker 成功后调用 self.success_event.set()。
  • 其他还活着的 worker 进程在下一轮检查时发现事件已设置,返回 stopped_early:True
  • 主进程还会对所有worker 调用join(timeout=5.0),超时后terminate()

错误处理机制

错误上下文在 state.last_error 中传递。

state.last_error = f"RUN_FAIL:{rr.reason}\nSTDOUT_TAIL:\n{out_tail}\nSTDERR_TAIL:\n{err_tail}"

错误分类如下:

  • EXTRACT_FAIL:代码提取失败
  • RUN_FAIL:代码执行失败
  • 其他 LLM 相关错误

迭代恢复

  • 将错误信息作为上下文传递给下一次迭代
  • LLM 可以基于错误信息改进解决方案

与系统其他组件的交互

Orchestrator 的交互

  • 通过 winner_queue 向协调器报告成功

  • 响应 cancel_event 信号

与 Runner 的交互

  • 调用 run_candidate 执行代码

  • 接收 RunResult 验证结果

与 Prompting 系统的交互

  • 使用 render_prompt 生成提示

  • 利用 VARIANT_WORDINGS 的不同变体

4.4 代码

    def run(self) -> None:
        state = WorkerState(
            worker_id=self.cfg.worker_id,
            iter_index=0,
            last_response_id=None,
            last_error=None,
            passed=False,
        )
        _write_json(self.cfg.workspace_dir / "state.json", asdict(state))

        for k in range(1, self.cfg.max_iters + 1):
            if self.cancel_event.is_set():
                self.logger.info("cancel seen; exiting")
                return

            state.iter_index = k
            _write_json(self.cfg.workspace_dir / "state.json", asdict(state))

            # Render prompt
            rp = render_prompt(
                problem_path=self.problem_path,
                variant_index=self.cfg.variant_index,
                attempt_index=k,
                error_context=state.last_error,
                enable_reasoning_extras=self.cfg.enable_reasoning_extras,
                model_name=self.cfg.model,
            )
            prompt_path = self.dirs["prompts"] / f"iteration_{k}.txt"
            prompt_path.write_text(rp.user, encoding="utf-8")

            """
            Temporary MUX to support Relay while we migrate to OpenAI Responses
            API.

            Uses EventAdapter for OpenAI otherwise Provider inferface
            """
            provider = get_model_provider(self.cfg.model)
            if provider.name != "openai":
                # Call LLM directly using provider
                messages: list[dict[str, str]] = [
                    {"role": "system", "content": SYSTEM_PROMPT},
                    {"role": "user", "content": rp.user},
                ]
                try:
                    response = provider.get_response(
                        self.cfg.model, messages, max_tokens=16000, **rp.extras
                    )
                    result = {
                        "output_text": response.content or "",
                        "response_id": response.response_id or None,
                        "error": None,
                    }
                except Exception as e:
                    error = f"stream_error: {e.__class__.__name__}: {e}"
                    result = {
                        "output_text": "",
                        "response_id": None,
                        "error": error,
                    }
            else:
                # Stream via EventAdapter
                jsonl_path = self.dirs["responses"] / f"iteration_{k}.stream.jsonl"
                adapter = EventAdapter(
                    model=self.cfg.model,
                    store_responses=self.cfg.store_responses,
                    timeout_s=self.cfg.llm_timeout_s,
                    jsonl_path=jsonl_path,
                    stop_event=self.cancel_event,
                    on_delta=self.on_delta,
                )
                result = adapter.stream(
                    system_prompt=SYSTEM_PROMPT, user_prompt=rp.user, extras=rp.extras
                )

            state.last_response_id = result.get("response_id")
            _write_json(
                self.dirs["responses"] / f"iteration_{k}.final.json",
                result,
            )

            if self.cancel_event.is_set():
                self.logger.info("cancel after streaming; exiting")
                return

            # Extract code
            try:
                extracted = extract_single_python_file(result.get("output_text", ""))
            except Exception as e:
                state.last_error = f"EXTRACT_FAIL: {e}"
                self.logger.warning("iteration %d: extract failed: %s", k, e)
                continue

            iter_art_dir = self.dirs["artifacts"] / f"iteration_{k}"
            latest_dir = self.dirs["artifacts"] / "latest"
            iter_art_dir.mkdir(parents=True, exist_ok=True)
            (iter_art_dir / "code.py").write_text(extracted.code, encoding="utf-8")
            (latest_dir / "code.py").write_text(extracted.code, encoding="utf-8")

            # Dedup registration
            sha = sha256_of_code(extracted.code)
            status, owner = register_digest(
                self.cfg.shared_digests_dir, sha, self.cfg.worker_id, k
            )
            if status == "duplicate_cross_worker":
                self.logger.info("duplicate across workers (owner=%s); exiting", owner)
                return
            if status == "duplicate_same_worker":
                self.logger.info("duplicate in same worker; continuing")
                continue

            # Execute
            run_root = self.dirs["runs"] / f"iteration_{k}"
            run_root.mkdir(parents=True, exist_ok=True)
            rr = run_candidate(
                artifacts_code_path=latest_dir / "code.py",
                run_root=run_root,
                timeout_s=self.cfg.run_timeout_s,
                isolated=self.cfg.isolated,
                deny_network=self.cfg.deny_network,
                cancel_event=self.cancel_event,
            )

            if rr.passed:
                self.logger.info("PASS at iter %d via %s", k, rr.validator_used)
                try:
                    self.winner_queue.put(
                        {
                            "worker_id": self.cfg.worker_id,
                            "iter": k,
                            "validator": rr.validator_used,
                            "runs_dir": str(run_root),
                            "artifacts_dir": str(self.dirs["artifacts"]),
                        },
                        timeout=0.1,
                    )
                except queue.Full:
                    pass
                state.passed = True
                _write_json(self.cfg.workspace_dir / "state.json", asdict(state))
                return

            # Build ERROR_CONTEXT and continue
            out_tail = _tail_text(rr.stdout_path)
            err_tail = _tail_text(rr.stderr_path)
            state.last_error = f"RUN_FAIL: {rr.reason}\nSTDOUT_TAIL:\n{out_tail}\nSTDERR_TAIL:\n{err_tail}"
            _write_json(self.cfg.workspace_dir / "state.json", asdict(state))

        # Done all iterations
        self.logger.info("exhausted max_iters without PASS")

0x05 runner.py 功能详解

runner.py 是 KernelAgent 系统中安全执行候选程序的模块,负责在隔离环境中运行生成的代码并验证其正确

5.1 核心职责

runner.py 是 KernelAgent 系统中至关重要的验证组件,它提供了安全、隔离的执行环境来测试生成的代码。通过多层次的验证机制和严格的资源限制,确保只有真正正确的解决方案才能被认为是成功的,从而保证了整个系统的可靠性和安全性。

runner.py 的职责如下:

  • 创建安全的执行环境

  • 运行生成的 Python 代码

  • 捕获和分析执行结果

  • 根据预定义规则判断是否通过验证

5.2 主要类和数据结构

RunResult 数据类

RunResult 如下。

@dataclass (frozen=True)
class RunResult:
    rc: int                # 返回码
    passed: bool           # 是否通过验证
    validator_used: str    # 使用的验证器类型
    reason: str            # 原因说明
    t_started: float       # 开始时间
    t_finished: float      # 结束时间
    stdout_path: Path      # 标准输出文件路径
    stderr_path: Path      # 标准错误文件路径

验证常量

验证常量如下:

  • _SENTINEL = "ALL_TESTS_PASSED" # 通用哨兵字符串

  • _PASS_REGEX = re.compile (r"\bPASS\b") # PASS 正则表达式

5.3 核心功能函数

run_candidate 函数为 runner.py 的主要逻辑。

输入参数

输入参数如下:

  • artifacts_code_path: 候选代码文件路径

  • run_root: 执行根目录

  • timeout_s: 超时时间(秒)

  • isolated: 是否使用隔离环境

  • deny_network: 是否禁用网络

  • cancel_event:取消事件

执行流程

进程管理

在子进程中使用 procss group

        p = subprocess.Popen(
            argv,
            cwd=str(run_dir),
            stdin=subprocess.DEVNULL,
            stdout=f_out,
            stderr=f_err,
            start_new_session=True,
            env=env,
        )

创建执行环境

  • 创建唯一标识的执行目录

  • 避免命名冲突

    run_dir = (
        run_root
        / f"attempt_{int(time.time() * 1000)}_{os.getpid()}_{random.randint(0, 9999):04d}"
    )
    run_dir.mkdir(parents=True, exist_ok=False)

准备执行文件

  • 复制候选代码到执行目录

  • 使用 candidate_main.p 避免与 Python 标准库的 code.py 冲突

    exec_filename = "candidate_main.py"
    code_dst = run_dir / exec_filename
    st = artifacts_code_path.lstat()
    if not stat.S_ISREG(st.st_mode) or stat.S_ISLNK(st.st_mode):
        raise ValueError("artifacts_code_path must be a regular file (no symlink)")
    shutil.copy2(artifacts_code_path, code_dst)

网络限制

如果启用网络限制,在执行目录中创建 sitecustomize.py 来阻止网络连接

    if deny_network:
        _write_sitecustomize_block_network(run_dir)
  • 环境准备

  • 构建执行命令行参数

  • 应用环境变量白名单

    argv = [sys.executable, "-u"]
    if isolated and not deny_network:
        argv.append("-I")
    argv.append(exec_filename)

    env = _allowlist_env()

验证逻辑

在 run_candidate 函数中实现优先级验证顺序

    if rc == 0:
        # Prefer explicit run_tests PASS if present in stdout
        if _PASS_REGEX.search(out_text): # 最高优先级
            passed = True
            validator = "run_tests"
            reason = "run_tests printed PASS and exited 0"
        elif _SENTINEL in out_text: # 次优先级
            passed = True
            validator = "sentinel"
            reason = "sentinel ALL_TESTS_PASSED found and exited 0"
        else: # 基础条件
            passed = False
            validator = "unknown"
            if scan_truncated:
                reason = (
                    "rc==0 but neither PASS nor sentinel found (scan_truncated=true)"
                )
            else:
                reason = "rc==0 but neither PASS nor sentinel found"
    else:
        passed = False
        reason = f"nonzero exit code: {rc}"

5.4 安全机制

_allowlist_env() 完成了环境隔离

def _allowlist_env() -> dict[str, str]:
    allow: dict[str, str] = {}
    for k, v in os.environ.items():
        if k == "PATH":
            allow[k] = v
        elif k == "PYTHONPATH":
            # sanitize: keep only absolute, existing dirs
            parts = [p for p in v.split(os.pathsep) if p]
            keep: list[str] = []
            for p in parts:
                try:
                    pp = os.path.abspath(p)
                    if os.path.isabs(pp) and os.path.isdir(pp):
                        keep.append(pp)
                except Exception:
                    continue
            if keep:
                allow["PYTHONPATH"] = os.pathsep.join(keep)
        elif k.startswith("LANG") or k.startswith("LC_"):
            allow[k] = v
    # Determinism and small resource caps
    allow["PYTHONHASHSEED"] = "0"
    allow.setdefault("OMP_NUM_THREADS", "1")
    allow.setdefault("MKL_NUM_THREADS", "1")
    allow.setdefault("OPENBLAS_NUM_THREADS", "1")
    return allow

_write_sitecustomize_block_network 完成了网络限制

def _write_sitecustomize_block_network(dst_dir: Path) -> None:
    code = (
        "import socket\n"
        "def _block(*a, **k):\n    raise RuntimeError('network disabled')\n"
        "class _Blocked(socket.socket):\n    def connect(self, *a, **k):\n        _block()\n    def connect_ex(self, *a, **k):\n        _block()\n"
        "socket.socket = _Blocked\n"
        "socket.create_connection = _block\n"
    )
    (dst_dir / "sitecustomize.py").write_text(code, encoding="utf-8")

时间限制

  • 使用 subprocess.TimeoutExpired 处理超时

  • 支持外部取消事件

5.5 与系统其他组件的交互

与 Worker 的交互

  • worker.py 调用 run_candidate 来验证生成的代码

  • 根据 RunResult 决定是否成功

与 Prompting 系统的交互

  • 验证结果影响后续提示的错误上下文

  • 未通过的执行结果会被用作改进提示的依据

5.4 代码

此处一个特点是:验证通过的判定逻辑是rc== 0 &&(PASS ∈ stdout || sentinel ∈ stdout)。为什么仅rc==0不够?其原因如下:

  • 误判风险:某些Python 程序即使内部测试失败也可能以rc=0退出(如try-except吞掉了异常)

  • 明确的成功信号:要求主动打印PASS或ALL_TESTS_PASSED是一种“肯定声明"(positiveassertion),证明测试逻辑确实执行了且判定为通过。

  • 防止空程序通过:一个空文件或只有pass的文件也会rc=0,但不会输出这些标识。

  • 渐进容错:两种 validator(run_tests识别PASS,sentinel识别ALL_TESTS_PASSED)提供了灵活性,适应不同格式的测试脚本。

此设计实质上是双重确认:进程退出码+输出内容共同决定验证结果。

def run_candidate(
    artifacts_code_path: Path,
    run_root: Path,
    timeout_s: int,
    isolated: bool,
    deny_network: bool,
    cancel_event: "threading.Event" | None = None,
) -> RunResult:
    """
    Execute a candidate program in a fresh run directory under run_root.
    - Copies artifacts_code_path to run_dir/candidate_main.py
    - Runs [sys.executable, '-u', 'candidate_main.py'] with optional -I (isolated)
    - If deny_network, injects sitecustomize.py to block sockets and do NOT use -I
    - Captures stdout/stderr to files; kills on timeout or cancel_event
    - Classifies pass/fail according to design precedence
    """
    run_dir = (
        run_root
        / f"attempt_{int(time.time() * 1000)}_{os.getpid()}_{random.randint(0, 9999):04d}"
    )
    run_dir.mkdir(parents=True, exist_ok=False)

    # Prepare working files. We intentionally avoid the name "code.py" here because
    # Python's stdlib exposes a module with that name, and PyTorch's import stack
    # (via pdb -> code) would accidentally load the candidate file instead of the
    # stdlib module, leading to partially initialised torch packages.
    exec_filename = "candidate_main.py"
    code_dst = run_dir / exec_filename
    st = artifacts_code_path.lstat()
    if not stat.S_ISREG(st.st_mode) or stat.S_ISLNK(st.st_mode):
        raise ValueError("artifacts_code_path must be a regular file (no symlink)")
    shutil.copy2(artifacts_code_path, code_dst)

    if deny_network:
        _write_sitecustomize_block_network(run_dir)

    stdout_path = run_dir / "stdout.txt"
    stderr_path = run_dir / "stderr.txt"

    argv = [sys.executable, "-u"]
    if isolated and not deny_network:
        argv.append("-I")
    argv.append(exec_filename)

    env = _allowlist_env()

    t_started = time.time()
    (run_dir / "EXEC_STARTED").write_text(str(t_started), encoding="utf-8")

    # Run the candidate (via subprocess or multiprocess)
    rc, t_finished = (
        _run_candidate(
            run_dir,
            argv,
            env,
            stdout_path,
            stderr_path,
            t_started,
            timeout_s,
            cancel_event,
        )
        if os.getenv("FUSER_COMPOSE_USE_SYS_EXECUTABLE", "1") == "1"
        else _run_candidate_multiprocess(
            exec_filename,
            run_dir,
            argv,
            env,
            stdout_path,
            stderr_path,
            t_started,
            timeout_s,
            cancel_event,
        )
    )

    # Read bounded scan for classification
    out_text, scan_truncated = _read_all_text_bounded(stdout_path, MAX_SCAN_BYTES)

    # Classification
    passed = False
    validator = "unknown"
    reason = ""
    if rc == 0:
        # Prefer explicit run_tests PASS if present in stdout
        if _PASS_REGEX.search(out_text):
            passed = True
            validator = "run_tests"
            reason = "run_tests printed PASS and exited 0"
        elif _SENTINEL in out_text:
            passed = True
            validator = "sentinel"
            reason = "sentinel ALL_TESTS_PASSED found and exited 0"
        else:
            passed = False
            validator = "unknown"
            if scan_truncated:
                reason = (
                    "rc==0 but neither PASS nor sentinel found (scan_truncated=true)"
                )
            else:
                reason = "rc==0 but neither PASS nor sentinel found"
    else:
        passed = False
        reason = f"nonzero exit code: {rc}"

    return RunResult(
        rc=rc,
        passed=passed,
        validator_used=validator,
        reason=reason,
        t_started=t_started,
        t_finished=t_finished,
        stdout_path=stdout_path,
        stderr_path=stderr_path,
    )

0xFF 参考

KernelFalcon: Autonomous GPU Kernel Generation via Deep Agents

基于 LLM 的 GPU 内核代码自动生成相关工作

Automating GPU Kernel Generation with DeepSeek-R1 and Inference Time Scaling

DeepSeek-R1自写CUDA内核跑分屠榜!斯坦福学霸狂飙GPU编程自动化挑战人类

CUDA、Triton 内核生成现状追踪

大模型能否为不同硬件平台生成高性能内核?南大、浙大提出跨平台内核生成评测框架MultiKernelBench

AKG kernel Agent:利用multi-agent进行kernel的生成和迁移

AKG KERNEL AGENT: A MULTI-AGENT FRAMEWORK FOR CROSS-PLATFORM KERNEL SYNTHESIS

AIKG -- 基于AI驱动的算子生成器

RL 猛刷 CUDA 核:CUDA-L1: Improving CUDA Optimization via Contrastive Reinforcement Learning

MultiKernelBench: A Multi-Platform Benchmark for Kernel Generation

Ouyang A, Guo S, Arora S, et al. Kernelbench: Can llms write efficient gpu kernels?[J]. arXiv preprint arXiv:2502.10517, 2025.

Baronio, Carlo, et al. "Kevin: Multi-turn rl for generating cuda kernels."arXiv preprint arXiv:2507.11948(2025).

Li, Shangzhan, et al. "Autotriton: Automatic triton programming with reinforcement learning in llms."arXiv preprint arXiv:2507.05687(2025).

Li, Jianling, et al. "Tritonbench: Benchmarking large language model capabilities for generating triton operators."Findings of the Association for Computational Linguistics: ACL 2025. 2025.

Tjarko Lange, Robert, et al. "Towards Robust Agentic CUDA Kernel Benchmarking, Verification, and Optimization."arXiv e-prints(2025): arXiv-2509.

Chen, Wentao, et al. "CUDA-LLM: LLMs Can Write Efficient CUDA Kernels."arXiv preprint arXiv:2506.09092(2025).


文章来源:https://www.cnblogs.com/rossiXYZ/p/20061518
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:jacktools123@163.com进行投诉反馈,一经查实,立即删除!

标签:

相关文章

本站推荐

标签云