定义框架
This commit is contained in:
commit
a97ee39bd4
37
.gitignore
vendored
Normal file
37
.gitignore
vendored
Normal file
@ -0,0 +1,37 @@
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*.egg-info/
|
||||
*.egg
|
||||
dist/
|
||||
build/
|
||||
|
||||
# Virtual environments
|
||||
.venv/
|
||||
venv/
|
||||
|
||||
# IDE
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
*.swo
|
||||
|
||||
# Jupyter
|
||||
.ipynb_checkpoints/
|
||||
|
||||
# Test/Coverage
|
||||
.coverage
|
||||
htmlcov/
|
||||
.pytest_cache/
|
||||
|
||||
# JAX
|
||||
jax_cache/
|
||||
|
||||
# Experiment outputs
|
||||
checkpoints/
|
||||
runs/
|
||||
logs/
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
43
README.md
Normal file
43
README.md
Normal file
@ -0,0 +1,43 @@
|
||||
# NeuroGraph
|
||||
|
||||
Bio-inspired energy-gradient deep learning framework.
|
||||
|
||||
## Overview
|
||||
|
||||
NeuroGraph explores training paradigms beyond backpropagation:
|
||||
- **Energy-based learning**: Networks relax to energy minima instead of computing analytic gradients
|
||||
- **Reward-modulated plasticity**: Three-factor learning rules with neuromodulator signals
|
||||
- **Autonomous pruning**: Self-organizing network topology through structural plasticity
|
||||
- **Architecture search**: Differentiable graph exploration for optimal connectivity
|
||||
|
||||
## Setup
|
||||
|
||||
```bash
|
||||
pip install -e ".[dev]"
|
||||
```
|
||||
|
||||
For GPU support:
|
||||
```bash
|
||||
pip install -e ".[dev,gpu]"
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
```python
|
||||
from neurograph.core.energy import compute_energy, EnergyConfig
|
||||
|
||||
config = EnergyConfig(data_weight=1.0, reg_weight=0.01)
|
||||
energy = compute_energy(params, activities, inputs, targets, config=config)
|
||||
```
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
src/neurograph/
|
||||
├── core/ # Energy functions, equilibrium dynamics, neurons
|
||||
├── learning/ # EqProp, reward modulation, Hebbian rules
|
||||
├── pruning/ # Magnitude/activity-based pruning, structural plasticity
|
||||
├── architecture/ # Graph NAS, topology mutation
|
||||
├── env/ # Gymnasium wrappers
|
||||
└── utils/ # Visualization, logging, metrics
|
||||
```
|
||||
447
docs/design.md
Normal file
447
docs/design.md
Normal file
@ -0,0 +1,447 @@
|
||||
# NeuroGraph 架构设计文档
|
||||
|
||||
> 类生物能量梯度深度学习框架
|
||||
|
||||
---
|
||||
|
||||
## 1. 项目定位
|
||||
|
||||
NeuroGraph 探索的是区别于传统反向传播(Backpropagation, BP)的全新训练模式。核心思想是:大脑中的神经元不会做链式求导,它们通过局部活动模式和能量最小化来学习。
|
||||
|
||||
### 1.1 与传统 BP 的区别
|
||||
|
||||
| 维度 | 传统 BP | NeuroGraph |
|
||||
|------|---------|-----------|
|
||||
| 梯度来源 | 解析链式法则 | 能量景观松弛过程 |
|
||||
| 信息流 | 前向 + 反向传播 | 双向活动传播(自由相 + 推动相) |
|
||||
| 学习信号 | Loss 梯度 | 能量差 (∂E_nudged - ∂E_free) |
|
||||
| 拓扑结构 | 人工设计固定 | 自主剪枝 + 结构探索 |
|
||||
| 调控机制 | 全局优化器 | 局部规则 + 全局奖罚调制 |
|
||||
|
||||
### 1.2 核心隐喻
|
||||
|
||||
把网络看作一个物理系统:给定输入后,神经元活动会像弹簧一样"松弛"到能量最低状态。学习不是计算梯度,而是比较"自由松弛态"和"被目标推动后的松弛态"之间的能量差。
|
||||
|
||||
---
|
||||
|
||||
## 2. 技术栈
|
||||
|
||||
### 2.1 核心依赖
|
||||
|
||||
```
|
||||
JAX → 自动微分 + JIT 编译 + 向量化
|
||||
Equinox → 神经网络模块系统(eqx.Module)
|
||||
Optax → 优化器链(SGD, Adam, 自定义变换)
|
||||
Gymnasium → RL 环境标准接口
|
||||
JaxPruner → 稀疏掩码基础设施
|
||||
```
|
||||
|
||||
### 2.2 为什么不用 C/C++ 手撸
|
||||
|
||||
- 没有自动微分,能量梯度验证极其困难
|
||||
- 没有现成的 RL 环境(Gymnasium 有 200+ 环境)
|
||||
- 没有 GPU 加速的张量运算库
|
||||
- 研究迭代速度会慢 10 倍以上
|
||||
|
||||
JAX 本身就是"手撸友好"的:它是 NumPy + 自动微分 + JIT,不强迫你接受任何高层抽象。你可以从零定义能量函数、松弛动力学、学习规则,同时拥有 GPU 加速和自动微分。
|
||||
|
||||
### 2.3 为什么不用 Julia
|
||||
|
||||
Julia 的科学计算生态很强,但:
|
||||
- 生物可塑性学习几乎没有现成参考实现
|
||||
- GPU 生态不如 JAX 成熟
|
||||
- 社区规模小得多
|
||||
|
||||
---
|
||||
|
||||
## 3. 核心算法
|
||||
|
||||
### 3.1 能量函数
|
||||
|
||||
网络的状态(活动量 + 权重)定义了一个能量景观。能量越低,状态越"好"。
|
||||
|
||||
```
|
||||
E(θ, a, x, y) = w_data · E_data + w_reg · E_reg + w_sparse · E_sparse + w_struct · E_struct
|
||||
```
|
||||
|
||||
| 项 | 公式 | 作用 |
|
||||
|----|------|------|
|
||||
| E_data | ‖f_θ(x) - y‖² | 预测误差 |
|
||||
| E_reg | λ · ‖θ‖² | 权重衰减 |
|
||||
| E_sparse | γ · ‖a‖₁ | 激活稀疏性 |
|
||||
| E_struct | η · ‖A‖² | 拓扑复杂度 |
|
||||
|
||||
### 3.2 平衡传播(Equilibrium Propagation, EqProp)
|
||||
|
||||
这是 Scellier & Bengio (2017) 提出的算法,核心步骤:
|
||||
|
||||
```
|
||||
Step 1: 自由相(Free Phase)
|
||||
输入 x 固定到可见层
|
||||
活动量沿能量梯度松弛:a ← a - η_a · ∂E/∂a
|
||||
直到收敛 → 记录 a_free, E_free
|
||||
|
||||
Step 2: 推动相(Nudged Phase)
|
||||
输出层被目标 y 轻微推动:a_out += β · (y - a_out)
|
||||
继续松弛直到收敛 → 记录 a_nudged, E_nudged
|
||||
|
||||
Step 3: 权重更新
|
||||
Δθ = -(∂E(θ, a_nudged)/∂θ - ∂E(θ, a_free)/∂θ) / β
|
||||
```
|
||||
|
||||
**关键洞察**:权重更新量正比于两个平衡态之间的能量梯度差。当 β → 0 时,EqProp 等价于 BP。
|
||||
|
||||
### 3.3 三因子学习规则(奖罚调制)
|
||||
|
||||
在生物大脑中,突触可塑性受神经递质(如多巴胺)调制:
|
||||
|
||||
```
|
||||
Δw_ij = reward · e_ij · (pre_i · post_j)
|
||||
```
|
||||
|
||||
- `pre_i`:突触前神经元活动
|
||||
- `post_j`:突触后神经元活动
|
||||
- `e_ij`:eligibility trace(资格迹),记录历史活动关联
|
||||
- `reward`:全局奖罚信号(+1 奖励,-1 惩罚)
|
||||
|
||||
这使得学习信号不再是解析梯度,而是"行为后果"的函数。
|
||||
|
||||
### 3.4 自主剪枝 + 结构可塑性
|
||||
|
||||
网络自主决定哪些连接有用:
|
||||
|
||||
```
|
||||
删除边:|w_ij| < threshold 且 活动贡献度低
|
||||
添加边:节点 i 和 j 的活动互信息高 → 建立新连接
|
||||
```
|
||||
|
||||
这模拟了大脑的突触修剪(synaptic pruning)和突触发生(synaptogenesis)。
|
||||
|
||||
### 3.5 可微图结构搜索(DARTS 风格)
|
||||
|
||||
网络拓扑用邻接矩阵 A[i,j] 表示,每个位置有一组可学习的架构参数 α_ij:
|
||||
|
||||
```
|
||||
P(edge_ij exists) = sigmoid(α_ij)
|
||||
```
|
||||
|
||||
架构参数和网络权重联合优化,训练结束后离散化:保留 sigmoid(α_ij) > 0.5 的边。
|
||||
|
||||
---
|
||||
|
||||
## 4. 架构设计
|
||||
|
||||
### 4.1 模块划分
|
||||
|
||||
```
|
||||
src/neurograph/
|
||||
│
|
||||
├── core/ # 核心引擎(最底层)
|
||||
│ ├── energy.py # 能量函数定义
|
||||
│ ├── equilibrium.py # 松弛循环动力学
|
||||
│ ├── neurons.py # 神经元模型
|
||||
│ ├── surrogate.py # 代理梯度
|
||||
│ └── topology.py # 图拓扑表示
|
||||
│
|
||||
├── learning/ # 学习规则(依赖 core)
|
||||
│ ├── eqprop.py # 平衡传播更新
|
||||
│ ├── reward.py # 奖罚调制
|
||||
│ ├── hebbian.py # Hebbian 局部规则
|
||||
│ └── optimizer.py # Optax 包装
|
||||
│
|
||||
├── pruning/ # 剪枝(依赖 core + learning)
|
||||
│ ├── magnitude.py # 幅值剪枝
|
||||
│ ├── activity.py # 活跃度剪枝
|
||||
│ └── structural.py # 结构可塑性
|
||||
│
|
||||
├── architecture/ # 结构探索(依赖 pruning)
|
||||
│ ├── graph_nas.py # 可微图搜索
|
||||
│ ├── mutation.py # 拓扑变异
|
||||
│ └── supernet.py # 超网络
|
||||
│
|
||||
├── env/ # 环境接口
|
||||
│ ├── wrapper.py # Gymnasium 适配
|
||||
│ └── simple/ # 内置简单环境
|
||||
│
|
||||
└── utils/ # 工具
|
||||
├── visualization.py # 可视化
|
||||
├── logging.py # 日志
|
||||
└── metrics.py # 指标
|
||||
```
|
||||
|
||||
### 4.2 依赖关系
|
||||
|
||||
```
|
||||
core ← learning ← pruning ← architecture
|
||||
↑
|
||||
env (提供交互数据)
|
||||
```
|
||||
|
||||
### 4.3 网络拓扑表示
|
||||
|
||||
用邻接矩阵 `[N, N]` 表示网络连接:
|
||||
|
||||
```python
|
||||
class NetworkTopology(eqx.Module):
|
||||
adjacency: jnp.ndarray # [N, N] 邻接矩阵
|
||||
node_types: jnp.ndarray # [N] 节点类型 (0=visible, 1=hidden, 2=output)
|
||||
edge_ops: jnp.ndarray # [N, N] 操作类型索引
|
||||
node_states: jnp.ndarray # [N, state_dim] 当前活动
|
||||
```
|
||||
|
||||
**选择理由**:
|
||||
- 可微(矩阵元素是连续值)
|
||||
- 直接支持矩阵乘法做消息传递
|
||||
- 支持 DARTS 风格优化(学习 edge_weights)
|
||||
- 大网络时可切换为 `jax.experimental.sparse.BCSR`
|
||||
|
||||
### 4.4 数据流:完整训练步
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────┐
|
||||
│ 训练步 (training step) │
|
||||
├─────────────────────────────────────────────────────────┤
|
||||
│ │
|
||||
│ ① 输入钳位 │
|
||||
│ activities[visible] = observations │
|
||||
│ │
|
||||
│ ② 自由相松弛 (jax.lax.while_loop) │
|
||||
│ for t in max_iters: │
|
||||
│ grad = jax.grad(energy, argnums=1)(params, acts) │
|
||||
│ acts -= activity_lr * grad │
|
||||
│ if converged: break │
|
||||
│ → activities_free, energy_free │
|
||||
│ │
|
||||
│ ③ 推动相松弛 │
|
||||
│ acts[output] += β * (target - acts[output]) │
|
||||
│ 同②松弛 │
|
||||
│ → activities_nudged, energy_nudged │
|
||||
│ │
|
||||
│ ④ 权重更新 (EqProp) │
|
||||
│ grad_free = jax.grad(energy, 0)(params, acts_free) │
|
||||
│ grad_nudged = jax.grad(energy, 0)(params, acts_nud) │
|
||||
│ delta_w = -(grad_nudged - grad_free) / β │
|
||||
│ delta_w *= reward_signal # 如有奖罚调制 │
|
||||
│ params = optax.apply_updates(params, delta_w) │
|
||||
│ │
|
||||
│ ⑤ 剪枝 (每 N 步) │
|
||||
│ mask = pruning_scheduler(params, activities, step) │
|
||||
│ params *= mask │
|
||||
│ │
|
||||
│ ⑥ 结构探索 (每 M 步) │
|
||||
│ topology = graph_nas.evaluate_and_mutate(topology) │
|
||||
│ │
|
||||
│ ⑦ 环境步进 (RL 模式) │
|
||||
│ action = read_output(acts_nudged) │
|
||||
│ next_obs, reward, done = env.step(action) │
|
||||
│ │
|
||||
└─────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
整个 ②-④ 步用 `jax.jit` 编译为一个函数。⑤⑥ 步在 JIT 外部执行(改变参数树形状)。
|
||||
|
||||
---
|
||||
|
||||
## 5. 开发阶段
|
||||
|
||||
### Phase 1:能量学习引擎(核心)
|
||||
|
||||
**目标**:MNIST >95% 准确率
|
||||
|
||||
| 任务 | 文件 | 说明 |
|
||||
|------|------|------|
|
||||
| 能量函数 | `core/energy.py` | 复合能量定义 |
|
||||
| 松弛循环 | `core/equilibrium.py` | jax.lax.while_loop |
|
||||
| 神经元模型 | `core/neurons.py` | rate-based, LIF |
|
||||
| 代理梯度 | `core/surrogate.py` | jax.custom_jvp |
|
||||
| 拓扑表示 | `core/topology.py` | 邻接矩阵 + 节点类型 |
|
||||
| EqProp 更新 | `learning/eqprop.py` | 核心学习规则 |
|
||||
| Optax 包装 | `learning/optimizer.py` | 可组合优化器 |
|
||||
| MNIST 实验 | `experiments/phase1/mnist_eqprop.py` | 基线验证 |
|
||||
|
||||
### Phase 2:奖罚系统 + 环境集成
|
||||
|
||||
**目标**:CartPole-v1 奖励 >400
|
||||
|
||||
| 任务 | 文件 | 说明 |
|
||||
|------|------|------|
|
||||
| 三因子规则 | `learning/reward.py` | reward × eligibility × gradient |
|
||||
| 资格迹 | `learning/reward.py` | 历史活动关联追踪 |
|
||||
| 环境包装 | `env/wrapper.py` | Gymnasium → energy agent |
|
||||
| CartPole 实验 | `experiments/phase2/cartpole_reward.py` | RL 基线 |
|
||||
|
||||
### Phase 3:自主剪枝 + 结构探索
|
||||
|
||||
**目标**:CIFAR-10 >80%,参数量 <50%
|
||||
|
||||
| 任务 | 文件 | 说明 |
|
||||
|------|------|------|
|
||||
| 幅值剪枝 | `pruning/magnitude.py` | 权重幅值阈值 |
|
||||
| 活跃度剪枝 | `pruning/activity.py` | 活动贡献度 |
|
||||
| 结构可塑性 | `pruning/structural.py` | 增删边 |
|
||||
| 图 NAS | `architecture/graph_nas.py` | 可微架构搜索 |
|
||||
| 拓扑变异 | `architecture/mutation.py` | 离散变异算子 |
|
||||
| CIFAR 实验 | `experiments/phase3/pruning_dynamics.py` | 剪枝 + NAS |
|
||||
|
||||
### Phase 4:全系统集成
|
||||
|
||||
- 端到端流水线
|
||||
- 多基准测试 + 消融实验
|
||||
- 文档 + 教程
|
||||
|
||||
---
|
||||
|
||||
## 6. 关键 JAX API 使用指南
|
||||
|
||||
### 6.1 松弛循环
|
||||
|
||||
```python
|
||||
def relax(energy_fn, params, activities, inputs, targets, nudging, config):
|
||||
"""松弛到能量最低态。"""
|
||||
|
||||
def energy_at(acts):
|
||||
return energy_fn(params, acts, inputs, targets, nudging, config)
|
||||
|
||||
def cond(state):
|
||||
iters, _, curr_energy = state
|
||||
new_e = energy_at(state[1])
|
||||
return (iters < config.max_iters) & (jnp.abs(curr_energy - new_e) > config.tol)
|
||||
|
||||
def body(state):
|
||||
iters, acts, _ = state
|
||||
grad = jax.grad(energy_at)(acts)
|
||||
new_acts = acts - config.activity_lr * grad
|
||||
new_e = energy_at(new_acts)
|
||||
return iters + 1, new_acts, new_e
|
||||
|
||||
init = (0, activities, jnp.inf)
|
||||
return jax.lax.while_loop(cond, body, init)
|
||||
```
|
||||
|
||||
### 6.2 代理梯度(LIF 脉冲神经元)
|
||||
|
||||
```python
|
||||
@jax.custom_jvp
|
||||
def spike_fn(x, threshold=1.0):
|
||||
return jnp.where(x >= threshold, 1.0, 0.0)
|
||||
|
||||
@spike_fn.defjvp
|
||||
def spike_fn_jvp(primals, tangents):
|
||||
x, = primals
|
||||
dx, = tangents
|
||||
# 高斯代理梯度
|
||||
surrogate = jnp.exp(-0.5 * ((x - threshold) / 0.5) ** 2)
|
||||
return spike_fn(x, threshold), surrogate * dx
|
||||
```
|
||||
|
||||
### 6.3 编译策略
|
||||
|
||||
```python
|
||||
# 整个训练步 JIT 编译
|
||||
@jax.jit
|
||||
def train_step(params, activities, inputs, targets, reward, opt_state):
|
||||
# 自由相
|
||||
acts_free, E_free, _ = relax(energy_fn, params, activities, inputs, None, 0.0, cfg)
|
||||
# 推动相
|
||||
acts_nudged, E_nudged, _ = relax(energy_fn, params, acts_free, inputs, targets, β, cfg)
|
||||
# 权重更新
|
||||
grads = eqprop_grads(params, acts_free, acts_nudged, β)
|
||||
grads *= reward # 奖罚调制
|
||||
updates, opt_state = optimizer.update(grads, opt_state)
|
||||
params = optax.apply_updates(params, updates)
|
||||
return params, acts_nudged, opt_state, {"E_free": E_free, "E_nudged": E_nudged}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 7. 验证策略
|
||||
|
||||
### 7.1 单元测试
|
||||
|
||||
| 测试 | 验证 |
|
||||
|------|------|
|
||||
| 能量单调下降 | 松弛过程中能量严格递减 |
|
||||
| EqProp ≈ BP | 小网络上 EqProp 收敛到与 BP 相同解 |
|
||||
| 代理梯度 | spike_fn 的 jax.grad 非零且平滑 |
|
||||
| 剪枝保精度 | 50% 幅值剪枝后准确率 >80% 原始 |
|
||||
| 拓扑合法性 | 所有生成的图都是有效 DAG |
|
||||
|
||||
### 7.2 集成测试基准
|
||||
|
||||
| 基准 | 目标 |
|
||||
|------|------|
|
||||
| MNIST (EqProp) | >95% 准确率, 20 epochs |
|
||||
| MNIST (reward-modulated) | >93% 准确率 |
|
||||
| CartPole-v1 | 500 episodes 内奖励 >400 |
|
||||
| CIFAR-10 + 剪枝 | >80% 准确率, 参数量减半 |
|
||||
|
||||
### 7.3 与 BP 对比
|
||||
|
||||
| 基准 | BP 基线 | NeuroGraph | 通过标准 |
|
||||
|------|---------|-----------|----------|
|
||||
| MNIST | >99% | >97% | 误差 <2% |
|
||||
| CIFAR-10 | >85% | >80% | 误差 <5% |
|
||||
| CartPole | ~500 episodes | ~800 episodes | 2x 内可解 |
|
||||
|
||||
### 7.4 诊断指标
|
||||
|
||||
- **能量轨迹**:能量 vs 松弛迭代(应单调递减)
|
||||
- **自由-推动间隙**:`E_nudged - E_free`(应与 loss 正相关)
|
||||
- **权重更新幅值**:`‖Δw‖` 每层(应稳定,不爆炸/消失)
|
||||
- **稀疏度演化**:零权重比例随训练步数
|
||||
- **拓扑演化**:节点/边数量随训练步数
|
||||
|
||||
---
|
||||
|
||||
## 8. 风险与缓解
|
||||
|
||||
| 风险 | 严重性 | 缓解措施 |
|
||||
|------|--------|---------|
|
||||
| 平衡不收敛 | 高 | 阻尼更新 + 梯度裁剪 + 自适应 activity_lr |
|
||||
| 能量函数局部极小 | 高 | 小推动力 β + 良好初始化 + 复合正则 |
|
||||
| 结构搜索时 JAX 形状不匹配 | 高 | 固定最大邻接矩阵 + 掩码 |
|
||||
| 剪枝不可逆退化 | 中 | 渐进式剪枝 + 结构可塑性恢复边 |
|
||||
| 奖罚信号太稀疏 | 中 | 奖励塑形 + 好奇心内在奖励 |
|
||||
| 深层平衡梯度消失 | 高 | 残差连接 + 逐层松弛 |
|
||||
|
||||
---
|
||||
|
||||
## 9. 参考资源
|
||||
|
||||
### 核心论文
|
||||
|
||||
1. Scellier & Bengio (2017). *Equilibrium Propagation: Energy-based Learning for Feedforward and Recurrent Networks*
|
||||
2. LeCun (2006). *A Tutorial on Energy-Based Learning*
|
||||
3. Millidge, Tsang, & Rao (2021). *Predictive Coding: a Theoretical and Experimental Review*
|
||||
4. Liu et al. (2019). *DARTS: Differentiable Architecture Search*
|
||||
5. Bi & Poo (2019). *Biologically plausible learning in deep networks*
|
||||
|
||||
### 参考实现
|
||||
|
||||
| 项目 | 语言 | 链接 |
|
||||
|------|------|------|
|
||||
| Equilibrium-Propagation | PyTorch | github.com/Laborieux-Axel/Equilibrium-Propagation |
|
||||
| EB-JEPA | PyTorch | github.com/facebookresearch/eb_jepa |
|
||||
| BioTorch | PyTorch | github.com/jsalbert/biotorch |
|
||||
| JPC (Predictive Coding) | JAX | github.com/thebuckleylab/jpc |
|
||||
| snnTorch | PyTorch | github.com/jeshraghian/snntorch |
|
||||
| JaxPruner | JAX | github.com/google-research/jaxpruner |
|
||||
| awesome-ebm | 列表 | github.com/yataobian/awesome-ebm |
|
||||
| awesome-biologically-motivated-learning | 列表 | github.com/jsalbert/awesome-biologically-motivated-learning |
|
||||
|
||||
---
|
||||
|
||||
## 10. 术语表
|
||||
|
||||
| 术语 | 英文 | 说明 |
|
||||
|------|------|------|
|
||||
| 能量函数 | Energy Function | 定义网络状态的标量函数,越低越好 |
|
||||
| 平衡传播 | Equilibrium Propagation (EqProp) | 通过两个平衡态的能量差更新权重 |
|
||||
| 自由相 | Free Phase | 仅有输入,无目标推动的松弛过程 |
|
||||
| 推动相 | Nudged Phase | 输出层被目标轻微推动后的松弛过程 |
|
||||
| 松弛 | Relaxation | 活动量沿能量梯度下降的过程 |
|
||||
| 三因子规则 | Three-factor Rule | pre × post × neuromodulator 的突触更新 |
|
||||
| 资格迹 | Eligibility Trace | 记录历史活动关联的衰减记忆 |
|
||||
| 代理梯度 | Surrogate Gradient | 不可导函数的近似梯度(如脉冲函数) |
|
||||
| 结构可塑性 | Structural Plasticity | 增删突触连接的能力 |
|
||||
| 图结构搜索 | Graph Neural Architecture Search | 在图拓扑空间中搜索最优架构 |
|
||||
38
pyproject.toml
Normal file
38
pyproject.toml
Normal file
@ -0,0 +1,38 @@
|
||||
[project]
|
||||
name = "neurograph"
|
||||
version = "0.1.0"
|
||||
description = "Bio-inspired energy-gradient deep learning framework"
|
||||
requires-python = ">=3.10"
|
||||
|
||||
dependencies = [
|
||||
"jax>=0.4.30",
|
||||
"jaxlib>=0.4.30",
|
||||
"equinox>=0.11.0",
|
||||
"optax>=0.2.0",
|
||||
"gymnasium>=1.0.0",
|
||||
"numpy>=1.26.0",
|
||||
"matplotlib>=3.8.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=8.0",
|
||||
"ruff>=0.4.0",
|
||||
]
|
||||
gpu = [
|
||||
"jax[cuda12]>=0.4.30",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools>=68.0"]
|
||||
build-backend = "setuptools.backends._legacy:_Backend"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py310"
|
||||
line-length = 100
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
3
src/neurograph/__init__.py
Normal file
3
src/neurograph/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
"""NeuroGraph — Bio-inspired energy-gradient deep learning framework."""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
1
src/neurograph/architecture/__init__.py
Normal file
1
src/neurograph/architecture/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""Architecture search and topology mutation."""
|
||||
1
src/neurograph/core/__init__.py
Normal file
1
src/neurograph/core/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""Core energy-based learning engine."""
|
||||
55
src/neurograph/core/energy.py
Normal file
55
src/neurograph/core/energy.py
Normal file
@ -0,0 +1,55 @@
|
||||
"""Energy function definitions.
|
||||
|
||||
The energy function is the foundation of the framework.
|
||||
It defines a scalar energy landscape over network activities
|
||||
and parameters, such that learning occurs through relaxation
|
||||
to energy minima rather than analytic backpropagation.
|
||||
|
||||
Energy composition:
|
||||
E = E_data + E_regularization + E_sparsity + E_structural
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EnergyConfig:
|
||||
"""Weights for each energy term."""
|
||||
|
||||
data_weight: float = 1.0
|
||||
reg_weight: float = 0.01
|
||||
sparse_weight: float = 0.001
|
||||
structural_weight: float = 0.0
|
||||
|
||||
|
||||
def compute_energy(
|
||||
params: dict,
|
||||
activities: jax.Array,
|
||||
inputs: jax.Array,
|
||||
targets: jax.Array | None = None,
|
||||
nudging_strength: float = 0.0,
|
||||
config: EnergyConfig = EnergyConfig(),
|
||||
) -> jax.Array:
|
||||
"""Compute composite energy.
|
||||
|
||||
Args:
|
||||
params: Network parameters (weights, biases).
|
||||
activities: Current neuron activities.
|
||||
inputs: Input data clamped to visible nodes.
|
||||
targets: Target outputs (None for free phase).
|
||||
nudging_strength: How strongly to push output toward targets.
|
||||
config: Energy term weights.
|
||||
|
||||
Returns:
|
||||
Scalar energy value.
|
||||
"""
|
||||
# TODO: Implement energy terms
|
||||
# E_data: prediction error
|
||||
# E_reg: weight decay
|
||||
# E_sparse: activity sparsity
|
||||
# E_structural: topology complexity prior
|
||||
|
||||
raise NotImplementedError("Implement your energy function here.")
|
||||
1
src/neurograph/env/__init__.py
vendored
Normal file
1
src/neurograph/env/__init__.py
vendored
Normal file
@ -0,0 +1 @@
|
||||
"""Environment interface and wrappers."""
|
||||
1
src/neurograph/learning/__init__.py
Normal file
1
src/neurograph/learning/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""Learning rules and update mechanisms."""
|
||||
1
src/neurograph/pruning/__init__.py
Normal file
1
src/neurograph/pruning/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""Autonomous pruning and structural plasticity."""
|
||||
1
src/neurograph/utils/__init__.py
Normal file
1
src/neurograph/utils/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""Utilities — visualization, logging, metrics."""
|
||||
Loading…
Reference in New Issue
Block a user