CfC - 封闭式连续时间神经网络

Created at: 2021-01-09 12:55:21
Language: Python
License: Apache-2.0

闭式连续时间模型

闭式连续时间神经网络 (CfC) 是功能强大的顺序神经信息处理单元。

纸质开放获取:https://www.nature.com/articles/s42256-022-00556-7

阿尔克西夫:https://arxiv.org/abs/2106.13898

要求

  • Python3.6 或更高版本
  • Tensorflow 2.4 或更高版本
  • PyTorch 1.8 或更高版本
  • pytorch-lightning 1.3.0 或更高版本
  • scikit-learn 0.24.2 或更高版本

模块说明

  • tf_cfc.py
    CfC(各种版本)在Tensorflow 2.x中的实现
  • torch_cfc.py
    在 PyTorch 中实现 CfC(各种版本)
  • train_physio.py
    在 PyTorch 中的 Physionet 2012 数据集上训练 CfC 模型(代码改编自 Rubanova 等人,2019 年)
  • train_xor.py
    在 Tensorflow 中的 XOR 数据集上训练 CfC 模型(代码改编自 Lechner & Hasani,2020 年)
  • train_imdb.py
    在 Tensorflow 中的 IMDB 数据集上训练 CfC 模型(代码改编自 Keras 示例网站)
  • train_walker.py
    在 Tensorflow 中的 Walker2d 数据集上训练 CfC 模型(代码改编自 Lechner & Hasani,2020 年)
  • irregular_sampled_datasets.py
    来自莱希纳和哈萨尼的数据集(相同的拆分)(2020)
  • duv_physionet.py
    和来自Rubanova等人的Physionet数据集(相同拆分)(2019)
    duv_utils.py

用法

除以下三个标志之外的所有训练脚本

  • no_gate
    在没有 (1 西格) 部件的情况下运行 CfC
  • minimal
    运行 CfC 直接解决方案
  • use_ltc
    使用半隐式 ODE 求解器而不是 CfC 运行 LTC
  • use_mixed
    将 CfC 的 RNN 状态与 LSTM 混合,以避免梯度消失

如果未提供这些标志,则使用完整的 CfC 模型

例如

python3 train_physio.py

在 Physionet 数据集上训练完整的 CfC 模型。

同样地

train_walker.py --minimal

在 walker2d 数据集上运行直接 CfC 解决方案。

要下载Lechner & Hasani 2020的Walker2d数据集,请运行

source download_dataset.sh

引用

	title = {Closed-form continuous-time neural networks},
	journal = {Nature Machine Intelligence},
	author = {Hasani, Ramin and Lechner, Mathias and Amini, Alexander and Liebenwein, Lucas and Ray, Aaron and Tschaikowski, Max and Teschl, Gerald and Rus, Daniela},
  issn = {2522-5839},
	month = nov,
	year = {2022},
}