16 KiB
NeuroGraph 架构设计文档
类生物能量梯度深度学习框架
1. 项目定位
NeuroGraph 探索的是区别于传统反向传播(Backpropagation, BP)的全新训练模式。核心思想是:大脑中的神经元不会做链式求导,它们通过局部活动模式和能量最小化来学习。
1.1 与传统 BP 的区别
| 维度 | 传统 BP | NeuroGraph |
|---|---|---|
| 梯度来源 | 解析链式法则 | 能量景观松弛过程 |
| 信息流 | 前向 + 反向传播 | 双向活动传播(自由相 + 推动相) |
| 学习信号 | Loss 梯度 | 能量差 (∂E_nudged - ∂E_free) |
| 拓扑结构 | 人工设计固定 | 自主剪枝 + 结构探索 |
| 调控机制 | 全局优化器 | 局部规则 + 全局奖罚调制 |
1.2 核心隐喻
把网络看作一个物理系统:给定输入后,神经元活动会像弹簧一样"松弛"到能量最低状态。学习不是计算梯度,而是比较"自由松弛态"和"被目标推动后的松弛态"之间的能量差。
2. 技术栈
2.1 核心依赖
JAX → 自动微分 + JIT 编译 + 向量化
Equinox → 神经网络模块系统(eqx.Module)
Optax → 优化器链(SGD, Adam, 自定义变换)
Gymnasium → RL 环境标准接口
JaxPruner → 稀疏掩码基础设施
2.2 为什么不用 C/C++ 手撸
- 没有自动微分,能量梯度验证极其困难
- 没有现成的 RL 环境(Gymnasium 有 200+ 环境)
- 没有 GPU 加速的张量运算库
- 研究迭代速度会慢 10 倍以上
JAX 本身就是"手撸友好"的:它是 NumPy + 自动微分 + JIT,不强迫你接受任何高层抽象。你可以从零定义能量函数、松弛动力学、学习规则,同时拥有 GPU 加速和自动微分。
2.3 为什么不用 Julia
Julia 的科学计算生态很强,但:
- 生物可塑性学习几乎没有现成参考实现
- GPU 生态不如 JAX 成熟
- 社区规模小得多
3. 核心算法
3.1 能量函数
网络的状态(活动量 + 权重)定义了一个能量景观。能量越低,状态越"好"。
E(θ, a, x, y) = w_data · E_data + w_reg · E_reg + w_sparse · E_sparse + w_struct · E_struct
| 项 | 公式 | 作用 |
|---|---|---|
| E_data | ‖f_θ(x) - y‖² | 预测误差 |
| E_reg | λ · ‖θ‖² | 权重衰减 |
| E_sparse | γ · ‖a‖₁ | 激活稀疏性 |
| E_struct | η · ‖A‖² | 拓扑复杂度 |
3.2 平衡传播(Equilibrium Propagation, EqProp)
这是 Scellier & Bengio (2017) 提出的算法,核心步骤:
Step 1: 自由相(Free Phase)
输入 x 固定到可见层
活动量沿能量梯度松弛:a ← a - η_a · ∂E/∂a
直到收敛 → 记录 a_free, E_free
Step 2: 推动相(Nudged Phase)
输出层被目标 y 轻微推动:a_out += β · (y - a_out)
继续松弛直到收敛 → 记录 a_nudged, E_nudged
Step 3: 权重更新
Δθ = -(∂E(θ, a_nudged)/∂θ - ∂E(θ, a_free)/∂θ) / β
关键洞察:权重更新量正比于两个平衡态之间的能量梯度差。当 β → 0 时,EqProp 等价于 BP。
3.3 三因子学习规则(奖罚调制)
在生物大脑中,突触可塑性受神经递质(如多巴胺)调制:
Δw_ij = reward · e_ij · (pre_i · post_j)
pre_i:突触前神经元活动post_j:突触后神经元活动e_ij:eligibility trace(资格迹),记录历史活动关联reward:全局奖罚信号(+1 奖励,-1 惩罚)
这使得学习信号不再是解析梯度,而是"行为后果"的函数。
3.4 自主剪枝 + 结构可塑性
网络自主决定哪些连接有用:
删除边:|w_ij| < threshold 且 活动贡献度低
添加边:节点 i 和 j 的活动互信息高 → 建立新连接
这模拟了大脑的突触修剪(synaptic pruning)和突触发生(synaptogenesis)。
3.5 可微图结构搜索(DARTS 风格)
网络拓扑用邻接矩阵 A[i,j] 表示,每个位置有一组可学习的架构参数 α_ij:
P(edge_ij exists) = sigmoid(α_ij)
架构参数和网络权重联合优化,训练结束后离散化:保留 sigmoid(α_ij) > 0.5 的边。
4. 架构设计
4.1 模块划分
src/neurograph/
│
├── core/ # 核心引擎(最底层)
│ ├── energy.py # 能量函数定义
│ ├── equilibrium.py # 松弛循环动力学
│ ├── neurons.py # 神经元模型
│ ├── surrogate.py # 代理梯度
│ └── topology.py # 图拓扑表示
│
├── learning/ # 学习规则(依赖 core)
│ ├── eqprop.py # 平衡传播更新
│ ├── reward.py # 奖罚调制
│ ├── hebbian.py # Hebbian 局部规则
│ └── optimizer.py # Optax 包装
│
├── pruning/ # 剪枝(依赖 core + learning)
│ ├── magnitude.py # 幅值剪枝
│ ├── activity.py # 活跃度剪枝
│ └── structural.py # 结构可塑性
│
├── architecture/ # 结构探索(依赖 pruning)
│ ├── graph_nas.py # 可微图搜索
│ ├── mutation.py # 拓扑变异
│ └── supernet.py # 超网络
│
├── env/ # 环境接口
│ ├── wrapper.py # Gymnasium 适配
│ └── simple/ # 内置简单环境
│
└── utils/ # 工具
├── visualization.py # 可视化
├── logging.py # 日志
└── metrics.py # 指标
4.2 依赖关系
core ← learning ← pruning ← architecture
↑
env (提供交互数据)
4.3 网络拓扑表示
用邻接矩阵 [N, N] 表示网络连接:
class NetworkTopology(eqx.Module):
adjacency: jnp.ndarray # [N, N] 邻接矩阵
node_types: jnp.ndarray # [N] 节点类型 (0=visible, 1=hidden, 2=output)
edge_ops: jnp.ndarray # [N, N] 操作类型索引
node_states: jnp.ndarray # [N, state_dim] 当前活动
选择理由:
- 可微(矩阵元素是连续值)
- 直接支持矩阵乘法做消息传递
- 支持 DARTS 风格优化(学习 edge_weights)
- 大网络时可切换为
jax.experimental.sparse.BCSR
4.4 数据流:完整训练步
┌─────────────────────────────────────────────────────────┐
│ 训练步 (training step) │
├─────────────────────────────────────────────────────────┤
│ │
│ ① 输入钳位 │
│ activities[visible] = observations │
│ │
│ ② 自由相松弛 (jax.lax.while_loop) │
│ for t in max_iters: │
│ grad = jax.grad(energy, argnums=1)(params, acts) │
│ acts -= activity_lr * grad │
│ if converged: break │
│ → activities_free, energy_free │
│ │
│ ③ 推动相松弛 │
│ acts[output] += β * (target - acts[output]) │
│ 同②松弛 │
│ → activities_nudged, energy_nudged │
│ │
│ ④ 权重更新 (EqProp) │
│ grad_free = jax.grad(energy, 0)(params, acts_free) │
│ grad_nudged = jax.grad(energy, 0)(params, acts_nud) │
│ delta_w = -(grad_nudged - grad_free) / β │
│ delta_w *= reward_signal # 如有奖罚调制 │
│ params = optax.apply_updates(params, delta_w) │
│ │
│ ⑤ 剪枝 (每 N 步) │
│ mask = pruning_scheduler(params, activities, step) │
│ params *= mask │
│ │
│ ⑥ 结构探索 (每 M 步) │
│ topology = graph_nas.evaluate_and_mutate(topology) │
│ │
│ ⑦ 环境步进 (RL 模式) │
│ action = read_output(acts_nudged) │
│ next_obs, reward, done = env.step(action) │
│ │
└─────────────────────────────────────────────────────────┘
整个 ②-④ 步用 jax.jit 编译为一个函数。⑤⑥ 步在 JIT 外部执行(改变参数树形状)。
5. 开发阶段
Phase 1:能量学习引擎(核心)
目标:MNIST >95% 准确率
| 任务 | 文件 | 说明 |
|---|---|---|
| 能量函数 | core/energy.py |
复合能量定义 |
| 松弛循环 | core/equilibrium.py |
jax.lax.while_loop |
| 神经元模型 | core/neurons.py |
rate-based, LIF |
| 代理梯度 | core/surrogate.py |
jax.custom_jvp |
| 拓扑表示 | core/topology.py |
邻接矩阵 + 节点类型 |
| EqProp 更新 | learning/eqprop.py |
核心学习规则 |
| Optax 包装 | learning/optimizer.py |
可组合优化器 |
| MNIST 实验 | experiments/phase1/mnist_eqprop.py |
基线验证 |
Phase 2:奖罚系统 + 环境集成
目标:CartPole-v1 奖励 >400
| 任务 | 文件 | 说明 |
|---|---|---|
| 三因子规则 | learning/reward.py |
reward × eligibility × gradient |
| 资格迹 | learning/reward.py |
历史活动关联追踪 |
| 环境包装 | env/wrapper.py |
Gymnasium → energy agent |
| CartPole 实验 | experiments/phase2/cartpole_reward.py |
RL 基线 |
Phase 3:自主剪枝 + 结构探索
目标:CIFAR-10 >80%,参数量 <50%
| 任务 | 文件 | 说明 |
|---|---|---|
| 幅值剪枝 | pruning/magnitude.py |
权重幅值阈值 |
| 活跃度剪枝 | pruning/activity.py |
活动贡献度 |
| 结构可塑性 | pruning/structural.py |
增删边 |
| 图 NAS | architecture/graph_nas.py |
可微架构搜索 |
| 拓扑变异 | architecture/mutation.py |
离散变异算子 |
| CIFAR 实验 | experiments/phase3/pruning_dynamics.py |
剪枝 + NAS |
Phase 4:全系统集成
- 端到端流水线
- 多基准测试 + 消融实验
- 文档 + 教程
6. 关键 JAX API 使用指南
6.1 松弛循环
def relax(energy_fn, params, activities, inputs, targets, nudging, config):
"""松弛到能量最低态。"""
def energy_at(acts):
return energy_fn(params, acts, inputs, targets, nudging, config)
def cond(state):
iters, _, curr_energy = state
new_e = energy_at(state[1])
return (iters < config.max_iters) & (jnp.abs(curr_energy - new_e) > config.tol)
def body(state):
iters, acts, _ = state
grad = jax.grad(energy_at)(acts)
new_acts = acts - config.activity_lr * grad
new_e = energy_at(new_acts)
return iters + 1, new_acts, new_e
init = (0, activities, jnp.inf)
return jax.lax.while_loop(cond, body, init)
6.2 代理梯度(LIF 脉冲神经元)
@jax.custom_jvp
def spike_fn(x, threshold=1.0):
return jnp.where(x >= threshold, 1.0, 0.0)
@spike_fn.defjvp
def spike_fn_jvp(primals, tangents):
x, = primals
dx, = tangents
# 高斯代理梯度
surrogate = jnp.exp(-0.5 * ((x - threshold) / 0.5) ** 2)
return spike_fn(x, threshold), surrogate * dx
6.3 编译策略
# 整个训练步 JIT 编译
@jax.jit
def train_step(params, activities, inputs, targets, reward, opt_state):
# 自由相
acts_free, E_free, _ = relax(energy_fn, params, activities, inputs, None, 0.0, cfg)
# 推动相
acts_nudged, E_nudged, _ = relax(energy_fn, params, acts_free, inputs, targets, β, cfg)
# 权重更新
grads = eqprop_grads(params, acts_free, acts_nudged, β)
grads *= reward # 奖罚调制
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return params, acts_nudged, opt_state, {"E_free": E_free, "E_nudged": E_nudged}
7. 验证策略
7.1 单元测试
| 测试 | 验证 |
|---|---|
| 能量单调下降 | 松弛过程中能量严格递减 |
| EqProp ≈ BP | 小网络上 EqProp 收敛到与 BP 相同解 |
| 代理梯度 | spike_fn 的 jax.grad 非零且平滑 |
| 剪枝保精度 | 50% 幅值剪枝后准确率 >80% 原始 |
| 拓扑合法性 | 所有生成的图都是有效 DAG |
7.2 集成测试基准
| 基准 | 目标 |
|---|---|
| MNIST (EqProp) | >95% 准确率, 20 epochs |
| MNIST (reward-modulated) | >93% 准确率 |
| CartPole-v1 | 500 episodes 内奖励 >400 |
| CIFAR-10 + 剪枝 | >80% 准确率, 参数量减半 |
7.3 与 BP 对比
| 基准 | BP 基线 | NeuroGraph | 通过标准 |
|---|---|---|---|
| MNIST | >99% | >97% | 误差 <2% |
| CIFAR-10 | >85% | >80% | 误差 <5% |
| CartPole | ~500 episodes | ~800 episodes | 2x 内可解 |
7.4 诊断指标
- 能量轨迹:能量 vs 松弛迭代(应单调递减)
- 自由-推动间隙:
E_nudged - E_free(应与 loss 正相关) - 权重更新幅值:
‖Δw‖每层(应稳定,不爆炸/消失) - 稀疏度演化:零权重比例随训练步数
- 拓扑演化:节点/边数量随训练步数
8. 风险与缓解
| 风险 | 严重性 | 缓解措施 |
|---|---|---|
| 平衡不收敛 | 高 | 阻尼更新 + 梯度裁剪 + 自适应 activity_lr |
| 能量函数局部极小 | 高 | 小推动力 β + 良好初始化 + 复合正则 |
| 结构搜索时 JAX 形状不匹配 | 高 | 固定最大邻接矩阵 + 掩码 |
| 剪枝不可逆退化 | 中 | 渐进式剪枝 + 结构可塑性恢复边 |
| 奖罚信号太稀疏 | 中 | 奖励塑形 + 好奇心内在奖励 |
| 深层平衡梯度消失 | 高 | 残差连接 + 逐层松弛 |
9. 参考资源
核心论文
- Scellier & Bengio (2017). Equilibrium Propagation: Energy-based Learning for Feedforward and Recurrent Networks
- LeCun (2006). A Tutorial on Energy-Based Learning
- Millidge, Tsang, & Rao (2021). Predictive Coding: a Theoretical and Experimental Review
- Liu et al. (2019). DARTS: Differentiable Architecture Search
- Bi & Poo (2019). Biologically plausible learning in deep networks
参考实现
| 项目 | 语言 | 链接 |
|---|---|---|
| Equilibrium-Propagation | PyTorch | github.com/Laborieux-Axel/Equilibrium-Propagation |
| EB-JEPA | PyTorch | github.com/facebookresearch/eb_jepa |
| BioTorch | PyTorch | github.com/jsalbert/biotorch |
| JPC (Predictive Coding) | JAX | github.com/thebuckleylab/jpc |
| snnTorch | PyTorch | github.com/jeshraghian/snntorch |
| JaxPruner | JAX | github.com/google-research/jaxpruner |
| awesome-ebm | 列表 | github.com/yataobian/awesome-ebm |
| awesome-biologically-motivated-learning | 列表 | github.com/jsalbert/awesome-biologically-motivated-learning |
10. 术语表
| 术语 | 英文 | 说明 |
|---|---|---|
| 能量函数 | Energy Function | 定义网络状态的标量函数,越低越好 |
| 平衡传播 | Equilibrium Propagation (EqProp) | 通过两个平衡态的能量差更新权重 |
| 自由相 | Free Phase | 仅有输入,无目标推动的松弛过程 |
| 推动相 | Nudged Phase | 输出层被目标轻微推动后的松弛过程 |
| 松弛 | Relaxation | 活动量沿能量梯度下降的过程 |
| 三因子规则 | Three-factor Rule | pre × post × neuromodulator 的突触更新 |
| 资格迹 | Eligibility Trace | 记录历史活动关联的衰减记忆 |
| 代理梯度 | Surrogate Gradient | 不可导函数的近似梯度(如脉冲函数) |
| 结构可塑性 | Structural Plasticity | 增删突触连接的能力 |
| 图结构搜索 | Graph Neural Architecture Search | 在图拓扑空间中搜索最优架构 |