全部学科
Python全栈
python
NodeJS全栈
nodejs
小程序首页
📅 2026-05-19 10 分钟 ✍️ juanwangdev

Python AST抽象语法树

AST(Abstract Syntax Tree)是Python代码的树形表示,可用于代码分析、转换和生成。

ast 模块基础

Python
import ast

# 解析源码为AST
source = "
def greet(name):
    return f"Hello, {name}"

greet("World")
"

tree = ast.parse(source)
print(ast.dump(tree, indent=2))
Python
# 简单表达式AST
expr = ast.parse('2 + 3 * 4', mode='eval')
print(ast.dump(expr))
# Expression(body=BinOp(left=Constant(value=2), op=Add(), right=BinOp(...)))

AST节点类型

Python
import ast

# 常用节点类型
source = "
x = 10
y = x + 5
print(y)
"

tree = ast.parse(source)

for node in ast.walk(tree):
    print(f"{node.__class__.__name__}: {ast.dump(node)[:50]}...")

# 常见节点:
# Module, FunctionDef, Assign, Name, Constant, BinOp, Call, Expr

主要节点类型:

类别节点类型说明
语句Assign, AugAssign赋值
语句FunctionDef, ClassDef函数/类定义
语句If, For, While控制流
语句Return, Raise返回/异常
表达式BinOp, UnaryOp二元/一元运算
表达式Call函数调用
表达式Name, Attribute变量/属性访问
表达式Constant, List, Dict常量/容器

遍历AST

Python
import ast

class Visitor(ast.NodeVisitor):
    "自定义AST遍历器"

    def visit_FunctionDef(self, node):
        print(f"函数: {node.name}")
        self.generic_visit(node)  # 继续遍历子节点

    def visit_Assign(self, node):
        targets = [t.id if hasattr(t, 'id') else str(t) for t in node.targets]
        print(f"赋值: {targets}")
        self.generic_visit(node)

    def visit_Call(self, node):
        func_name = getattr(node.func, 'id', str(node.func))
        print(f"调用: {func_name}")
        self.generic_visit(node)

source = "
def add(a, b):
    return a + b

x = add(1, 2)
"

tree = ast.parse(source)
Visitor().visit(tree)
Python
# ast.NodeTransformer 修改AST
class SimpleTransformer(ast.NodeTransformer):
    "简化变量名"

    def visit_Name(self, node):
        if node.id.startswith('var_'):
            node.id = node.id[4:]  # 移除 var_ 前缀
        return node

    def visit_Constant(self, node):
        # 所有字符串加前缀
        if isinstance(node.value, str):
            node.value = f"prefix_{node.value}"
        return node

source = "
var_x = "hello"
var_y = var_x + " world"
"

tree = ast.parse(source)
modified = SimpleTransformer().visit(tree)
ast.fix_missing_locations(modified)  # 修复行号信息

code = compile(modified, '<modified>', 'exec')
exec(code)
# prefix_hello prefix_hello prefix_ world(注意:字符串都加了前缀)

代码分析与检查

Python
import ast

class CodeAnalyzer(ast.NodeVisitor):
    "分析代码特征"

    def __init__(self):
        self.functions = []
        self.variables = set()
        self.imports = []

    def visit_FunctionDef(self, node):
        self.functions.append({
            'name': node.name,
            'args': [arg.arg for arg in node.args.args],
            'lineno': node.lineno
        })
        self.generic_visit(node)

    def visit_Name(self, node):
        if isinstance(node.ctx, ast.Store):
            self.variables.add(node.id)
        self.generic_visit(node)

    def visit_Import(self, node):
        for alias in node.names:
            self.imports.append(alias.name)
        self.generic_visit(node)

    def visit_ImportFrom(self, node):
        module = node.module or ''
        for alias in node.names:
            self.imports.append(f"{module}.{alias.name}")
        self.generic_visit(node)

source = "
import os
from sys import path

def process(data):
    result = data + 1
    return result

x = process(10)
"

tree = ast.parse(source)
analyzer = CodeAnalyzer()
analyzer.visit(tree)

print(f"函数: {analyzer.functions}")
print(f"变量: {analyzer.variables}")
print(f"导入: {analyzer.imports}")

AST安全检查

Python
import ast

class SafeCodeChecker(ast.NodeVisitor):
    "检查代码安全性"

    dangerous_calls = ['eval', 'exec', 'compile', 'open', '__import__']

    def __init__(self):
        self.is_safe = True
        self.violations = []

    def visit_Call(self, node):
        if isinstance(node.func, ast.Name):
            if node.func.id in self.dangerous_calls:
                self.is_safe = False
                self.violations.append(f"禁止调用: {node.func.id} @ {node.lineno}")

        if isinstance(node.func, ast.Attribute):
            attr = node.func.attr
            if attr in self.dangerous_calls:
                self.is_safe = False
                self.violations.append(f"禁止调用属性: {attr} @ {node.lineno}")

        self.generic_visit(node)

def check_code(source):
    tree = ast.parse(source)
    checker = SafeCodeChecker()
    checker.visit(tree)
    return checker.is_safe, checker.violations

# 测试
safe_source = "x = 1 + 2"
unsafe_source = "eval('1 + 2')"

print(check_code(safe_source))      # (True, [])
print(check_code(unsafe_source))    # (False, ['禁止调用: eval @ 1'])

AST代码生成

Python
import ast

# 手动构建AST
def create_add_function():
    "创建加法函数的AST"

    # 函数参数
    args = ast.arguments(
        posonlyargs=[], args=[ast.arg(arg='a'), ast.arg(arg='b')],
        kwonlyargs=[], kw_defaults=[], defaults=[]
    )

    # 函数体:return a + b
    body = [
        ast.Return(
            value=ast.BinOp(
                left=ast.Name(id='a', ctx=ast.Load()),
                op=ast.Add(),
                right=ast.Name(id='b', ctx=ast.Load())
            )
        )
    ]

    # 函数定义
    func = ast.FunctionDef(
        name='add',
        args=args,
        body=body,
        decorator_list=[],
        returns=None
    )

    # 模块
    module = ast.Module(body=[func], type_ignores=[])

    # 修复位置信息
    ast.fix_missing_locations(module)

    return module

# 编译并执行
tree = create_add_function()
code = compile(tree, '<generated>', 'exec')

namespace = {}
exec(code, namespace)
print(namespace['add'](3, 4))  # 7

从AST到源码

Python
import ast

# ast.unparse(Python 3.9+)
source = "
def factorial(n):
    if n <= 1:
        return 1
    return n * factorial(n - 1)
"

tree = ast.parse(source)
restored = ast.unparse(tree)
print(restored)  # 重建的源码(可能有格式差异)

# 第三方库:astor(更精确的重建)
# pip install astor
# import astor
# astor.to_source(tree)

要点总结

  1. **ast.parse()**将源码解析为AST树
  2. **ast.NodeVisitor遍历AST,ast.NodeTransformer**修改AST
  3. 常用节点:FunctionDef, Assign, Call, BinOp, Name, Constant
  4. **ast.fix_missing_locations()**修改AST后必须调用
  5. ast.unparse()(Python 3.9+)可将AST还原为源码

📝 发现内容有误?点击此处直接编辑

← 上一篇 Python数据结构性能优化
下一篇 → Python __init_subclass__
想查看更多题目和详细解析?
小程序提供完整的题库、模拟考试和详细解析
马上就来

长按或扫描二维码,立即体验

扫码体验小程序
马上就来
使用微信扫描二维码
立即体验完整题库