Python AST抽象语法树
AST(Abstract Syntax Tree)是Python代码的树形表示,可用于代码分析、转换和生成。
ast 模块基础
Python
import ast
# 解析源码为AST
source = "
def greet(name):
return f"Hello, {name}"
greet("World")
"
tree = ast.parse(source)
print(ast.dump(tree, indent=2))
Python
# 简单表达式AST
expr = ast.parse('2 + 3 * 4', mode='eval')
print(ast.dump(expr))
# Expression(body=BinOp(left=Constant(value=2), op=Add(), right=BinOp(...)))
AST节点类型
Python
import ast
# 常用节点类型
source = "
x = 10
y = x + 5
print(y)
"
tree = ast.parse(source)
for node in ast.walk(tree):
print(f"{node.__class__.__name__}: {ast.dump(node)[:50]}...")
# 常见节点:
# Module, FunctionDef, Assign, Name, Constant, BinOp, Call, Expr
主要节点类型:
| 类别 | 节点类型 | 说明 |
|---|---|---|
| 语句 | Assign, AugAssign | 赋值 |
| 语句 | FunctionDef, ClassDef | 函数/类定义 |
| 语句 | If, For, While | 控制流 |
| 语句 | Return, Raise | 返回/异常 |
| 表达式 | BinOp, UnaryOp | 二元/一元运算 |
| 表达式 | Call | 函数调用 |
| 表达式 | Name, Attribute | 变量/属性访问 |
| 表达式 | Constant, List, Dict | 常量/容器 |
遍历AST
Python
import ast
class Visitor(ast.NodeVisitor):
"自定义AST遍历器"
def visit_FunctionDef(self, node):
print(f"函数: {node.name}")
self.generic_visit(node) # 继续遍历子节点
def visit_Assign(self, node):
targets = [t.id if hasattr(t, 'id') else str(t) for t in node.targets]
print(f"赋值: {targets}")
self.generic_visit(node)
def visit_Call(self, node):
func_name = getattr(node.func, 'id', str(node.func))
print(f"调用: {func_name}")
self.generic_visit(node)
source = "
def add(a, b):
return a + b
x = add(1, 2)
"
tree = ast.parse(source)
Visitor().visit(tree)
Python
# ast.NodeTransformer 修改AST
class SimpleTransformer(ast.NodeTransformer):
"简化变量名"
def visit_Name(self, node):
if node.id.startswith('var_'):
node.id = node.id[4:] # 移除 var_ 前缀
return node
def visit_Constant(self, node):
# 所有字符串加前缀
if isinstance(node.value, str):
node.value = f"prefix_{node.value}"
return node
source = "
var_x = "hello"
var_y = var_x + " world"
"
tree = ast.parse(source)
modified = SimpleTransformer().visit(tree)
ast.fix_missing_locations(modified) # 修复行号信息
code = compile(modified, '<modified>', 'exec')
exec(code)
# prefix_hello prefix_hello prefix_ world(注意:字符串都加了前缀)
代码分析与检查
Python
import ast
class CodeAnalyzer(ast.NodeVisitor):
"分析代码特征"
def __init__(self):
self.functions = []
self.variables = set()
self.imports = []
def visit_FunctionDef(self, node):
self.functions.append({
'name': node.name,
'args': [arg.arg for arg in node.args.args],
'lineno': node.lineno
})
self.generic_visit(node)
def visit_Name(self, node):
if isinstance(node.ctx, ast.Store):
self.variables.add(node.id)
self.generic_visit(node)
def visit_Import(self, node):
for alias in node.names:
self.imports.append(alias.name)
self.generic_visit(node)
def visit_ImportFrom(self, node):
module = node.module or ''
for alias in node.names:
self.imports.append(f"{module}.{alias.name}")
self.generic_visit(node)
source = "
import os
from sys import path
def process(data):
result = data + 1
return result
x = process(10)
"
tree = ast.parse(source)
analyzer = CodeAnalyzer()
analyzer.visit(tree)
print(f"函数: {analyzer.functions}")
print(f"变量: {analyzer.variables}")
print(f"导入: {analyzer.imports}")
AST安全检查
Python
import ast
class SafeCodeChecker(ast.NodeVisitor):
"检查代码安全性"
dangerous_calls = ['eval', 'exec', 'compile', 'open', '__import__']
def __init__(self):
self.is_safe = True
self.violations = []
def visit_Call(self, node):
if isinstance(node.func, ast.Name):
if node.func.id in self.dangerous_calls:
self.is_safe = False
self.violations.append(f"禁止调用: {node.func.id} @ {node.lineno}")
if isinstance(node.func, ast.Attribute):
attr = node.func.attr
if attr in self.dangerous_calls:
self.is_safe = False
self.violations.append(f"禁止调用属性: {attr} @ {node.lineno}")
self.generic_visit(node)
def check_code(source):
tree = ast.parse(source)
checker = SafeCodeChecker()
checker.visit(tree)
return checker.is_safe, checker.violations
# 测试
safe_source = "x = 1 + 2"
unsafe_source = "eval('1 + 2')"
print(check_code(safe_source)) # (True, [])
print(check_code(unsafe_source)) # (False, ['禁止调用: eval @ 1'])
AST代码生成
Python
import ast
# 手动构建AST
def create_add_function():
"创建加法函数的AST"
# 函数参数
args = ast.arguments(
posonlyargs=[], args=[ast.arg(arg='a'), ast.arg(arg='b')],
kwonlyargs=[], kw_defaults=[], defaults=[]
)
# 函数体:return a + b
body = [
ast.Return(
value=ast.BinOp(
left=ast.Name(id='a', ctx=ast.Load()),
op=ast.Add(),
right=ast.Name(id='b', ctx=ast.Load())
)
)
]
# 函数定义
func = ast.FunctionDef(
name='add',
args=args,
body=body,
decorator_list=[],
returns=None
)
# 模块
module = ast.Module(body=[func], type_ignores=[])
# 修复位置信息
ast.fix_missing_locations(module)
return module
# 编译并执行
tree = create_add_function()
code = compile(tree, '<generated>', 'exec')
namespace = {}
exec(code, namespace)
print(namespace['add'](3, 4)) # 7
从AST到源码
Python
import ast
# ast.unparse(Python 3.9+)
source = "
def factorial(n):
if n <= 1:
return 1
return n * factorial(n - 1)
"
tree = ast.parse(source)
restored = ast.unparse(tree)
print(restored) # 重建的源码(可能有格式差异)
# 第三方库:astor(更精确的重建)
# pip install astor
# import astor
# astor.to_source(tree)
要点总结
- **
ast.parse()**将源码解析为AST树 - **
ast.NodeVisitor遍历AST,ast.NodeTransformer**修改AST - 常用节点:FunctionDef, Assign, Call, BinOp, Name, Constant
- **
ast.fix_missing_locations()**修改AST后必须调用 ast.unparse()(Python 3.9+)可将AST还原为源码
📝 发现内容有误?点击此处直接编辑