STIDGCN - 学习笔记
STIDGCN - 学习笔记
1. 文章摘要
准确的交通预测对于城市交通管理、路线规划和流量检测至关重要。时空模型的最新进展显着改进了交通预测中复杂的时空相关性的建模。不幸的是,之前的大多数研究在跨不同感知视角有效建模时空相关性方面遇到了挑战,并且忽略了时空相关性之间的交互学习。此外,受空间异质性的限制,大多数研究未能考虑每个节点不同的时空模式。为了克服这些限制,我们提出了一种用于流量预测的时空交互式动态图卷积网络(STIDGCN)。具体来说,文章提出了一个由空间和时间模块组成的交互式学习框架,用于对流量数据进行下采样。该框架旨在通过采用从全局到局部的感知视角来捕捉空间和时间的相关性,并通过积极的反馈促进它们的相互利用。在空间模块中,我们基于图构造方法设计了动态图卷积网络。该网络旨在利用考虑时空异质性的流量模式库作为查询来重建数据驱动的动态图结构。重构的图结构可以揭示交通网络中节点之间的动态关联。对八个真实世界流量数据集的大量实验表明,STIDGCN 在平衡计算成本的同时优于最先进的基线。
源代码:
源代码
论文:
论文
2. 模型核心结构
- STI结构把输入分成多个序列,并且向下游不断分裂,形成类似二叉树的结构。这样做的目的:类似时序卷积,例如原始序列id:12345,分裂后序列A就是135,B就是246,这样序列A就能侧重分析到索引1和3之间的关系,如果原始数据是五分钟级别的,做一次STI可以认为变成关注十分钟级别的,两次就是关注二十分钟级别的。
- 创新点:创新了一个DGCN模块,把原始序列分成奇序列(
)和偶序列
,送入时序卷积块中提取卷积信息,然后送到DGCN块中提取空间信息。最重要的一点是是交互学习,也就是看图中的红线,奇序列经过图模块提取完空间信息后的隐向量将会与偶序列进行哈德马积(红线、绿线),反过来偶数序列时序卷积、DGCN后也和奇序列进行哈德玛积。这个过程会进行两次,因此形成一个时空交互的结构,这个过程是本文最大的创新点。
3. 各模块介绍
3.1 Encoder
STI模块
STI模块里面包含时间卷积TSConv模块和图卷积DGCN模块,两个模块在STI模块中进行交互学习。
TSConv模块
文章提出使用TSConv模块作为时间模块,捕捉时间关联性。
时序卷积模块使用二维的CNN对padding后的序列进行卷积,两层卷积分别使用(stride = 1,padding = s1,)、(1,s2)作为卷积核,s1、s2是预先定义好的核尺寸,文章后续对该参数进行了敏感性分析。TSConv可与定义为如下公式:
其中H_t 和 H_t’ 代表了TSConv的隐状态,这里省略了激活函数。
通过两层的卷积,能够提取到单个序列上的时序的动态性
DGCN模块
文章设计了一个权重共享的DGCN模块,作为空间模块,捕捉空间动态相关性。
因为文章处理的是动态图,预先不知道图的邻接矩阵,得通过DGCN结构提取空间状态表征
分成两个步骤:
①动态图重构:模拟动态的邻接矩阵
②对于构建的动态图,聚合周边节点的信息
输入是TSConv模块学习后的嵌入表示,这里输入的维度是Hg∈R(C×N×t’),C表示隐藏层的维度,channel,N是节点的个数,t’是时序的长度。
这里由于输入的 尺寸不一样,因此先经过一个全连接层,得到聚合输入Hf ∈R(C×N),随后和模式库(φ,Pattern Bank)进行交叉注意力(Hf和φ,得到Ap)和自注意力(Hf和Hf,得到Ah)计算。这里Pattern Bank是一个可训练的矩阵,φ∈R(C×N),可以认为是一种节点嵌入。得到嵌入矩阵后,进行交叉注意力【公式(6)】和自注意力【公式(7)】计算。:
如此操作后,会得到两个邻接矩阵,Ap和Ah,这是两个矩阵,大小都是N×N
然后把得到的两个空间注意力邻接矩阵使用Concat操作拼接,得到一个2N×N的矩阵,为了和下游的尺度对齐,又经过一个全连接层,这一步的目的是把矩阵的尺度从N×2N变成N×N
经过全连接层输出Af矩阵,可以认为Af矩阵中包含了节点和节点之间的空间关联关系,这是一个N×N的动态邻接矩阵。用这种方式计算的邻接矩阵会计算所有节点的空间关联性,然而实际上在图中不是所有节点都连接在一起的。因此还需要屏蔽掉不相关的节点。文章是对Af与矩阵M进行哈德玛积,矩阵M可以认为是注意力里面的Mask,取得是一个节点与他最相邻的K个邻居,通过Top-K的方式选取。(更通俗来说,一个节点i分别和各个节点计算相关性,最相关的K个节点标记1,否则标记0.)
得到邻接矩阵就可以进行图卷积操作了,文章采用扩散GCN进行动态图卷积,扩散图卷积把节点的动态变化描述成一个“扩散”过程,扩散图卷积聚合了图中节点之间的信息。扩散信号院子目标节点和当前最近的节点。扩散GCN可以描述为:
这里的Hg就是最开始的TSConv表征后的隐向量,尺寸是C×N×t’,W 是自学习权重的矩阵,尺寸是N×N,Af是刚刚得到的邻接矩阵,大小N×N。这个地方矩阵乘的维度没有写的很清楚,纳闷了很久维度不一样怎么乘,看代码,我们可以看到矩阵相乘是在N维度上相乘,最终输出还是B×C×N×T(TODO 单独实验一下DGCN模块的输入、输出,验证)。
时空交互学习
最终整个交互学习的过程,用公式的话可以描述成,和首图是对得上的:
3.2 Decoder
解码器用于输入编码器的编码后的特征He进行解码,得到最终的计算结果Y。首先隐特征He送入GLU门控线性单元,门控单元就是将Hg分别经过两层FC,其中一层加上sigmoid激活函数,并进行哈德玛积。经过激活和FC后得到最终的预测结果Y。文章指出:没有采用自回归的方式生成y设结果是为了提高计算效率、缩小误差积累。
4. 实验
4.1 对比实验
模型采用了PEMS多个数据集,在多个数据集上进行测试,取得了SOTA的效果。
评价指标选的是MAE,MAPE,RMSE
4.2 消融实验
消融实验:进行了四组消融实验:
w/o IL:STIDGCN替换掉时空交互学习模块,使用串行策略进行学习。
w/o TSConv: 移除TSConv模块
w/o GG:移除图生成模块
w/o DGCN: 使用普通GCN替换DGCN,并且输入的邻接矩阵采用预定义图