commit a97ee39bd4f2054eb0bbeffe549523ed421e1663 Author: yongjiang.lin Date: Wed Jun 3 21:40:29 2026 +0800 定义框架 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..55189e8 --- /dev/null +++ b/.gitignore @@ -0,0 +1,37 @@ +# Python +__pycache__/ +*.py[cod] +*.egg-info/ +*.egg +dist/ +build/ + +# Virtual environments +.venv/ +venv/ + +# IDE +.vscode/ +.idea/ +*.swp +*.swo + +# Jupyter +.ipynb_checkpoints/ + +# Test/Coverage +.coverage +htmlcov/ +.pytest_cache/ + +# JAX +jax_cache/ + +# Experiment outputs +checkpoints/ +runs/ +logs/ + +# OS +.DS_Store +Thumbs.db diff --git a/README.md b/README.md new file mode 100644 index 0000000..0c8eef7 --- /dev/null +++ b/README.md @@ -0,0 +1,43 @@ +# NeuroGraph + +Bio-inspired energy-gradient deep learning framework. + +## Overview + +NeuroGraph explores training paradigms beyond backpropagation: +- **Energy-based learning**: Networks relax to energy minima instead of computing analytic gradients +- **Reward-modulated plasticity**: Three-factor learning rules with neuromodulator signals +- **Autonomous pruning**: Self-organizing network topology through structural plasticity +- **Architecture search**: Differentiable graph exploration for optimal connectivity + +## Setup + +```bash +pip install -e ".[dev]" +``` + +For GPU support: +```bash +pip install -e ".[dev,gpu]" +``` + +## Quick Start + +```python +from neurograph.core.energy import compute_energy, EnergyConfig + +config = EnergyConfig(data_weight=1.0, reg_weight=0.01) +energy = compute_energy(params, activities, inputs, targets, config=config) +``` + +## Project Structure + +``` +src/neurograph/ +├── core/ # Energy functions, equilibrium dynamics, neurons +├── learning/ # EqProp, reward modulation, Hebbian rules +├── pruning/ # Magnitude/activity-based pruning, structural plasticity +├── architecture/ # Graph NAS, topology mutation +├── env/ # Gymnasium wrappers +└── utils/ # Visualization, logging, metrics +``` diff --git a/docs/design.md b/docs/design.md new file mode 100644 index 0000000..ebae78f --- /dev/null +++ b/docs/design.md @@ -0,0 +1,447 @@ +# NeuroGraph 架构设计文档 + +> 类生物能量梯度深度学习框架 + +--- + +## 1. 项目定位 + +NeuroGraph 探索的是区别于传统反向传播(Backpropagation, BP)的全新训练模式。核心思想是:大脑中的神经元不会做链式求导,它们通过局部活动模式和能量最小化来学习。 + +### 1.1 与传统 BP 的区别 + +| 维度 | 传统 BP | NeuroGraph | +|------|---------|-----------| +| 梯度来源 | 解析链式法则 | 能量景观松弛过程 | +| 信息流 | 前向 + 反向传播 | 双向活动传播(自由相 + 推动相) | +| 学习信号 | Loss 梯度 | 能量差 (∂E_nudged - ∂E_free) | +| 拓扑结构 | 人工设计固定 | 自主剪枝 + 结构探索 | +| 调控机制 | 全局优化器 | 局部规则 + 全局奖罚调制 | + +### 1.2 核心隐喻 + +把网络看作一个物理系统:给定输入后,神经元活动会像弹簧一样"松弛"到能量最低状态。学习不是计算梯度,而是比较"自由松弛态"和"被目标推动后的松弛态"之间的能量差。 + +--- + +## 2. 技术栈 + +### 2.1 核心依赖 + +``` +JAX → 自动微分 + JIT 编译 + 向量化 +Equinox → 神经网络模块系统(eqx.Module) +Optax → 优化器链(SGD, Adam, 自定义变换) +Gymnasium → RL 环境标准接口 +JaxPruner → 稀疏掩码基础设施 +``` + +### 2.2 为什么不用 C/C++ 手撸 + +- 没有自动微分,能量梯度验证极其困难 +- 没有现成的 RL 环境(Gymnasium 有 200+ 环境) +- 没有 GPU 加速的张量运算库 +- 研究迭代速度会慢 10 倍以上 + +JAX 本身就是"手撸友好"的:它是 NumPy + 自动微分 + JIT,不强迫你接受任何高层抽象。你可以从零定义能量函数、松弛动力学、学习规则,同时拥有 GPU 加速和自动微分。 + +### 2.3 为什么不用 Julia + +Julia 的科学计算生态很强,但: +- 生物可塑性学习几乎没有现成参考实现 +- GPU 生态不如 JAX 成熟 +- 社区规模小得多 + +--- + +## 3. 核心算法 + +### 3.1 能量函数 + +网络的状态(活动量 + 权重)定义了一个能量景观。能量越低,状态越"好"。 + +``` +E(θ, a, x, y) = w_data · E_data + w_reg · E_reg + w_sparse · E_sparse + w_struct · E_struct +``` + +| 项 | 公式 | 作用 | +|----|------|------| +| E_data | ‖f_θ(x) - y‖² | 预测误差 | +| E_reg | λ · ‖θ‖² | 权重衰减 | +| E_sparse | γ · ‖a‖₁ | 激活稀疏性 | +| E_struct | η · ‖A‖² | 拓扑复杂度 | + +### 3.2 平衡传播(Equilibrium Propagation, EqProp) + +这是 Scellier & Bengio (2017) 提出的算法,核心步骤: + +``` +Step 1: 自由相(Free Phase) + 输入 x 固定到可见层 + 活动量沿能量梯度松弛:a ← a - η_a · ∂E/∂a + 直到收敛 → 记录 a_free, E_free + +Step 2: 推动相(Nudged Phase) + 输出层被目标 y 轻微推动:a_out += β · (y - a_out) + 继续松弛直到收敛 → 记录 a_nudged, E_nudged + +Step 3: 权重更新 + Δθ = -(∂E(θ, a_nudged)/∂θ - ∂E(θ, a_free)/∂θ) / β +``` + +**关键洞察**:权重更新量正比于两个平衡态之间的能量梯度差。当 β → 0 时,EqProp 等价于 BP。 + +### 3.3 三因子学习规则(奖罚调制) + +在生物大脑中,突触可塑性受神经递质(如多巴胺)调制: + +``` +Δw_ij = reward · e_ij · (pre_i · post_j) +``` + +- `pre_i`:突触前神经元活动 +- `post_j`:突触后神经元活动 +- `e_ij`:eligibility trace(资格迹),记录历史活动关联 +- `reward`:全局奖罚信号(+1 奖励,-1 惩罚) + +这使得学习信号不再是解析梯度,而是"行为后果"的函数。 + +### 3.4 自主剪枝 + 结构可塑性 + +网络自主决定哪些连接有用: + +``` +删除边:|w_ij| < threshold 且 活动贡献度低 +添加边:节点 i 和 j 的活动互信息高 → 建立新连接 +``` + +这模拟了大脑的突触修剪(synaptic pruning)和突触发生(synaptogenesis)。 + +### 3.5 可微图结构搜索(DARTS 风格) + +网络拓扑用邻接矩阵 A[i,j] 表示,每个位置有一组可学习的架构参数 α_ij: + +``` +P(edge_ij exists) = sigmoid(α_ij) +``` + +架构参数和网络权重联合优化,训练结束后离散化:保留 sigmoid(α_ij) > 0.5 的边。 + +--- + +## 4. 架构设计 + +### 4.1 模块划分 + +``` +src/neurograph/ +│ +├── core/ # 核心引擎(最底层) +│ ├── energy.py # 能量函数定义 +│ ├── equilibrium.py # 松弛循环动力学 +│ ├── neurons.py # 神经元模型 +│ ├── surrogate.py # 代理梯度 +│ └── topology.py # 图拓扑表示 +│ +├── learning/ # 学习规则(依赖 core) +│ ├── eqprop.py # 平衡传播更新 +│ ├── reward.py # 奖罚调制 +│ ├── hebbian.py # Hebbian 局部规则 +│ └── optimizer.py # Optax 包装 +│ +├── pruning/ # 剪枝(依赖 core + learning) +│ ├── magnitude.py # 幅值剪枝 +│ ├── activity.py # 活跃度剪枝 +│ └── structural.py # 结构可塑性 +│ +├── architecture/ # 结构探索(依赖 pruning) +│ ├── graph_nas.py # 可微图搜索 +│ ├── mutation.py # 拓扑变异 +│ └── supernet.py # 超网络 +│ +├── env/ # 环境接口 +│ ├── wrapper.py # Gymnasium 适配 +│ └── simple/ # 内置简单环境 +│ +└── utils/ # 工具 + ├── visualization.py # 可视化 + ├── logging.py # 日志 + └── metrics.py # 指标 +``` + +### 4.2 依赖关系 + +``` +core ← learning ← pruning ← architecture + ↑ + env (提供交互数据) +``` + +### 4.3 网络拓扑表示 + +用邻接矩阵 `[N, N]` 表示网络连接: + +```python +class NetworkTopology(eqx.Module): + adjacency: jnp.ndarray # [N, N] 邻接矩阵 + node_types: jnp.ndarray # [N] 节点类型 (0=visible, 1=hidden, 2=output) + edge_ops: jnp.ndarray # [N, N] 操作类型索引 + node_states: jnp.ndarray # [N, state_dim] 当前活动 +``` + +**选择理由**: +- 可微(矩阵元素是连续值) +- 直接支持矩阵乘法做消息传递 +- 支持 DARTS 风格优化(学习 edge_weights) +- 大网络时可切换为 `jax.experimental.sparse.BCSR` + +### 4.4 数据流:完整训练步 + +``` +┌─────────────────────────────────────────────────────────┐ +│ 训练步 (training step) │ +├─────────────────────────────────────────────────────────┤ +│ │ +│ ① 输入钳位 │ +│ activities[visible] = observations │ +│ │ +│ ② 自由相松弛 (jax.lax.while_loop) │ +│ for t in max_iters: │ +│ grad = jax.grad(energy, argnums=1)(params, acts) │ +│ acts -= activity_lr * grad │ +│ if converged: break │ +│ → activities_free, energy_free │ +│ │ +│ ③ 推动相松弛 │ +│ acts[output] += β * (target - acts[output]) │ +│ 同②松弛 │ +│ → activities_nudged, energy_nudged │ +│ │ +│ ④ 权重更新 (EqProp) │ +│ grad_free = jax.grad(energy, 0)(params, acts_free) │ +│ grad_nudged = jax.grad(energy, 0)(params, acts_nud) │ +│ delta_w = -(grad_nudged - grad_free) / β │ +│ delta_w *= reward_signal # 如有奖罚调制 │ +│ params = optax.apply_updates(params, delta_w) │ +│ │ +│ ⑤ 剪枝 (每 N 步) │ +│ mask = pruning_scheduler(params, activities, step) │ +│ params *= mask │ +│ │ +│ ⑥ 结构探索 (每 M 步) │ +│ topology = graph_nas.evaluate_and_mutate(topology) │ +│ │ +│ ⑦ 环境步进 (RL 模式) │ +│ action = read_output(acts_nudged) │ +│ next_obs, reward, done = env.step(action) │ +│ │ +└─────────────────────────────────────────────────────────┘ +``` + +整个 ②-④ 步用 `jax.jit` 编译为一个函数。⑤⑥ 步在 JIT 外部执行(改变参数树形状)。 + +--- + +## 5. 开发阶段 + +### Phase 1:能量学习引擎(核心) + +**目标**:MNIST >95% 准确率 + +| 任务 | 文件 | 说明 | +|------|------|------| +| 能量函数 | `core/energy.py` | 复合能量定义 | +| 松弛循环 | `core/equilibrium.py` | jax.lax.while_loop | +| 神经元模型 | `core/neurons.py` | rate-based, LIF | +| 代理梯度 | `core/surrogate.py` | jax.custom_jvp | +| 拓扑表示 | `core/topology.py` | 邻接矩阵 + 节点类型 | +| EqProp 更新 | `learning/eqprop.py` | 核心学习规则 | +| Optax 包装 | `learning/optimizer.py` | 可组合优化器 | +| MNIST 实验 | `experiments/phase1/mnist_eqprop.py` | 基线验证 | + +### Phase 2:奖罚系统 + 环境集成 + +**目标**:CartPole-v1 奖励 >400 + +| 任务 | 文件 | 说明 | +|------|------|------| +| 三因子规则 | `learning/reward.py` | reward × eligibility × gradient | +| 资格迹 | `learning/reward.py` | 历史活动关联追踪 | +| 环境包装 | `env/wrapper.py` | Gymnasium → energy agent | +| CartPole 实验 | `experiments/phase2/cartpole_reward.py` | RL 基线 | + +### Phase 3:自主剪枝 + 结构探索 + +**目标**:CIFAR-10 >80%,参数量 <50% + +| 任务 | 文件 | 说明 | +|------|------|------| +| 幅值剪枝 | `pruning/magnitude.py` | 权重幅值阈值 | +| 活跃度剪枝 | `pruning/activity.py` | 活动贡献度 | +| 结构可塑性 | `pruning/structural.py` | 增删边 | +| 图 NAS | `architecture/graph_nas.py` | 可微架构搜索 | +| 拓扑变异 | `architecture/mutation.py` | 离散变异算子 | +| CIFAR 实验 | `experiments/phase3/pruning_dynamics.py` | 剪枝 + NAS | + +### Phase 4:全系统集成 + +- 端到端流水线 +- 多基准测试 + 消融实验 +- 文档 + 教程 + +--- + +## 6. 关键 JAX API 使用指南 + +### 6.1 松弛循环 + +```python +def relax(energy_fn, params, activities, inputs, targets, nudging, config): + """松弛到能量最低态。""" + + def energy_at(acts): + return energy_fn(params, acts, inputs, targets, nudging, config) + + def cond(state): + iters, _, curr_energy = state + new_e = energy_at(state[1]) + return (iters < config.max_iters) & (jnp.abs(curr_energy - new_e) > config.tol) + + def body(state): + iters, acts, _ = state + grad = jax.grad(energy_at)(acts) + new_acts = acts - config.activity_lr * grad + new_e = energy_at(new_acts) + return iters + 1, new_acts, new_e + + init = (0, activities, jnp.inf) + return jax.lax.while_loop(cond, body, init) +``` + +### 6.2 代理梯度(LIF 脉冲神经元) + +```python +@jax.custom_jvp +def spike_fn(x, threshold=1.0): + return jnp.where(x >= threshold, 1.0, 0.0) + +@spike_fn.defjvp +def spike_fn_jvp(primals, tangents): + x, = primals + dx, = tangents + # 高斯代理梯度 + surrogate = jnp.exp(-0.5 * ((x - threshold) / 0.5) ** 2) + return spike_fn(x, threshold), surrogate * dx +``` + +### 6.3 编译策略 + +```python +# 整个训练步 JIT 编译 +@jax.jit +def train_step(params, activities, inputs, targets, reward, opt_state): + # 自由相 + acts_free, E_free, _ = relax(energy_fn, params, activities, inputs, None, 0.0, cfg) + # 推动相 + acts_nudged, E_nudged, _ = relax(energy_fn, params, acts_free, inputs, targets, β, cfg) + # 权重更新 + grads = eqprop_grads(params, acts_free, acts_nudged, β) + grads *= reward # 奖罚调制 + updates, opt_state = optimizer.update(grads, opt_state) + params = optax.apply_updates(params, updates) + return params, acts_nudged, opt_state, {"E_free": E_free, "E_nudged": E_nudged} +``` + +--- + +## 7. 验证策略 + +### 7.1 单元测试 + +| 测试 | 验证 | +|------|------| +| 能量单调下降 | 松弛过程中能量严格递减 | +| EqProp ≈ BP | 小网络上 EqProp 收敛到与 BP 相同解 | +| 代理梯度 | spike_fn 的 jax.grad 非零且平滑 | +| 剪枝保精度 | 50% 幅值剪枝后准确率 >80% 原始 | +| 拓扑合法性 | 所有生成的图都是有效 DAG | + +### 7.2 集成测试基准 + +| 基准 | 目标 | +|------|------| +| MNIST (EqProp) | >95% 准确率, 20 epochs | +| MNIST (reward-modulated) | >93% 准确率 | +| CartPole-v1 | 500 episodes 内奖励 >400 | +| CIFAR-10 + 剪枝 | >80% 准确率, 参数量减半 | + +### 7.3 与 BP 对比 + +| 基准 | BP 基线 | NeuroGraph | 通过标准 | +|------|---------|-----------|----------| +| MNIST | >99% | >97% | 误差 <2% | +| CIFAR-10 | >85% | >80% | 误差 <5% | +| CartPole | ~500 episodes | ~800 episodes | 2x 内可解 | + +### 7.4 诊断指标 + +- **能量轨迹**:能量 vs 松弛迭代(应单调递减) +- **自由-推动间隙**:`E_nudged - E_free`(应与 loss 正相关) +- **权重更新幅值**:`‖Δw‖` 每层(应稳定,不爆炸/消失) +- **稀疏度演化**:零权重比例随训练步数 +- **拓扑演化**:节点/边数量随训练步数 + +--- + +## 8. 风险与缓解 + +| 风险 | 严重性 | 缓解措施 | +|------|--------|---------| +| 平衡不收敛 | 高 | 阻尼更新 + 梯度裁剪 + 自适应 activity_lr | +| 能量函数局部极小 | 高 | 小推动力 β + 良好初始化 + 复合正则 | +| 结构搜索时 JAX 形状不匹配 | 高 | 固定最大邻接矩阵 + 掩码 | +| 剪枝不可逆退化 | 中 | 渐进式剪枝 + 结构可塑性恢复边 | +| 奖罚信号太稀疏 | 中 | 奖励塑形 + 好奇心内在奖励 | +| 深层平衡梯度消失 | 高 | 残差连接 + 逐层松弛 | + +--- + +## 9. 参考资源 + +### 核心论文 + +1. Scellier & Bengio (2017). *Equilibrium Propagation: Energy-based Learning for Feedforward and Recurrent Networks* +2. LeCun (2006). *A Tutorial on Energy-Based Learning* +3. Millidge, Tsang, & Rao (2021). *Predictive Coding: a Theoretical and Experimental Review* +4. Liu et al. (2019). *DARTS: Differentiable Architecture Search* +5. Bi & Poo (2019). *Biologically plausible learning in deep networks* + +### 参考实现 + +| 项目 | 语言 | 链接 | +|------|------|------| +| Equilibrium-Propagation | PyTorch | github.com/Laborieux-Axel/Equilibrium-Propagation | +| EB-JEPA | PyTorch | github.com/facebookresearch/eb_jepa | +| BioTorch | PyTorch | github.com/jsalbert/biotorch | +| JPC (Predictive Coding) | JAX | github.com/thebuckleylab/jpc | +| snnTorch | PyTorch | github.com/jeshraghian/snntorch | +| JaxPruner | JAX | github.com/google-research/jaxpruner | +| awesome-ebm | 列表 | github.com/yataobian/awesome-ebm | +| awesome-biologically-motivated-learning | 列表 | github.com/jsalbert/awesome-biologically-motivated-learning | + +--- + +## 10. 术语表 + +| 术语 | 英文 | 说明 | +|------|------|------| +| 能量函数 | Energy Function | 定义网络状态的标量函数,越低越好 | +| 平衡传播 | Equilibrium Propagation (EqProp) | 通过两个平衡态的能量差更新权重 | +| 自由相 | Free Phase | 仅有输入,无目标推动的松弛过程 | +| 推动相 | Nudged Phase | 输出层被目标轻微推动后的松弛过程 | +| 松弛 | Relaxation | 活动量沿能量梯度下降的过程 | +| 三因子规则 | Three-factor Rule | pre × post × neuromodulator 的突触更新 | +| 资格迹 | Eligibility Trace | 记录历史活动关联的衰减记忆 | +| 代理梯度 | Surrogate Gradient | 不可导函数的近似梯度(如脉冲函数) | +| 结构可塑性 | Structural Plasticity | 增删突触连接的能力 | +| 图结构搜索 | Graph Neural Architecture Search | 在图拓扑空间中搜索最优架构 | diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..cd1111c --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,38 @@ +[project] +name = "neurograph" +version = "0.1.0" +description = "Bio-inspired energy-gradient deep learning framework" +requires-python = ">=3.10" + +dependencies = [ + "jax>=0.4.30", + "jaxlib>=0.4.30", + "equinox>=0.11.0", + "optax>=0.2.0", + "gymnasium>=1.0.0", + "numpy>=1.26.0", + "matplotlib>=3.8.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=8.0", + "ruff>=0.4.0", +] +gpu = [ + "jax[cuda12]>=0.4.30", +] + +[build-system] +requires = ["setuptools>=68.0"] +build-backend = "setuptools.backends._legacy:_Backend" + +[tool.setuptools.packages.find] +where = ["src"] + +[tool.ruff] +target-version = "py310" +line-length = 100 + +[tool.pytest.ini_options] +testpaths = ["tests"] diff --git a/src/neurograph/__init__.py b/src/neurograph/__init__.py new file mode 100644 index 0000000..41ce54c --- /dev/null +++ b/src/neurograph/__init__.py @@ -0,0 +1,3 @@ +"""NeuroGraph — Bio-inspired energy-gradient deep learning framework.""" + +__version__ = "0.1.0" diff --git a/src/neurograph/architecture/__init__.py b/src/neurograph/architecture/__init__.py new file mode 100644 index 0000000..8f990de --- /dev/null +++ b/src/neurograph/architecture/__init__.py @@ -0,0 +1 @@ +"""Architecture search and topology mutation.""" diff --git a/src/neurograph/core/__init__.py b/src/neurograph/core/__init__.py new file mode 100644 index 0000000..e7be1a1 --- /dev/null +++ b/src/neurograph/core/__init__.py @@ -0,0 +1 @@ +"""Core energy-based learning engine.""" diff --git a/src/neurograph/core/energy.py b/src/neurograph/core/energy.py new file mode 100644 index 0000000..0491159 --- /dev/null +++ b/src/neurograph/core/energy.py @@ -0,0 +1,55 @@ +"""Energy function definitions. + +The energy function is the foundation of the framework. +It defines a scalar energy landscape over network activities +and parameters, such that learning occurs through relaxation +to energy minima rather than analytic backpropagation. + +Energy composition: + E = E_data + E_regularization + E_sparsity + E_structural +""" + +from dataclasses import dataclass + +import jax +import jax.numpy as jnp + + +@dataclass(frozen=True) +class EnergyConfig: + """Weights for each energy term.""" + + data_weight: float = 1.0 + reg_weight: float = 0.01 + sparse_weight: float = 0.001 + structural_weight: float = 0.0 + + +def compute_energy( + params: dict, + activities: jax.Array, + inputs: jax.Array, + targets: jax.Array | None = None, + nudging_strength: float = 0.0, + config: EnergyConfig = EnergyConfig(), +) -> jax.Array: + """Compute composite energy. + + Args: + params: Network parameters (weights, biases). + activities: Current neuron activities. + inputs: Input data clamped to visible nodes. + targets: Target outputs (None for free phase). + nudging_strength: How strongly to push output toward targets. + config: Energy term weights. + + Returns: + Scalar energy value. + """ + # TODO: Implement energy terms + # E_data: prediction error + # E_reg: weight decay + # E_sparse: activity sparsity + # E_structural: topology complexity prior + + raise NotImplementedError("Implement your energy function here.") diff --git a/src/neurograph/env/__init__.py b/src/neurograph/env/__init__.py new file mode 100644 index 0000000..5f74c98 --- /dev/null +++ b/src/neurograph/env/__init__.py @@ -0,0 +1 @@ +"""Environment interface and wrappers.""" diff --git a/src/neurograph/learning/__init__.py b/src/neurograph/learning/__init__.py new file mode 100644 index 0000000..74b5e16 --- /dev/null +++ b/src/neurograph/learning/__init__.py @@ -0,0 +1 @@ +"""Learning rules and update mechanisms.""" diff --git a/src/neurograph/pruning/__init__.py b/src/neurograph/pruning/__init__.py new file mode 100644 index 0000000..27a70c4 --- /dev/null +++ b/src/neurograph/pruning/__init__.py @@ -0,0 +1 @@ +"""Autonomous pruning and structural plasticity.""" diff --git a/src/neurograph/utils/__init__.py b/src/neurograph/utils/__init__.py new file mode 100644 index 0000000..11c5725 --- /dev/null +++ b/src/neurograph/utils/__init__.py @@ -0,0 +1 @@ +"""Utilities — visualization, logging, metrics."""