448 lines
16 KiB
Markdown
448 lines
16 KiB
Markdown
# NeuroGraph 架构设计文档
|
||
|
||
> 类生物能量梯度深度学习框架
|
||
|
||
---
|
||
|
||
## 1. 项目定位
|
||
|
||
NeuroGraph 探索的是区别于传统反向传播(Backpropagation, BP)的全新训练模式。核心思想是:大脑中的神经元不会做链式求导,它们通过局部活动模式和能量最小化来学习。
|
||
|
||
### 1.1 与传统 BP 的区别
|
||
|
||
| 维度 | 传统 BP | NeuroGraph |
|
||
|------|---------|-----------|
|
||
| 梯度来源 | 解析链式法则 | 能量景观松弛过程 |
|
||
| 信息流 | 前向 + 反向传播 | 双向活动传播(自由相 + 推动相) |
|
||
| 学习信号 | Loss 梯度 | 能量差 (∂E_nudged - ∂E_free) |
|
||
| 拓扑结构 | 人工设计固定 | 自主剪枝 + 结构探索 |
|
||
| 调控机制 | 全局优化器 | 局部规则 + 全局奖罚调制 |
|
||
|
||
### 1.2 核心隐喻
|
||
|
||
把网络看作一个物理系统:给定输入后,神经元活动会像弹簧一样"松弛"到能量最低状态。学习不是计算梯度,而是比较"自由松弛态"和"被目标推动后的松弛态"之间的能量差。
|
||
|
||
---
|
||
|
||
## 2. 技术栈
|
||
|
||
### 2.1 核心依赖
|
||
|
||
```
|
||
JAX → 自动微分 + JIT 编译 + 向量化
|
||
Equinox → 神经网络模块系统(eqx.Module)
|
||
Optax → 优化器链(SGD, Adam, 自定义变换)
|
||
Gymnasium → RL 环境标准接口
|
||
JaxPruner → 稀疏掩码基础设施
|
||
```
|
||
|
||
### 2.2 为什么不用 C/C++ 手撸
|
||
|
||
- 没有自动微分,能量梯度验证极其困难
|
||
- 没有现成的 RL 环境(Gymnasium 有 200+ 环境)
|
||
- 没有 GPU 加速的张量运算库
|
||
- 研究迭代速度会慢 10 倍以上
|
||
|
||
JAX 本身就是"手撸友好"的:它是 NumPy + 自动微分 + JIT,不强迫你接受任何高层抽象。你可以从零定义能量函数、松弛动力学、学习规则,同时拥有 GPU 加速和自动微分。
|
||
|
||
### 2.3 为什么不用 Julia
|
||
|
||
Julia 的科学计算生态很强,但:
|
||
- 生物可塑性学习几乎没有现成参考实现
|
||
- GPU 生态不如 JAX 成熟
|
||
- 社区规模小得多
|
||
|
||
---
|
||
|
||
## 3. 核心算法
|
||
|
||
### 3.1 能量函数
|
||
|
||
网络的状态(活动量 + 权重)定义了一个能量景观。能量越低,状态越"好"。
|
||
|
||
```
|
||
E(θ, a, x, y) = w_data · E_data + w_reg · E_reg + w_sparse · E_sparse + w_struct · E_struct
|
||
```
|
||
|
||
| 项 | 公式 | 作用 |
|
||
|----|------|------|
|
||
| E_data | ‖f_θ(x) - y‖² | 预测误差 |
|
||
| E_reg | λ · ‖θ‖² | 权重衰减 |
|
||
| E_sparse | γ · ‖a‖₁ | 激活稀疏性 |
|
||
| E_struct | η · ‖A‖² | 拓扑复杂度 |
|
||
|
||
### 3.2 平衡传播(Equilibrium Propagation, EqProp)
|
||
|
||
这是 Scellier & Bengio (2017) 提出的算法,核心步骤:
|
||
|
||
```
|
||
Step 1: 自由相(Free Phase)
|
||
输入 x 固定到可见层
|
||
活动量沿能量梯度松弛:a ← a - η_a · ∂E/∂a
|
||
直到收敛 → 记录 a_free, E_free
|
||
|
||
Step 2: 推动相(Nudged Phase)
|
||
输出层被目标 y 轻微推动:a_out += β · (y - a_out)
|
||
继续松弛直到收敛 → 记录 a_nudged, E_nudged
|
||
|
||
Step 3: 权重更新
|
||
Δθ = -(∂E(θ, a_nudged)/∂θ - ∂E(θ, a_free)/∂θ) / β
|
||
```
|
||
|
||
**关键洞察**:权重更新量正比于两个平衡态之间的能量梯度差。当 β → 0 时,EqProp 等价于 BP。
|
||
|
||
### 3.3 三因子学习规则(奖罚调制)
|
||
|
||
在生物大脑中,突触可塑性受神经递质(如多巴胺)调制:
|
||
|
||
```
|
||
Δw_ij = reward · e_ij · (pre_i · post_j)
|
||
```
|
||
|
||
- `pre_i`:突触前神经元活动
|
||
- `post_j`:突触后神经元活动
|
||
- `e_ij`:eligibility trace(资格迹),记录历史活动关联
|
||
- `reward`:全局奖罚信号(+1 奖励,-1 惩罚)
|
||
|
||
这使得学习信号不再是解析梯度,而是"行为后果"的函数。
|
||
|
||
### 3.4 自主剪枝 + 结构可塑性
|
||
|
||
网络自主决定哪些连接有用:
|
||
|
||
```
|
||
删除边:|w_ij| < threshold 且 活动贡献度低
|
||
添加边:节点 i 和 j 的活动互信息高 → 建立新连接
|
||
```
|
||
|
||
这模拟了大脑的突触修剪(synaptic pruning)和突触发生(synaptogenesis)。
|
||
|
||
### 3.5 可微图结构搜索(DARTS 风格)
|
||
|
||
网络拓扑用邻接矩阵 A[i,j] 表示,每个位置有一组可学习的架构参数 α_ij:
|
||
|
||
```
|
||
P(edge_ij exists) = sigmoid(α_ij)
|
||
```
|
||
|
||
架构参数和网络权重联合优化,训练结束后离散化:保留 sigmoid(α_ij) > 0.5 的边。
|
||
|
||
---
|
||
|
||
## 4. 架构设计
|
||
|
||
### 4.1 模块划分
|
||
|
||
```
|
||
src/neurograph/
|
||
│
|
||
├── core/ # 核心引擎(最底层)
|
||
│ ├── energy.py # 能量函数定义
|
||
│ ├── equilibrium.py # 松弛循环动力学
|
||
│ ├── neurons.py # 神经元模型
|
||
│ ├── surrogate.py # 代理梯度
|
||
│ └── topology.py # 图拓扑表示
|
||
│
|
||
├── learning/ # 学习规则(依赖 core)
|
||
│ ├── eqprop.py # 平衡传播更新
|
||
│ ├── reward.py # 奖罚调制
|
||
│ ├── hebbian.py # Hebbian 局部规则
|
||
│ └── optimizer.py # Optax 包装
|
||
│
|
||
├── pruning/ # 剪枝(依赖 core + learning)
|
||
│ ├── magnitude.py # 幅值剪枝
|
||
│ ├── activity.py # 活跃度剪枝
|
||
│ └── structural.py # 结构可塑性
|
||
│
|
||
├── architecture/ # 结构探索(依赖 pruning)
|
||
│ ├── graph_nas.py # 可微图搜索
|
||
│ ├── mutation.py # 拓扑变异
|
||
│ └── supernet.py # 超网络
|
||
│
|
||
├── env/ # 环境接口
|
||
│ ├── wrapper.py # Gymnasium 适配
|
||
│ └── simple/ # 内置简单环境
|
||
│
|
||
└── utils/ # 工具
|
||
├── visualization.py # 可视化
|
||
├── logging.py # 日志
|
||
└── metrics.py # 指标
|
||
```
|
||
|
||
### 4.2 依赖关系
|
||
|
||
```
|
||
core ← learning ← pruning ← architecture
|
||
↑
|
||
env (提供交互数据)
|
||
```
|
||
|
||
### 4.3 网络拓扑表示
|
||
|
||
用邻接矩阵 `[N, N]` 表示网络连接:
|
||
|
||
```python
|
||
class NetworkTopology(eqx.Module):
|
||
adjacency: jnp.ndarray # [N, N] 邻接矩阵
|
||
node_types: jnp.ndarray # [N] 节点类型 (0=visible, 1=hidden, 2=output)
|
||
edge_ops: jnp.ndarray # [N, N] 操作类型索引
|
||
node_states: jnp.ndarray # [N, state_dim] 当前活动
|
||
```
|
||
|
||
**选择理由**:
|
||
- 可微(矩阵元素是连续值)
|
||
- 直接支持矩阵乘法做消息传递
|
||
- 支持 DARTS 风格优化(学习 edge_weights)
|
||
- 大网络时可切换为 `jax.experimental.sparse.BCSR`
|
||
|
||
### 4.4 数据流:完整训练步
|
||
|
||
```
|
||
┌─────────────────────────────────────────────────────────┐
|
||
│ 训练步 (training step) │
|
||
├─────────────────────────────────────────────────────────┤
|
||
│ │
|
||
│ ① 输入钳位 │
|
||
│ activities[visible] = observations │
|
||
│ │
|
||
│ ② 自由相松弛 (jax.lax.while_loop) │
|
||
│ for t in max_iters: │
|
||
│ grad = jax.grad(energy, argnums=1)(params, acts) │
|
||
│ acts -= activity_lr * grad │
|
||
│ if converged: break │
|
||
│ → activities_free, energy_free │
|
||
│ │
|
||
│ ③ 推动相松弛 │
|
||
│ acts[output] += β * (target - acts[output]) │
|
||
│ 同②松弛 │
|
||
│ → activities_nudged, energy_nudged │
|
||
│ │
|
||
│ ④ 权重更新 (EqProp) │
|
||
│ grad_free = jax.grad(energy, 0)(params, acts_free) │
|
||
│ grad_nudged = jax.grad(energy, 0)(params, acts_nud) │
|
||
│ delta_w = -(grad_nudged - grad_free) / β │
|
||
│ delta_w *= reward_signal # 如有奖罚调制 │
|
||
│ params = optax.apply_updates(params, delta_w) │
|
||
│ │
|
||
│ ⑤ 剪枝 (每 N 步) │
|
||
│ mask = pruning_scheduler(params, activities, step) │
|
||
│ params *= mask │
|
||
│ │
|
||
│ ⑥ 结构探索 (每 M 步) │
|
||
│ topology = graph_nas.evaluate_and_mutate(topology) │
|
||
│ │
|
||
│ ⑦ 环境步进 (RL 模式) │
|
||
│ action = read_output(acts_nudged) │
|
||
│ next_obs, reward, done = env.step(action) │
|
||
│ │
|
||
└─────────────────────────────────────────────────────────┘
|
||
```
|
||
|
||
整个 ②-④ 步用 `jax.jit` 编译为一个函数。⑤⑥ 步在 JIT 外部执行(改变参数树形状)。
|
||
|
||
---
|
||
|
||
## 5. 开发阶段
|
||
|
||
### Phase 1:能量学习引擎(核心)
|
||
|
||
**目标**:MNIST >95% 准确率
|
||
|
||
| 任务 | 文件 | 说明 |
|
||
|------|------|------|
|
||
| 能量函数 | `core/energy.py` | 复合能量定义 |
|
||
| 松弛循环 | `core/equilibrium.py` | jax.lax.while_loop |
|
||
| 神经元模型 | `core/neurons.py` | rate-based, LIF |
|
||
| 代理梯度 | `core/surrogate.py` | jax.custom_jvp |
|
||
| 拓扑表示 | `core/topology.py` | 邻接矩阵 + 节点类型 |
|
||
| EqProp 更新 | `learning/eqprop.py` | 核心学习规则 |
|
||
| Optax 包装 | `learning/optimizer.py` | 可组合优化器 |
|
||
| MNIST 实验 | `experiments/phase1/mnist_eqprop.py` | 基线验证 |
|
||
|
||
### Phase 2:奖罚系统 + 环境集成
|
||
|
||
**目标**:CartPole-v1 奖励 >400
|
||
|
||
| 任务 | 文件 | 说明 |
|
||
|------|------|------|
|
||
| 三因子规则 | `learning/reward.py` | reward × eligibility × gradient |
|
||
| 资格迹 | `learning/reward.py` | 历史活动关联追踪 |
|
||
| 环境包装 | `env/wrapper.py` | Gymnasium → energy agent |
|
||
| CartPole 实验 | `experiments/phase2/cartpole_reward.py` | RL 基线 |
|
||
|
||
### Phase 3:自主剪枝 + 结构探索
|
||
|
||
**目标**:CIFAR-10 >80%,参数量 <50%
|
||
|
||
| 任务 | 文件 | 说明 |
|
||
|------|------|------|
|
||
| 幅值剪枝 | `pruning/magnitude.py` | 权重幅值阈值 |
|
||
| 活跃度剪枝 | `pruning/activity.py` | 活动贡献度 |
|
||
| 结构可塑性 | `pruning/structural.py` | 增删边 |
|
||
| 图 NAS | `architecture/graph_nas.py` | 可微架构搜索 |
|
||
| 拓扑变异 | `architecture/mutation.py` | 离散变异算子 |
|
||
| CIFAR 实验 | `experiments/phase3/pruning_dynamics.py` | 剪枝 + NAS |
|
||
|
||
### Phase 4:全系统集成
|
||
|
||
- 端到端流水线
|
||
- 多基准测试 + 消融实验
|
||
- 文档 + 教程
|
||
|
||
---
|
||
|
||
## 6. 关键 JAX API 使用指南
|
||
|
||
### 6.1 松弛循环
|
||
|
||
```python
|
||
def relax(energy_fn, params, activities, inputs, targets, nudging, config):
|
||
"""松弛到能量最低态。"""
|
||
|
||
def energy_at(acts):
|
||
return energy_fn(params, acts, inputs, targets, nudging, config)
|
||
|
||
def cond(state):
|
||
iters, _, curr_energy = state
|
||
new_e = energy_at(state[1])
|
||
return (iters < config.max_iters) & (jnp.abs(curr_energy - new_e) > config.tol)
|
||
|
||
def body(state):
|
||
iters, acts, _ = state
|
||
grad = jax.grad(energy_at)(acts)
|
||
new_acts = acts - config.activity_lr * grad
|
||
new_e = energy_at(new_acts)
|
||
return iters + 1, new_acts, new_e
|
||
|
||
init = (0, activities, jnp.inf)
|
||
return jax.lax.while_loop(cond, body, init)
|
||
```
|
||
|
||
### 6.2 代理梯度(LIF 脉冲神经元)
|
||
|
||
```python
|
||
@jax.custom_jvp
|
||
def spike_fn(x, threshold=1.0):
|
||
return jnp.where(x >= threshold, 1.0, 0.0)
|
||
|
||
@spike_fn.defjvp
|
||
def spike_fn_jvp(primals, tangents):
|
||
x, = primals
|
||
dx, = tangents
|
||
# 高斯代理梯度
|
||
surrogate = jnp.exp(-0.5 * ((x - threshold) / 0.5) ** 2)
|
||
return spike_fn(x, threshold), surrogate * dx
|
||
```
|
||
|
||
### 6.3 编译策略
|
||
|
||
```python
|
||
# 整个训练步 JIT 编译
|
||
@jax.jit
|
||
def train_step(params, activities, inputs, targets, reward, opt_state):
|
||
# 自由相
|
||
acts_free, E_free, _ = relax(energy_fn, params, activities, inputs, None, 0.0, cfg)
|
||
# 推动相
|
||
acts_nudged, E_nudged, _ = relax(energy_fn, params, acts_free, inputs, targets, β, cfg)
|
||
# 权重更新
|
||
grads = eqprop_grads(params, acts_free, acts_nudged, β)
|
||
grads *= reward # 奖罚调制
|
||
updates, opt_state = optimizer.update(grads, opt_state)
|
||
params = optax.apply_updates(params, updates)
|
||
return params, acts_nudged, opt_state, {"E_free": E_free, "E_nudged": E_nudged}
|
||
```
|
||
|
||
---
|
||
|
||
## 7. 验证策略
|
||
|
||
### 7.1 单元测试
|
||
|
||
| 测试 | 验证 |
|
||
|------|------|
|
||
| 能量单调下降 | 松弛过程中能量严格递减 |
|
||
| EqProp ≈ BP | 小网络上 EqProp 收敛到与 BP 相同解 |
|
||
| 代理梯度 | spike_fn 的 jax.grad 非零且平滑 |
|
||
| 剪枝保精度 | 50% 幅值剪枝后准确率 >80% 原始 |
|
||
| 拓扑合法性 | 所有生成的图都是有效 DAG |
|
||
|
||
### 7.2 集成测试基准
|
||
|
||
| 基准 | 目标 |
|
||
|------|------|
|
||
| MNIST (EqProp) | >95% 准确率, 20 epochs |
|
||
| MNIST (reward-modulated) | >93% 准确率 |
|
||
| CartPole-v1 | 500 episodes 内奖励 >400 |
|
||
| CIFAR-10 + 剪枝 | >80% 准确率, 参数量减半 |
|
||
|
||
### 7.3 与 BP 对比
|
||
|
||
| 基准 | BP 基线 | NeuroGraph | 通过标准 |
|
||
|------|---------|-----------|----------|
|
||
| MNIST | >99% | >97% | 误差 <2% |
|
||
| CIFAR-10 | >85% | >80% | 误差 <5% |
|
||
| CartPole | ~500 episodes | ~800 episodes | 2x 内可解 |
|
||
|
||
### 7.4 诊断指标
|
||
|
||
- **能量轨迹**:能量 vs 松弛迭代(应单调递减)
|
||
- **自由-推动间隙**:`E_nudged - E_free`(应与 loss 正相关)
|
||
- **权重更新幅值**:`‖Δw‖` 每层(应稳定,不爆炸/消失)
|
||
- **稀疏度演化**:零权重比例随训练步数
|
||
- **拓扑演化**:节点/边数量随训练步数
|
||
|
||
---
|
||
|
||
## 8. 风险与缓解
|
||
|
||
| 风险 | 严重性 | 缓解措施 |
|
||
|------|--------|---------|
|
||
| 平衡不收敛 | 高 | 阻尼更新 + 梯度裁剪 + 自适应 activity_lr |
|
||
| 能量函数局部极小 | 高 | 小推动力 β + 良好初始化 + 复合正则 |
|
||
| 结构搜索时 JAX 形状不匹配 | 高 | 固定最大邻接矩阵 + 掩码 |
|
||
| 剪枝不可逆退化 | 中 | 渐进式剪枝 + 结构可塑性恢复边 |
|
||
| 奖罚信号太稀疏 | 中 | 奖励塑形 + 好奇心内在奖励 |
|
||
| 深层平衡梯度消失 | 高 | 残差连接 + 逐层松弛 |
|
||
|
||
---
|
||
|
||
## 9. 参考资源
|
||
|
||
### 核心论文
|
||
|
||
1. Scellier & Bengio (2017). *Equilibrium Propagation: Energy-based Learning for Feedforward and Recurrent Networks*
|
||
2. LeCun (2006). *A Tutorial on Energy-Based Learning*
|
||
3. Millidge, Tsang, & Rao (2021). *Predictive Coding: a Theoretical and Experimental Review*
|
||
4. Liu et al. (2019). *DARTS: Differentiable Architecture Search*
|
||
5. Bi & Poo (2019). *Biologically plausible learning in deep networks*
|
||
|
||
### 参考实现
|
||
|
||
| 项目 | 语言 | 链接 |
|
||
|------|------|------|
|
||
| Equilibrium-Propagation | PyTorch | github.com/Laborieux-Axel/Equilibrium-Propagation |
|
||
| EB-JEPA | PyTorch | github.com/facebookresearch/eb_jepa |
|
||
| BioTorch | PyTorch | github.com/jsalbert/biotorch |
|
||
| JPC (Predictive Coding) | JAX | github.com/thebuckleylab/jpc |
|
||
| snnTorch | PyTorch | github.com/jeshraghian/snntorch |
|
||
| JaxPruner | JAX | github.com/google-research/jaxpruner |
|
||
| awesome-ebm | 列表 | github.com/yataobian/awesome-ebm |
|
||
| awesome-biologically-motivated-learning | 列表 | github.com/jsalbert/awesome-biologically-motivated-learning |
|
||
|
||
---
|
||
|
||
## 10. 术语表
|
||
|
||
| 术语 | 英文 | 说明 |
|
||
|------|------|------|
|
||
| 能量函数 | Energy Function | 定义网络状态的标量函数,越低越好 |
|
||
| 平衡传播 | Equilibrium Propagation (EqProp) | 通过两个平衡态的能量差更新权重 |
|
||
| 自由相 | Free Phase | 仅有输入,无目标推动的松弛过程 |
|
||
| 推动相 | Nudged Phase | 输出层被目标轻微推动后的松弛过程 |
|
||
| 松弛 | Relaxation | 活动量沿能量梯度下降的过程 |
|
||
| 三因子规则 | Three-factor Rule | pre × post × neuromodulator 的突触更新 |
|
||
| 资格迹 | Eligibility Trace | 记录历史活动关联的衰减记忆 |
|
||
| 代理梯度 | Surrogate Gradient | 不可导函数的近似梯度(如脉冲函数) |
|
||
| 结构可塑性 | Structural Plasticity | 增删突触连接的能力 |
|
||
| 图结构搜索 | Graph Neural Architecture Search | 在图拓扑空间中搜索最优架构 |
|