neuron/docs/design.md
2026-06-03 21:40:29 +08:00

16 KiB
Raw Permalink Blame History

NeuroGraph 架构设计文档

类生物能量梯度深度学习框架


1. 项目定位

NeuroGraph 探索的是区别于传统反向传播Backpropagation, BP的全新训练模式。核心思想是大脑中的神经元不会做链式求导它们通过局部活动模式和能量最小化来学习。

1.1 与传统 BP 的区别

维度 传统 BP NeuroGraph
梯度来源 解析链式法则 能量景观松弛过程
信息流 前向 + 反向传播 双向活动传播(自由相 + 推动相)
学习信号 Loss 梯度 能量差 (∂E_nudged - ∂E_free)
拓扑结构 人工设计固定 自主剪枝 + 结构探索
调控机制 全局优化器 局部规则 + 全局奖罚调制

1.2 核心隐喻

把网络看作一个物理系统:给定输入后,神经元活动会像弹簧一样"松弛"到能量最低状态。学习不是计算梯度,而是比较"自由松弛态"和"被目标推动后的松弛态"之间的能量差。


2. 技术栈

2.1 核心依赖

JAX         → 自动微分 + JIT 编译 + 向量化
Equinox     → 神经网络模块系统eqx.Module
Optax       → 优化器链SGD, Adam, 自定义变换)
Gymnasium   → RL 环境标准接口
JaxPruner   → 稀疏掩码基础设施

2.2 为什么不用 C/C++ 手撸

  • 没有自动微分,能量梯度验证极其困难
  • 没有现成的 RL 环境Gymnasium 有 200+ 环境)
  • 没有 GPU 加速的张量运算库
  • 研究迭代速度会慢 10 倍以上

JAX 本身就是"手撸友好"的:它是 NumPy + 自动微分 + JIT不强迫你接受任何高层抽象。你可以从零定义能量函数、松弛动力学、学习规则同时拥有 GPU 加速和自动微分。

2.3 为什么不用 Julia

Julia 的科学计算生态很强,但:

  • 生物可塑性学习几乎没有现成参考实现
  • GPU 生态不如 JAX 成熟
  • 社区规模小得多

3. 核心算法

3.1 能量函数

网络的状态(活动量 + 权重)定义了一个能量景观。能量越低,状态越"好"。

E(θ, a, x, y) = w_data · E_data + w_reg · E_reg + w_sparse · E_sparse + w_struct · E_struct
公式 作用
E_data ‖f_θ(x) - y‖² 预测误差
E_reg λ · ‖θ‖² 权重衰减
E_sparse γ · ‖a‖₁ 激活稀疏性
E_struct η · ‖A‖² 拓扑复杂度

3.2 平衡传播Equilibrium Propagation, EqProp

这是 Scellier & Bengio (2017) 提出的算法,核心步骤:

Step 1: 自由相Free Phase
    输入 x 固定到可见层
    活动量沿能量梯度松弛a ← a - η_a · ∂E/∂a
    直到收敛 → 记录 a_free, E_free

Step 2: 推动相Nudged Phase
    输出层被目标 y 轻微推动a_out += β · (y - a_out)
    继续松弛直到收敛 → 记录 a_nudged, E_nudged

Step 3: 权重更新
    Δθ = -(∂E(θ, a_nudged)/∂θ - ∂E(θ, a_free)/∂θ) / β

关键洞察:权重更新量正比于两个平衡态之间的能量梯度差。当 β → 0 时EqProp 等价于 BP。

3.3 三因子学习规则(奖罚调制)

在生物大脑中,突触可塑性受神经递质(如多巴胺)调制:

Δw_ij = reward · e_ij · (pre_i · post_j)
  • pre_i:突触前神经元活动
  • post_j:突触后神经元活动
  • e_ijeligibility trace资格迹记录历史活动关联
  • reward:全局奖罚信号(+1 奖励,-1 惩罚)

这使得学习信号不再是解析梯度,而是"行为后果"的函数。

3.4 自主剪枝 + 结构可塑性

网络自主决定哪些连接有用:

删除边:|w_ij| < threshold 且 活动贡献度低
添加边:节点 i 和 j 的活动互信息高 → 建立新连接

这模拟了大脑的突触修剪synaptic pruning和突触发生synaptogenesis

3.5 可微图结构搜索DARTS 风格)

网络拓扑用邻接矩阵 A[i,j] 表示,每个位置有一组可学习的架构参数 α_ij

P(edge_ij exists) = sigmoid(α_ij)

架构参数和网络权重联合优化,训练结束后离散化:保留 sigmoid(α_ij) > 0.5 的边。


4. 架构设计

4.1 模块划分

src/neurograph/
│
├── core/                    # 核心引擎(最底层)
│   ├── energy.py            # 能量函数定义
│   ├── equilibrium.py       # 松弛循环动力学
│   ├── neurons.py           # 神经元模型
│   ├── surrogate.py         # 代理梯度
│   └── topology.py          # 图拓扑表示
│
├── learning/                # 学习规则(依赖 core
│   ├── eqprop.py            # 平衡传播更新
│   ├── reward.py            # 奖罚调制
│   ├── hebbian.py           # Hebbian 局部规则
│   └── optimizer.py         # Optax 包装
│
├── pruning/                 # 剪枝(依赖 core + learning
│   ├── magnitude.py         # 幅值剪枝
│   ├── activity.py          # 活跃度剪枝
│   └── structural.py        # 结构可塑性
│
├── architecture/            # 结构探索(依赖 pruning
│   ├── graph_nas.py         # 可微图搜索
│   ├── mutation.py          # 拓扑变异
│   └── supernet.py          # 超网络
│
├── env/                     # 环境接口
│   ├── wrapper.py           # Gymnasium 适配
│   └── simple/              # 内置简单环境
│
└── utils/                   # 工具
    ├── visualization.py     # 可视化
    ├── logging.py           # 日志
    └── metrics.py           # 指标

4.2 依赖关系

core ← learning ← pruning ← architecture
              ↑
              env (提供交互数据)

4.3 网络拓扑表示

用邻接矩阵 [N, N] 表示网络连接:

class NetworkTopology(eqx.Module):
    adjacency: jnp.ndarray        # [N, N] 邻接矩阵
    node_types: jnp.ndarray       # [N] 节点类型 (0=visible, 1=hidden, 2=output)
    edge_ops: jnp.ndarray         # [N, N] 操作类型索引
    node_states: jnp.ndarray      # [N, state_dim] 当前活动

选择理由

  • 可微(矩阵元素是连续值)
  • 直接支持矩阵乘法做消息传递
  • 支持 DARTS 风格优化(学习 edge_weights
  • 大网络时可切换为 jax.experimental.sparse.BCSR

4.4 数据流:完整训练步

┌─────────────────────────────────────────────────────────┐
│                    训练步 (training step)                  │
├─────────────────────────────────────────────────────────┤
│                                                          │
│  ① 输入钳位                                                │
│     activities[visible] = observations                   │
│                                                          │
│  ② 自由相松弛 (jax.lax.while_loop)                         │
│     for t in max_iters:                                  │
│       grad = jax.grad(energy, argnums=1)(params, acts)    │
│       acts -= activity_lr * grad                         │
│       if converged: break                                │
│     → activities_free, energy_free                       │
│                                                          │
│  ③ 推动相松弛                                              │
│     acts[output] += β * (target - acts[output])          │
│     同②松弛                                               │
│     → activities_nudged, energy_nudged                   │
│                                                          │
│  ④ 权重更新 (EqProp)                                      │
│     grad_free = jax.grad(energy, 0)(params, acts_free)   │
│     grad_nudged = jax.grad(energy, 0)(params, acts_nud)  │
│     delta_w = -(grad_nudged - grad_free) / β             │
│     delta_w *= reward_signal  # 如有奖罚调制              │
│     params = optax.apply_updates(params, delta_w)        │
│                                                          │
│  ⑤ 剪枝 (每 N 步)                                          │
│     mask = pruning_scheduler(params, activities, step)   │
│     params *= mask                                       │
│                                                          │
│  ⑥ 结构探索 (每 M 步)                                      │
│     topology = graph_nas.evaluate_and_mutate(topology)   │
│                                                          │
│  ⑦ 环境步进 (RL 模式)                                      │
│     action = read_output(acts_nudged)                    │
│     next_obs, reward, done = env.step(action)            │
│                                                          │
└─────────────────────────────────────────────────────────┘

整个 ②-④ 步用 jax.jit 编译为一个函数。⑤⑥ 步在 JIT 外部执行(改变参数树形状)。


5. 开发阶段

Phase 1能量学习引擎核心

目标MNIST >95% 准确率

任务 文件 说明
能量函数 core/energy.py 复合能量定义
松弛循环 core/equilibrium.py jax.lax.while_loop
神经元模型 core/neurons.py rate-based, LIF
代理梯度 core/surrogate.py jax.custom_jvp
拓扑表示 core/topology.py 邻接矩阵 + 节点类型
EqProp 更新 learning/eqprop.py 核心学习规则
Optax 包装 learning/optimizer.py 可组合优化器
MNIST 实验 experiments/phase1/mnist_eqprop.py 基线验证

Phase 2奖罚系统 + 环境集成

目标CartPole-v1 奖励 >400

任务 文件 说明
三因子规则 learning/reward.py reward × eligibility × gradient
资格迹 learning/reward.py 历史活动关联追踪
环境包装 env/wrapper.py Gymnasium → energy agent
CartPole 实验 experiments/phase2/cartpole_reward.py RL 基线

Phase 3自主剪枝 + 结构探索

目标CIFAR-10 >80%,参数量 <50%

任务 文件 说明
幅值剪枝 pruning/magnitude.py 权重幅值阈值
活跃度剪枝 pruning/activity.py 活动贡献度
结构可塑性 pruning/structural.py 增删边
图 NAS architecture/graph_nas.py 可微架构搜索
拓扑变异 architecture/mutation.py 离散变异算子
CIFAR 实验 experiments/phase3/pruning_dynamics.py 剪枝 + NAS

Phase 4全系统集成

  • 端到端流水线
  • 多基准测试 + 消融实验
  • 文档 + 教程

6. 关键 JAX API 使用指南

6.1 松弛循环

def relax(energy_fn, params, activities, inputs, targets, nudging, config):
    """松弛到能量最低态。"""

    def energy_at(acts):
        return energy_fn(params, acts, inputs, targets, nudging, config)

    def cond(state):
        iters, _, curr_energy = state
        new_e = energy_at(state[1])
        return (iters < config.max_iters) & (jnp.abs(curr_energy - new_e) > config.tol)

    def body(state):
        iters, acts, _ = state
        grad = jax.grad(energy_at)(acts)
        new_acts = acts - config.activity_lr * grad
        new_e = energy_at(new_acts)
        return iters + 1, new_acts, new_e

    init = (0, activities, jnp.inf)
    return jax.lax.while_loop(cond, body, init)

6.2 代理梯度LIF 脉冲神经元)

@jax.custom_jvp
def spike_fn(x, threshold=1.0):
    return jnp.where(x >= threshold, 1.0, 0.0)

@spike_fn.defjvp
def spike_fn_jvp(primals, tangents):
    x, = primals
    dx, = tangents
    # 高斯代理梯度
    surrogate = jnp.exp(-0.5 * ((x - threshold) / 0.5) ** 2)
    return spike_fn(x, threshold), surrogate * dx

6.3 编译策略

# 整个训练步 JIT 编译
@jax.jit
def train_step(params, activities, inputs, targets, reward, opt_state):
    # 自由相
    acts_free, E_free, _ = relax(energy_fn, params, activities, inputs, None, 0.0, cfg)
    # 推动相
    acts_nudged, E_nudged, _ = relax(energy_fn, params, acts_free, inputs, targets, β, cfg)
    # 权重更新
    grads = eqprop_grads(params, acts_free, acts_nudged, β)
    grads *= reward  # 奖罚调制
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, acts_nudged, opt_state, {"E_free": E_free, "E_nudged": E_nudged}

7. 验证策略

7.1 单元测试

测试 验证
能量单调下降 松弛过程中能量严格递减
EqProp ≈ BP 小网络上 EqProp 收敛到与 BP 相同解
代理梯度 spike_fn 的 jax.grad 非零且平滑
剪枝保精度 50% 幅值剪枝后准确率 >80% 原始
拓扑合法性 所有生成的图都是有效 DAG

7.2 集成测试基准

基准 目标
MNIST (EqProp) >95% 准确率, 20 epochs
MNIST (reward-modulated) >93% 准确率
CartPole-v1 500 episodes 内奖励 >400
CIFAR-10 + 剪枝 >80% 准确率, 参数量减半

7.3 与 BP 对比

基准 BP 基线 NeuroGraph 通过标准
MNIST >99% >97% 误差 <2%
CIFAR-10 >85% >80% 误差 <5%
CartPole ~500 episodes ~800 episodes 2x 内可解

7.4 诊断指标

  • 能量轨迹:能量 vs 松弛迭代(应单调递减)
  • 自由-推动间隙E_nudged - E_free(应与 loss 正相关)
  • 权重更新幅值‖Δw‖ 每层(应稳定,不爆炸/消失)
  • 稀疏度演化:零权重比例随训练步数
  • 拓扑演化:节点/边数量随训练步数

8. 风险与缓解

风险 严重性 缓解措施
平衡不收敛 阻尼更新 + 梯度裁剪 + 自适应 activity_lr
能量函数局部极小 小推动力 β + 良好初始化 + 复合正则
结构搜索时 JAX 形状不匹配 固定最大邻接矩阵 + 掩码
剪枝不可逆退化 渐进式剪枝 + 结构可塑性恢复边
奖罚信号太稀疏 奖励塑形 + 好奇心内在奖励
深层平衡梯度消失 残差连接 + 逐层松弛

9. 参考资源

核心论文

  1. Scellier & Bengio (2017). Equilibrium Propagation: Energy-based Learning for Feedforward and Recurrent Networks
  2. LeCun (2006). A Tutorial on Energy-Based Learning
  3. Millidge, Tsang, & Rao (2021). Predictive Coding: a Theoretical and Experimental Review
  4. Liu et al. (2019). DARTS: Differentiable Architecture Search
  5. Bi & Poo (2019). Biologically plausible learning in deep networks

参考实现

项目 语言 链接
Equilibrium-Propagation PyTorch github.com/Laborieux-Axel/Equilibrium-Propagation
EB-JEPA PyTorch github.com/facebookresearch/eb_jepa
BioTorch PyTorch github.com/jsalbert/biotorch
JPC (Predictive Coding) JAX github.com/thebuckleylab/jpc
snnTorch PyTorch github.com/jeshraghian/snntorch
JaxPruner JAX github.com/google-research/jaxpruner
awesome-ebm 列表 github.com/yataobian/awesome-ebm
awesome-biologically-motivated-learning 列表 github.com/jsalbert/awesome-biologically-motivated-learning

10. 术语表

术语 英文 说明
能量函数 Energy Function 定义网络状态的标量函数,越低越好
平衡传播 Equilibrium Propagation (EqProp) 通过两个平衡态的能量差更新权重
自由相 Free Phase 仅有输入,无目标推动的松弛过程
推动相 Nudged Phase 输出层被目标轻微推动后的松弛过程
松弛 Relaxation 活动量沿能量梯度下降的过程
三因子规则 Three-factor Rule pre × post × neuromodulator 的突触更新
资格迹 Eligibility Trace 记录历史活动关联的衰减记忆
代理梯度 Surrogate Gradient 不可导函数的近似梯度(如脉冲函数)
结构可塑性 Structural Plasticity 增删突触连接的能力
图结构搜索 Graph Neural Architecture Search 在图拓扑空间中搜索最优架构