neuron/docs/design.md
2026-06-03 21:40:29 +08:00

448 lines
16 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# NeuroGraph 架构设计文档
> 类生物能量梯度深度学习框架
---
## 1. 项目定位
NeuroGraph 探索的是区别于传统反向传播Backpropagation, BP的全新训练模式。核心思想是大脑中的神经元不会做链式求导它们通过局部活动模式和能量最小化来学习。
### 1.1 与传统 BP 的区别
| 维度 | 传统 BP | NeuroGraph |
|------|---------|-----------|
| 梯度来源 | 解析链式法则 | 能量景观松弛过程 |
| 信息流 | 前向 + 反向传播 | 双向活动传播(自由相 + 推动相) |
| 学习信号 | Loss 梯度 | 能量差 (∂E_nudged - ∂E_free) |
| 拓扑结构 | 人工设计固定 | 自主剪枝 + 结构探索 |
| 调控机制 | 全局优化器 | 局部规则 + 全局奖罚调制 |
### 1.2 核心隐喻
把网络看作一个物理系统:给定输入后,神经元活动会像弹簧一样"松弛"到能量最低状态。学习不是计算梯度,而是比较"自由松弛态"和"被目标推动后的松弛态"之间的能量差。
---
## 2. 技术栈
### 2.1 核心依赖
```
JAX → 自动微分 + JIT 编译 + 向量化
Equinox → 神经网络模块系统eqx.Module
Optax → 优化器链SGD, Adam, 自定义变换)
Gymnasium → RL 环境标准接口
JaxPruner → 稀疏掩码基础设施
```
### 2.2 为什么不用 C/C++ 手撸
- 没有自动微分,能量梯度验证极其困难
- 没有现成的 RL 环境Gymnasium 有 200+ 环境)
- 没有 GPU 加速的张量运算库
- 研究迭代速度会慢 10 倍以上
JAX 本身就是"手撸友好"的:它是 NumPy + 自动微分 + JIT不强迫你接受任何高层抽象。你可以从零定义能量函数、松弛动力学、学习规则同时拥有 GPU 加速和自动微分。
### 2.3 为什么不用 Julia
Julia 的科学计算生态很强,但:
- 生物可塑性学习几乎没有现成参考实现
- GPU 生态不如 JAX 成熟
- 社区规模小得多
---
## 3. 核心算法
### 3.1 能量函数
网络的状态(活动量 + 权重)定义了一个能量景观。能量越低,状态越"好"。
```
E(θ, a, x, y) = w_data · E_data + w_reg · E_reg + w_sparse · E_sparse + w_struct · E_struct
```
| 项 | 公式 | 作用 |
|----|------|------|
| E_data | ‖f_θ(x) - y‖² | 预测误差 |
| E_reg | λ · ‖θ‖² | 权重衰减 |
| E_sparse | γ · ‖a‖₁ | 激活稀疏性 |
| E_struct | η · ‖A‖² | 拓扑复杂度 |
### 3.2 平衡传播Equilibrium Propagation, EqProp
这是 Scellier & Bengio (2017) 提出的算法,核心步骤:
```
Step 1: 自由相Free Phase
输入 x 固定到可见层
活动量沿能量梯度松弛a ← a - η_a · ∂E/∂a
直到收敛 → 记录 a_free, E_free
Step 2: 推动相Nudged Phase
输出层被目标 y 轻微推动a_out += β · (y - a_out)
继续松弛直到收敛 → 记录 a_nudged, E_nudged
Step 3: 权重更新
Δθ = -(∂E(θ, a_nudged)/∂θ - ∂E(θ, a_free)/∂θ) / β
```
**关键洞察**:权重更新量正比于两个平衡态之间的能量梯度差。当 β → 0 时EqProp 等价于 BP。
### 3.3 三因子学习规则(奖罚调制)
在生物大脑中,突触可塑性受神经递质(如多巴胺)调制:
```
Δw_ij = reward · e_ij · (pre_i · post_j)
```
- `pre_i`:突触前神经元活动
- `post_j`:突触后神经元活动
- `e_ij`eligibility trace资格迹记录历史活动关联
- `reward`:全局奖罚信号(+1 奖励,-1 惩罚)
这使得学习信号不再是解析梯度,而是"行为后果"的函数。
### 3.4 自主剪枝 + 结构可塑性
网络自主决定哪些连接有用:
```
删除边:|w_ij| < threshold 且 活动贡献度低
添加边:节点 i 和 j 的活动互信息高 → 建立新连接
```
这模拟了大脑的突触修剪synaptic pruning和突触发生synaptogenesis
### 3.5 可微图结构搜索DARTS 风格)
网络拓扑用邻接矩阵 A[i,j] 表示,每个位置有一组可学习的架构参数 α_ij
```
P(edge_ij exists) = sigmoid(α_ij)
```
架构参数和网络权重联合优化,训练结束后离散化:保留 sigmoid(α_ij) > 0.5 的边。
---
## 4. 架构设计
### 4.1 模块划分
```
src/neurograph/
├── core/ # 核心引擎(最底层)
│ ├── energy.py # 能量函数定义
│ ├── equilibrium.py # 松弛循环动力学
│ ├── neurons.py # 神经元模型
│ ├── surrogate.py # 代理梯度
│ └── topology.py # 图拓扑表示
├── learning/ # 学习规则(依赖 core
│ ├── eqprop.py # 平衡传播更新
│ ├── reward.py # 奖罚调制
│ ├── hebbian.py # Hebbian 局部规则
│ └── optimizer.py # Optax 包装
├── pruning/ # 剪枝(依赖 core + learning
│ ├── magnitude.py # 幅值剪枝
│ ├── activity.py # 活跃度剪枝
│ └── structural.py # 结构可塑性
├── architecture/ # 结构探索(依赖 pruning
│ ├── graph_nas.py # 可微图搜索
│ ├── mutation.py # 拓扑变异
│ └── supernet.py # 超网络
├── env/ # 环境接口
│ ├── wrapper.py # Gymnasium 适配
│ └── simple/ # 内置简单环境
└── utils/ # 工具
├── visualization.py # 可视化
├── logging.py # 日志
└── metrics.py # 指标
```
### 4.2 依赖关系
```
core ← learning ← pruning ← architecture
env (提供交互数据)
```
### 4.3 网络拓扑表示
用邻接矩阵 `[N, N]` 表示网络连接:
```python
class NetworkTopology(eqx.Module):
adjacency: jnp.ndarray # [N, N] 邻接矩阵
node_types: jnp.ndarray # [N] 节点类型 (0=visible, 1=hidden, 2=output)
edge_ops: jnp.ndarray # [N, N] 操作类型索引
node_states: jnp.ndarray # [N, state_dim] 当前活动
```
**选择理由**
- 可微(矩阵元素是连续值)
- 直接支持矩阵乘法做消息传递
- 支持 DARTS 风格优化(学习 edge_weights
- 大网络时可切换为 `jax.experimental.sparse.BCSR`
### 4.4 数据流:完整训练步
```
┌─────────────────────────────────────────────────────────┐
│ 训练步 (training step) │
├─────────────────────────────────────────────────────────┤
│ │
│ ① 输入钳位 │
│ activities[visible] = observations │
│ │
│ ② 自由相松弛 (jax.lax.while_loop) │
│ for t in max_iters: │
│ grad = jax.grad(energy, argnums=1)(params, acts) │
│ acts -= activity_lr * grad │
│ if converged: break │
│ → activities_free, energy_free │
│ │
│ ③ 推动相松弛 │
│ acts[output] += β * (target - acts[output]) │
│ 同②松弛 │
│ → activities_nudged, energy_nudged │
│ │
│ ④ 权重更新 (EqProp) │
│ grad_free = jax.grad(energy, 0)(params, acts_free) │
│ grad_nudged = jax.grad(energy, 0)(params, acts_nud) │
│ delta_w = -(grad_nudged - grad_free) / β │
│ delta_w *= reward_signal # 如有奖罚调制 │
│ params = optax.apply_updates(params, delta_w) │
│ │
│ ⑤ 剪枝 (每 N 步) │
│ mask = pruning_scheduler(params, activities, step) │
│ params *= mask │
│ │
│ ⑥ 结构探索 (每 M 步) │
│ topology = graph_nas.evaluate_and_mutate(topology) │
│ │
│ ⑦ 环境步进 (RL 模式) │
│ action = read_output(acts_nudged) │
│ next_obs, reward, done = env.step(action) │
│ │
└─────────────────────────────────────────────────────────┘
```
整个 ②-④ 步用 `jax.jit` 编译为一个函数。⑤⑥ 步在 JIT 外部执行(改变参数树形状)。
---
## 5. 开发阶段
### Phase 1能量学习引擎核心
**目标**MNIST >95% 准确率
| 任务 | 文件 | 说明 |
|------|------|------|
| 能量函数 | `core/energy.py` | 复合能量定义 |
| 松弛循环 | `core/equilibrium.py` | jax.lax.while_loop |
| 神经元模型 | `core/neurons.py` | rate-based, LIF |
| 代理梯度 | `core/surrogate.py` | jax.custom_jvp |
| 拓扑表示 | `core/topology.py` | 邻接矩阵 + 节点类型 |
| EqProp 更新 | `learning/eqprop.py` | 核心学习规则 |
| Optax 包装 | `learning/optimizer.py` | 可组合优化器 |
| MNIST 实验 | `experiments/phase1/mnist_eqprop.py` | 基线验证 |
### Phase 2奖罚系统 + 环境集成
**目标**CartPole-v1 奖励 >400
| 任务 | 文件 | 说明 |
|------|------|------|
| 三因子规则 | `learning/reward.py` | reward × eligibility × gradient |
| 资格迹 | `learning/reward.py` | 历史活动关联追踪 |
| 环境包装 | `env/wrapper.py` | Gymnasium → energy agent |
| CartPole 实验 | `experiments/phase2/cartpole_reward.py` | RL 基线 |
### Phase 3自主剪枝 + 结构探索
**目标**CIFAR-10 >80%,参数量 <50%
| 任务 | 文件 | 说明 |
|------|------|------|
| 幅值剪枝 | `pruning/magnitude.py` | 权重幅值阈值 |
| 活跃度剪枝 | `pruning/activity.py` | 活动贡献度 |
| 结构可塑性 | `pruning/structural.py` | 增删边 |
| NAS | `architecture/graph_nas.py` | 可微架构搜索 |
| 拓扑变异 | `architecture/mutation.py` | 离散变异算子 |
| CIFAR 实验 | `experiments/phase3/pruning_dynamics.py` | 剪枝 + NAS |
### Phase 4全系统集成
- 端到端流水线
- 多基准测试 + 消融实验
- 文档 + 教程
---
## 6. 关键 JAX API 使用指南
### 6.1 松弛循环
```python
def relax(energy_fn, params, activities, inputs, targets, nudging, config):
"""松弛到能量最低态。"""
def energy_at(acts):
return energy_fn(params, acts, inputs, targets, nudging, config)
def cond(state):
iters, _, curr_energy = state
new_e = energy_at(state[1])
return (iters < config.max_iters) & (jnp.abs(curr_energy - new_e) > config.tol)
def body(state):
iters, acts, _ = state
grad = jax.grad(energy_at)(acts)
new_acts = acts - config.activity_lr * grad
new_e = energy_at(new_acts)
return iters + 1, new_acts, new_e
init = (0, activities, jnp.inf)
return jax.lax.while_loop(cond, body, init)
```
### 6.2 代理梯度LIF 脉冲神经元)
```python
@jax.custom_jvp
def spike_fn(x, threshold=1.0):
return jnp.where(x >= threshold, 1.0, 0.0)
@spike_fn.defjvp
def spike_fn_jvp(primals, tangents):
x, = primals
dx, = tangents
# 高斯代理梯度
surrogate = jnp.exp(-0.5 * ((x - threshold) / 0.5) ** 2)
return spike_fn(x, threshold), surrogate * dx
```
### 6.3 编译策略
```python
# 整个训练步 JIT 编译
@jax.jit
def train_step(params, activities, inputs, targets, reward, opt_state):
# 自由相
acts_free, E_free, _ = relax(energy_fn, params, activities, inputs, None, 0.0, cfg)
# 推动相
acts_nudged, E_nudged, _ = relax(energy_fn, params, acts_free, inputs, targets, β, cfg)
# 权重更新
grads = eqprop_grads(params, acts_free, acts_nudged, β)
grads *= reward # 奖罚调制
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return params, acts_nudged, opt_state, {"E_free": E_free, "E_nudged": E_nudged}
```
---
## 7. 验证策略
### 7.1 单元测试
| 测试 | 验证 |
|------|------|
| 能量单调下降 | 松弛过程中能量严格递减 |
| EqProp BP | 小网络上 EqProp 收敛到与 BP 相同解 |
| 代理梯度 | spike_fn jax.grad 非零且平滑 |
| 剪枝保精度 | 50% 幅值剪枝后准确率 >80% 原始 |
| 拓扑合法性 | 所有生成的图都是有效 DAG |
### 7.2 集成测试基准
| 基准 | 目标 |
|------|------|
| MNIST (EqProp) | >95% 准确率, 20 epochs |
| MNIST (reward-modulated) | >93% 准确率 |
| CartPole-v1 | 500 episodes 内奖励 >400 |
| CIFAR-10 + 剪枝 | >80% 准确率, 参数量减半 |
### 7.3 与 BP 对比
| 基准 | BP 基线 | NeuroGraph | 通过标准 |
|------|---------|-----------|----------|
| MNIST | >99% | >97% | 误差 <2% |
| CIFAR-10 | >85% | >80% | 误差 <5% |
| CartPole | ~500 episodes | ~800 episodes | 2x 内可解 |
### 7.4 诊断指标
- **能量轨迹**能量 vs 松弛迭代应单调递减
- **自由-推动间隙**`E_nudged - E_free`应与 loss 正相关
- **权重更新幅值**`‖Δw‖` 每层应稳定不爆炸/消失
- **稀疏度演化**零权重比例随训练步数
- **拓扑演化**节点/边数量随训练步数
---
## 8. 风险与缓解
| 风险 | 严重性 | 缓解措施 |
|------|--------|---------|
| 平衡不收敛 | | 阻尼更新 + 梯度裁剪 + 自适应 activity_lr |
| 能量函数局部极小 | | 小推动力 β + 良好初始化 + 复合正则 |
| 结构搜索时 JAX 形状不匹配 | | 固定最大邻接矩阵 + 掩码 |
| 剪枝不可逆退化 | | 渐进式剪枝 + 结构可塑性恢复边 |
| 奖罚信号太稀疏 | | 奖励塑形 + 好奇心内在奖励 |
| 深层平衡梯度消失 | | 残差连接 + 逐层松弛 |
---
## 9. 参考资源
### 核心论文
1. Scellier & Bengio (2017). *Equilibrium Propagation: Energy-based Learning for Feedforward and Recurrent Networks*
2. LeCun (2006). *A Tutorial on Energy-Based Learning*
3. Millidge, Tsang, & Rao (2021). *Predictive Coding: a Theoretical and Experimental Review*
4. Liu et al. (2019). *DARTS: Differentiable Architecture Search*
5. Bi & Poo (2019). *Biologically plausible learning in deep networks*
### 参考实现
| 项目 | 语言 | 链接 |
|------|------|------|
| Equilibrium-Propagation | PyTorch | github.com/Laborieux-Axel/Equilibrium-Propagation |
| EB-JEPA | PyTorch | github.com/facebookresearch/eb_jepa |
| BioTorch | PyTorch | github.com/jsalbert/biotorch |
| JPC (Predictive Coding) | JAX | github.com/thebuckleylab/jpc |
| snnTorch | PyTorch | github.com/jeshraghian/snntorch |
| JaxPruner | JAX | github.com/google-research/jaxpruner |
| awesome-ebm | 列表 | github.com/yataobian/awesome-ebm |
| awesome-biologically-motivated-learning | 列表 | github.com/jsalbert/awesome-biologically-motivated-learning |
---
## 10. 术语表
| 术语 | 英文 | 说明 |
|------|------|------|
| 能量函数 | Energy Function | 定义网络状态的标量函数越低越好 |
| 平衡传播 | Equilibrium Propagation (EqProp) | 通过两个平衡态的能量差更新权重 |
| 自由相 | Free Phase | 仅有输入无目标推动的松弛过程 |
| 推动相 | Nudged Phase | 输出层被目标轻微推动后的松弛过程 |
| 松弛 | Relaxation | 活动量沿能量梯度下降的过程 |
| 三因子规则 | Three-factor Rule | pre × post × neuromodulator 的突触更新 |
| 资格迹 | Eligibility Trace | 记录历史活动关联的衰减记忆 |
| 代理梯度 | Surrogate Gradient | 不可导函数的近似梯度如脉冲函数 |
| 结构可塑性 | Structural Plasticity | 增删突触连接的能力 |
| 图结构搜索 | Graph Neural Architecture Search | 在图拓扑空间中搜索最优架构 |