# NeuroGraph 架构设计文档 > 类生物能量梯度深度学习框架 --- ## 1. 项目定位 NeuroGraph 探索的是区别于传统反向传播(Backpropagation, BP)的全新训练模式。核心思想是:大脑中的神经元不会做链式求导,它们通过局部活动模式和能量最小化来学习。 ### 1.1 与传统 BP 的区别 | 维度 | 传统 BP | NeuroGraph | |------|---------|-----------| | 梯度来源 | 解析链式法则 | 能量景观松弛过程 | | 信息流 | 前向 + 反向传播 | 双向活动传播(自由相 + 推动相) | | 学习信号 | Loss 梯度 | 能量差 (∂E_nudged - ∂E_free) | | 拓扑结构 | 人工设计固定 | 自主剪枝 + 结构探索 | | 调控机制 | 全局优化器 | 局部规则 + 全局奖罚调制 | ### 1.2 核心隐喻 把网络看作一个物理系统:给定输入后,神经元活动会像弹簧一样"松弛"到能量最低状态。学习不是计算梯度,而是比较"自由松弛态"和"被目标推动后的松弛态"之间的能量差。 --- ## 2. 技术栈 ### 2.1 核心依赖 ``` JAX → 自动微分 + JIT 编译 + 向量化 Equinox → 神经网络模块系统(eqx.Module) Optax → 优化器链(SGD, Adam, 自定义变换) Gymnasium → RL 环境标准接口 JaxPruner → 稀疏掩码基础设施 ``` ### 2.2 为什么不用 C/C++ 手撸 - 没有自动微分,能量梯度验证极其困难 - 没有现成的 RL 环境(Gymnasium 有 200+ 环境) - 没有 GPU 加速的张量运算库 - 研究迭代速度会慢 10 倍以上 JAX 本身就是"手撸友好"的:它是 NumPy + 自动微分 + JIT,不强迫你接受任何高层抽象。你可以从零定义能量函数、松弛动力学、学习规则,同时拥有 GPU 加速和自动微分。 ### 2.3 为什么不用 Julia Julia 的科学计算生态很强,但: - 生物可塑性学习几乎没有现成参考实现 - GPU 生态不如 JAX 成熟 - 社区规模小得多 --- ## 3. 核心算法 ### 3.1 能量函数 网络的状态(活动量 + 权重)定义了一个能量景观。能量越低,状态越"好"。 ``` E(θ, a, x, y) = w_data · E_data + w_reg · E_reg + w_sparse · E_sparse + w_struct · E_struct ``` | 项 | 公式 | 作用 | |----|------|------| | E_data | ‖f_θ(x) - y‖² | 预测误差 | | E_reg | λ · ‖θ‖² | 权重衰减 | | E_sparse | γ · ‖a‖₁ | 激活稀疏性 | | E_struct | η · ‖A‖² | 拓扑复杂度 | ### 3.2 平衡传播(Equilibrium Propagation, EqProp) 这是 Scellier & Bengio (2017) 提出的算法,核心步骤: ``` Step 1: 自由相(Free Phase) 输入 x 固定到可见层 活动量沿能量梯度松弛:a ← a - η_a · ∂E/∂a 直到收敛 → 记录 a_free, E_free Step 2: 推动相(Nudged Phase) 输出层被目标 y 轻微推动:a_out += β · (y - a_out) 继续松弛直到收敛 → 记录 a_nudged, E_nudged Step 3: 权重更新 Δθ = -(∂E(θ, a_nudged)/∂θ - ∂E(θ, a_free)/∂θ) / β ``` **关键洞察**:权重更新量正比于两个平衡态之间的能量梯度差。当 β → 0 时,EqProp 等价于 BP。 ### 3.3 三因子学习规则(奖罚调制) 在生物大脑中,突触可塑性受神经递质(如多巴胺)调制: ``` Δw_ij = reward · e_ij · (pre_i · post_j) ``` - `pre_i`:突触前神经元活动 - `post_j`:突触后神经元活动 - `e_ij`:eligibility trace(资格迹),记录历史活动关联 - `reward`:全局奖罚信号(+1 奖励,-1 惩罚) 这使得学习信号不再是解析梯度,而是"行为后果"的函数。 ### 3.4 自主剪枝 + 结构可塑性 网络自主决定哪些连接有用: ``` 删除边:|w_ij| < threshold 且 活动贡献度低 添加边:节点 i 和 j 的活动互信息高 → 建立新连接 ``` 这模拟了大脑的突触修剪(synaptic pruning)和突触发生(synaptogenesis)。 ### 3.5 可微图结构搜索(DARTS 风格) 网络拓扑用邻接矩阵 A[i,j] 表示,每个位置有一组可学习的架构参数 α_ij: ``` P(edge_ij exists) = sigmoid(α_ij) ``` 架构参数和网络权重联合优化,训练结束后离散化:保留 sigmoid(α_ij) > 0.5 的边。 --- ## 4. 架构设计 ### 4.1 模块划分 ``` src/neurograph/ │ ├── core/ # 核心引擎(最底层) │ ├── energy.py # 能量函数定义 │ ├── equilibrium.py # 松弛循环动力学 │ ├── neurons.py # 神经元模型 │ ├── surrogate.py # 代理梯度 │ └── topology.py # 图拓扑表示 │ ├── learning/ # 学习规则(依赖 core) │ ├── eqprop.py # 平衡传播更新 │ ├── reward.py # 奖罚调制 │ ├── hebbian.py # Hebbian 局部规则 │ └── optimizer.py # Optax 包装 │ ├── pruning/ # 剪枝(依赖 core + learning) │ ├── magnitude.py # 幅值剪枝 │ ├── activity.py # 活跃度剪枝 │ └── structural.py # 结构可塑性 │ ├── architecture/ # 结构探索(依赖 pruning) │ ├── graph_nas.py # 可微图搜索 │ ├── mutation.py # 拓扑变异 │ └── supernet.py # 超网络 │ ├── env/ # 环境接口 │ ├── wrapper.py # Gymnasium 适配 │ └── simple/ # 内置简单环境 │ └── utils/ # 工具 ├── visualization.py # 可视化 ├── logging.py # 日志 └── metrics.py # 指标 ``` ### 4.2 依赖关系 ``` core ← learning ← pruning ← architecture ↑ env (提供交互数据) ``` ### 4.3 网络拓扑表示 用邻接矩阵 `[N, N]` 表示网络连接: ```python class NetworkTopology(eqx.Module): adjacency: jnp.ndarray # [N, N] 邻接矩阵 node_types: jnp.ndarray # [N] 节点类型 (0=visible, 1=hidden, 2=output) edge_ops: jnp.ndarray # [N, N] 操作类型索引 node_states: jnp.ndarray # [N, state_dim] 当前活动 ``` **选择理由**: - 可微(矩阵元素是连续值) - 直接支持矩阵乘法做消息传递 - 支持 DARTS 风格优化(学习 edge_weights) - 大网络时可切换为 `jax.experimental.sparse.BCSR` ### 4.4 数据流:完整训练步 ``` ┌─────────────────────────────────────────────────────────┐ │ 训练步 (training step) │ ├─────────────────────────────────────────────────────────┤ │ │ │ ① 输入钳位 │ │ activities[visible] = observations │ │ │ │ ② 自由相松弛 (jax.lax.while_loop) │ │ for t in max_iters: │ │ grad = jax.grad(energy, argnums=1)(params, acts) │ │ acts -= activity_lr * grad │ │ if converged: break │ │ → activities_free, energy_free │ │ │ │ ③ 推动相松弛 │ │ acts[output] += β * (target - acts[output]) │ │ 同②松弛 │ │ → activities_nudged, energy_nudged │ │ │ │ ④ 权重更新 (EqProp) │ │ grad_free = jax.grad(energy, 0)(params, acts_free) │ │ grad_nudged = jax.grad(energy, 0)(params, acts_nud) │ │ delta_w = -(grad_nudged - grad_free) / β │ │ delta_w *= reward_signal # 如有奖罚调制 │ │ params = optax.apply_updates(params, delta_w) │ │ │ │ ⑤ 剪枝 (每 N 步) │ │ mask = pruning_scheduler(params, activities, step) │ │ params *= mask │ │ │ │ ⑥ 结构探索 (每 M 步) │ │ topology = graph_nas.evaluate_and_mutate(topology) │ │ │ │ ⑦ 环境步进 (RL 模式) │ │ action = read_output(acts_nudged) │ │ next_obs, reward, done = env.step(action) │ │ │ └─────────────────────────────────────────────────────────┘ ``` 整个 ②-④ 步用 `jax.jit` 编译为一个函数。⑤⑥ 步在 JIT 外部执行(改变参数树形状)。 --- ## 5. 开发阶段 ### Phase 1:能量学习引擎(核心) **目标**:MNIST >95% 准确率 | 任务 | 文件 | 说明 | |------|------|------| | 能量函数 | `core/energy.py` | 复合能量定义 | | 松弛循环 | `core/equilibrium.py` | jax.lax.while_loop | | 神经元模型 | `core/neurons.py` | rate-based, LIF | | 代理梯度 | `core/surrogate.py` | jax.custom_jvp | | 拓扑表示 | `core/topology.py` | 邻接矩阵 + 节点类型 | | EqProp 更新 | `learning/eqprop.py` | 核心学习规则 | | Optax 包装 | `learning/optimizer.py` | 可组合优化器 | | MNIST 实验 | `experiments/phase1/mnist_eqprop.py` | 基线验证 | ### Phase 2:奖罚系统 + 环境集成 **目标**:CartPole-v1 奖励 >400 | 任务 | 文件 | 说明 | |------|------|------| | 三因子规则 | `learning/reward.py` | reward × eligibility × gradient | | 资格迹 | `learning/reward.py` | 历史活动关联追踪 | | 环境包装 | `env/wrapper.py` | Gymnasium → energy agent | | CartPole 实验 | `experiments/phase2/cartpole_reward.py` | RL 基线 | ### Phase 3:自主剪枝 + 结构探索 **目标**:CIFAR-10 >80%,参数量 <50% | 任务 | 文件 | 说明 | |------|------|------| | 幅值剪枝 | `pruning/magnitude.py` | 权重幅值阈值 | | 活跃度剪枝 | `pruning/activity.py` | 活动贡献度 | | 结构可塑性 | `pruning/structural.py` | 增删边 | | 图 NAS | `architecture/graph_nas.py` | 可微架构搜索 | | 拓扑变异 | `architecture/mutation.py` | 离散变异算子 | | CIFAR 实验 | `experiments/phase3/pruning_dynamics.py` | 剪枝 + NAS | ### Phase 4:全系统集成 - 端到端流水线 - 多基准测试 + 消融实验 - 文档 + 教程 --- ## 6. 关键 JAX API 使用指南 ### 6.1 松弛循环 ```python def relax(energy_fn, params, activities, inputs, targets, nudging, config): """松弛到能量最低态。""" def energy_at(acts): return energy_fn(params, acts, inputs, targets, nudging, config) def cond(state): iters, _, curr_energy = state new_e = energy_at(state[1]) return (iters < config.max_iters) & (jnp.abs(curr_energy - new_e) > config.tol) def body(state): iters, acts, _ = state grad = jax.grad(energy_at)(acts) new_acts = acts - config.activity_lr * grad new_e = energy_at(new_acts) return iters + 1, new_acts, new_e init = (0, activities, jnp.inf) return jax.lax.while_loop(cond, body, init) ``` ### 6.2 代理梯度(LIF 脉冲神经元) ```python @jax.custom_jvp def spike_fn(x, threshold=1.0): return jnp.where(x >= threshold, 1.0, 0.0) @spike_fn.defjvp def spike_fn_jvp(primals, tangents): x, = primals dx, = tangents # 高斯代理梯度 surrogate = jnp.exp(-0.5 * ((x - threshold) / 0.5) ** 2) return spike_fn(x, threshold), surrogate * dx ``` ### 6.3 编译策略 ```python # 整个训练步 JIT 编译 @jax.jit def train_step(params, activities, inputs, targets, reward, opt_state): # 自由相 acts_free, E_free, _ = relax(energy_fn, params, activities, inputs, None, 0.0, cfg) # 推动相 acts_nudged, E_nudged, _ = relax(energy_fn, params, acts_free, inputs, targets, β, cfg) # 权重更新 grads = eqprop_grads(params, acts_free, acts_nudged, β) grads *= reward # 奖罚调制 updates, opt_state = optimizer.update(grads, opt_state) params = optax.apply_updates(params, updates) return params, acts_nudged, opt_state, {"E_free": E_free, "E_nudged": E_nudged} ``` --- ## 7. 验证策略 ### 7.1 单元测试 | 测试 | 验证 | |------|------| | 能量单调下降 | 松弛过程中能量严格递减 | | EqProp ≈ BP | 小网络上 EqProp 收敛到与 BP 相同解 | | 代理梯度 | spike_fn 的 jax.grad 非零且平滑 | | 剪枝保精度 | 50% 幅值剪枝后准确率 >80% 原始 | | 拓扑合法性 | 所有生成的图都是有效 DAG | ### 7.2 集成测试基准 | 基准 | 目标 | |------|------| | MNIST (EqProp) | >95% 准确率, 20 epochs | | MNIST (reward-modulated) | >93% 准确率 | | CartPole-v1 | 500 episodes 内奖励 >400 | | CIFAR-10 + 剪枝 | >80% 准确率, 参数量减半 | ### 7.3 与 BP 对比 | 基准 | BP 基线 | NeuroGraph | 通过标准 | |------|---------|-----------|----------| | MNIST | >99% | >97% | 误差 <2% | | CIFAR-10 | >85% | >80% | 误差 <5% | | CartPole | ~500 episodes | ~800 episodes | 2x 内可解 | ### 7.4 诊断指标 - **能量轨迹**:能量 vs 松弛迭代(应单调递减) - **自由-推动间隙**:`E_nudged - E_free`(应与 loss 正相关) - **权重更新幅值**:`‖Δw‖` 每层(应稳定,不爆炸/消失) - **稀疏度演化**:零权重比例随训练步数 - **拓扑演化**:节点/边数量随训练步数 --- ## 8. 风险与缓解 | 风险 | 严重性 | 缓解措施 | |------|--------|---------| | 平衡不收敛 | 高 | 阻尼更新 + 梯度裁剪 + 自适应 activity_lr | | 能量函数局部极小 | 高 | 小推动力 β + 良好初始化 + 复合正则 | | 结构搜索时 JAX 形状不匹配 | 高 | 固定最大邻接矩阵 + 掩码 | | 剪枝不可逆退化 | 中 | 渐进式剪枝 + 结构可塑性恢复边 | | 奖罚信号太稀疏 | 中 | 奖励塑形 + 好奇心内在奖励 | | 深层平衡梯度消失 | 高 | 残差连接 + 逐层松弛 | --- ## 9. 参考资源 ### 核心论文 1. Scellier & Bengio (2017). *Equilibrium Propagation: Energy-based Learning for Feedforward and Recurrent Networks* 2. LeCun (2006). *A Tutorial on Energy-Based Learning* 3. Millidge, Tsang, & Rao (2021). *Predictive Coding: a Theoretical and Experimental Review* 4. Liu et al. (2019). *DARTS: Differentiable Architecture Search* 5. Bi & Poo (2019). *Biologically plausible learning in deep networks* ### 参考实现 | 项目 | 语言 | 链接 | |------|------|------| | Equilibrium-Propagation | PyTorch | github.com/Laborieux-Axel/Equilibrium-Propagation | | EB-JEPA | PyTorch | github.com/facebookresearch/eb_jepa | | BioTorch | PyTorch | github.com/jsalbert/biotorch | | JPC (Predictive Coding) | JAX | github.com/thebuckleylab/jpc | | snnTorch | PyTorch | github.com/jeshraghian/snntorch | | JaxPruner | JAX | github.com/google-research/jaxpruner | | awesome-ebm | 列表 | github.com/yataobian/awesome-ebm | | awesome-biologically-motivated-learning | 列表 | github.com/jsalbert/awesome-biologically-motivated-learning | --- ## 10. 术语表 | 术语 | 英文 | 说明 | |------|------|------| | 能量函数 | Energy Function | 定义网络状态的标量函数,越低越好 | | 平衡传播 | Equilibrium Propagation (EqProp) | 通过两个平衡态的能量差更新权重 | | 自由相 | Free Phase | 仅有输入,无目标推动的松弛过程 | | 推动相 | Nudged Phase | 输出层被目标轻微推动后的松弛过程 | | 松弛 | Relaxation | 活动量沿能量梯度下降的过程 | | 三因子规则 | Three-factor Rule | pre × post × neuromodulator 的突触更新 | | 资格迹 | Eligibility Trace | 记录历史活动关联的衰减记忆 | | 代理梯度 | Surrogate Gradient | 不可导函数的近似梯度(如脉冲函数) | | 结构可塑性 | Structural Plasticity | 增删突触连接的能力 | | 图结构搜索 | Graph Neural Architecture Search | 在图拓扑空间中搜索最优架构 |