时空图神经网络原理及Pytorch实现
发布时间:2024-04-30
在我们生活的这个充满联系的世界中,从微观的分子结构到宏观的社交网络,再到复杂的城市设计结构,都隐藏着一张张相互关联的图数据。这些图数据仿佛一张张神秘的网,将世界万物紧密相连。而图神经网络(GNN)作为一种革命性的技术,正以其强大的能力,逐渐揭开这些图数据的面纱,让我们能够更深入地理解和利用它们。
图神经网络的出现,为我们提供了一种全新的建模和学习方式。它不仅能够捕捉数据的空间结构,还能够揭示图结构中的复杂关系。无论是在生物学领域,如蛋白质结构分析和药物发现,还是在社会学领域,如社交网络模拟和舆情分析,图神经网络都展现出了惊人的应用潜力。
更令人兴奋的是,图神经网络还可以与其他机器学习模型进行融 合,形成更加强大的模型。例如,将图神经网络与序列模型结合,形成时空图神经网络(Spatail-Temporal Graph),不仅能够捕捉数据的时间和空间依赖性,还能够更全面地揭示数据的内在规律和趋势。这种融合模型的出现,为各个领域的研究和应用带来了更多的可能性。
在时空图神经网络中,时间维度被巧妙地引入到了图结构中。这意味着,原本静止的节点特征现在会随着时间的推移而发生变化。这种变化不仅反映了节点之间的动态关系,还为我们提供了更丰富的信息,使我们能够更准确地预测和分析各种复杂现象。
不过,GNN模型和序列模型(如简单RNN、LSTM或GRU)本身就复杂。结合这些模型以处理空间和时间依赖性是强大的,但也很复杂:难以理解,也难以实现。
所以在 这篇文章中,我们将深入探讨这些模型的原理,并实现一个相对简单的示例,以更深入地理解它们的能力和应用。
图神经网络(GNN) 我们先介绍一些入门的知识简要讨论GNN。
图G可以定义为G = (V, E),其中V是节点集,E是它们之间的边。
一个包含n个节点的图的特征矩阵,每个节点具有f个特征,是所有特征的连接:
GNN的关键问题是所有连接节点之间的消息传递,这种邻居特征转换和聚合可以写成:
A是图的邻接矩阵,I是允许自连接的单位矩阵。虽然这不是完整的方程,但这已经可以说明可以学习不同节点之间空间依赖性的图卷积网络的基础。一个经典的图神经网络如下图所示:
时空图神经网络 (ST-GNN) ST-GNN中每个时间步都是一个图,并通过GCN/GAT网络传递,以获得嵌入数据空间相互依赖性的结果编码图。然后这些编码图可以像时间序列数据一样进行建模,只要保留每个时间步骤的数据的图结构的完整性。下图演示了这两个步骤,时间模型可以是从ARIMA或简单的循环神经网络或者是transformers的任何序列模型。
我们下面使用简单的循环神经网络来绘制ST-GNN的组件
上面就是ST-GNN的基本原理,将GNN和序列模型(如RNN、LSTM、GRU、Transformers 等)结合。如果你已经熟悉这些序列和GNN模型,那么理论来说是非常简单的,但是实际操作的时候就会有一些复杂,所以我们下面将直接使用Pytorch实现一个简单的ST-GNN。
ST-GNN的Pytorch实现 首先要说明:为了用于演示我将使用大型科技公司的股市数据。虽然这些数据本质上不是图数据,但这种网络可能会捕捉到这些公司之间的 相互依赖性,例如一个公司的表现(好或坏)可能反过来影响市场中其他公司的价值。但这只是一个演示,我们并不建议在股市预测中使用ST-GNN。
加载数据,直接使用yfinance里面什么都有
为了适应ST-GNN,所以我们要将数据进行转换以适应模型的要求
将标量时间序列数据集转换为图形数据结构是一个将传统数据转换为图神经网络可以处理的形式的关键步骤。这里描述的功能和类如下:
邻接矩阵的定义 : AdjacencyMatrix 函数定义了图的邻接矩阵(连通性),这通常是基于手头物理系统的结构来完成的。然而,在这里,作者仅使用 了一个全1矩阵,即所有节点都与所有其他节点相连。 股市数据集类 : StockMarketDataset 类旨在为训练时空图神经网络(ST-GNNs)创建数据集。这个类中包含的方法有: 数据序列生成 : DatasetCreate 方法生成数据序列。 构造图边 : _create_edges 方法使用邻接矩阵构造图的边。 生成数据序列 : _create_sequences 方法通过在输入的股市数据上滑动窗口来生成数据序列。 这种数据准备代码可以很容易地适应其他问题。这包括定义每个时间步的节点间的连接方式,并利用滑动窗口方法提取可以供模型学习的序列特征。通过这种方法,原本简单的时间序列数据被转化为具有复杂关系和时间依赖性的图形数据结构,从而可以使用图神经网络来进行更深入的分析和预测。
训练-验证-测试分割。
我们的模型包括一个GATConv和2个GRU层作为编码器,1个GRU层+全连接层作为解码器。GATconv是GNN部分,可以捕获空间依赖性,GRU层可以捕获数据的时间动态。代码包括大量的数据重塑,这样可以保证每一层的输入维度相同。这也是我们所说的ST-GNN实现中最复杂的部分,所以如果向具体了解输各层输入的维度,可以在向前传递的不同阶段打印x的形状,并将其与GRU和Linear层的预期输入尺寸的文档进行比较。
复制 import torch
import torch. nn. functional as F
from torch_geometric. nn import GATConv
class ST_GNN_Model ( torch. nn. Module) :
def __init__ ( self, in_channels, out_channels, n_nodes, gru_hs_l1, gru_hs_l2, heads= 1 , dropout= 0.01 ) :
super ( ST_GAT, self) . __init__( )
self. n_pred = out_channels
self. heads = heads
self. dropout = dropout
self. n_nodes = n_nodes
self. gru_hidden_size_l1 = gru_hs_l1
self. gru_hidden_size_l2 = gru_hs_l2
self. decoder_hidden_size = self. gru_hidden_size_l2
self. gat = GATConv( in_channels= in_channels, out_channels= in_channels,
heads= heads, dropout= dropout, cnotallow= False )
self. encoder_gru_l1 = torch. nn. GRU( input_size= self. n_nodes,
hidden_size= self. gru_hidden_size_l1, num_layers= 1 ,
bias = True )
self. encoder_gru_l2 = torch. nn. GRU( input_size= self. gru_hidden_size_l1,
hidden_size= self. gru_hidden_size_l2, num_layers = 1 ,
bias = True )
self. GRU_decoder = torch. nn. GRU( input_size = self. gru_hidden_size_l2, hidden_size = self. decoder_hidden_size,
num_layers = 1 , bias = True , dropout= self. dropout)
self. prediction_layer = torch. nn. Linear( self. decoder_hidden_size, self. n_nodes* self. n_pred, bias= True )
def forward ( self, data, device) :
x, edge_index = data. x, data. edge_index
if device == 'cpu' :
x = torch. FloatTensor( x)
else :
x = torch. cuda. FloatTensor( x)
x = self. gat( x, edge_index)
x = F. dropout( x, self. dropout, training= self. training)
batch_size = data. num_graphs
n_node = int ( data. num_nodes / batch_size)
x = torch. reshape( x, ( batch_size, n_node, data. num_features) )
x = torch. movedim( x, 2 , 0 )
encoderl1_outputs, _ = self. encoder_gru_l1( x)
x = F. relu( encoderl1_outputs)
encoderl2_outputs, h2 = self. encoder_gru_l2( x)
x = F. relu( encoderl2_outputs)
x, _ = self. GRU_decoder( x, h2)
x = torch. squeeze( x[ - 1 , : , : ] )
x = self. prediction_layer( x)
x = torch. reshape( x, ( batch_size, self. n_nodes, self. n_pred) )
x = torch. reshape( x, ( batch_size* self. n_nodes, self. n_pred) )
return x
1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13. 14. 15. 16. 17. 18. 19. 20. 21. 22. 23. 24. 25. 26. 27. 28. 29. 30. 31. 32. 33. 34. 35. 36. 37. 38. 39. 40. 41. 42. 43. 44. 45. 46. 47. 48. 49. 50.
训练过程与pytorch中的任何网络训练过程几乎相同。
模型训练完成了,下面就可视化模型的预测能力。对于每个数据输入,下面的代码预测模型输出,并随后绘制模型输出与基础真值的关系。
复制 @torch. no_grad ( )
def Extract_results ( model, device, dataloader, type = '' ) :
model. eval ( )
model. to( device)
n = 0
for i, batch in enumerate ( dataloader) :
batch = batch. to( device)
if batch. x. shape[ 0 ] == 1 :
pass
else :
with torch. no_grad( ) :
pred = model( batch, device)
truth = batch. y. view( pred. shape)
if i == 0 :
y_pred = torch. zeros( len ( dataloader) , pred. shape[ 0 ] , pred. shape[ 1 ] )
y_truth = torch. zeros( len ( dataloader) , pred. shape[ 0 ] , pred. shape[ 1 ] )
y_pred[ i, : pred. shape[ 0 ] , : ] = pred
y_truth[ i, : pred. shape[ 0 ] , : ] = truth
n += 1
y_pred_flat = torch. reshape( y_pred, ( len ( dataloader) , batch_size, n_nodes, n_pred) )
y_truth_flat = torch. reshape( y_truth, ( len ( dataloader) , batch_size, n_nodes, n_pred) )
return y_pred_flat, y_truth_flat
def plot_results ( predictions, actual, step, node) :
predictions = torch. tensor( predictions[ : , : , node, step] ) . squeeze( )
actual = torch. tensor( actual[ : , : , node, step] ) . squeeze( )
pred_values_float = torch. reshape( predictions, ( - 1 , ) )
actual_values_float = torch. reshape( actual, ( - 1 , ) )
scatter_trace = go. Scatter(
x= actual_values_float,
y= pred_values_float,
mode= 'markers' ,
marker= dict (
size= 10 ,
opacity= 0.5 ,
color= 'rgba(255,255,255,0)' ,
line= dict (
width= 2 ,
color= 'rgba(152, 0, 0, .8)' ,
)
) ,
name= 'Actual vs Predicted'
)
line_trace = go. Scatter(
x= [ min ( actual_values_float) , max ( actual_values_float) ] ,
y= [ min ( actual_values_float) , max ( actual_values_float) ] ,
mode= 'lines' ,
marker= dict ( color= 'blue' ) ,
name= 'Perfect Prediction'
)
data = [ scatter_trace, line_trace]
layout = dict (
title= 'Actual vs Predicted Values' ,
xaxis= dict ( title= 'Actual Values' ) ,
yaxis= dict ( title= 'Predicted Values' ) ,
autosize= False ,
width= 800 ,
height= 600
)
fig = dict ( data= data, layout= layout)
iplot( fig)
y_pred, y_truth = Extract_results( model, device, test_dataloader, 'Test' )
plot_results( y_pred, y_truth, 9 , 0 )
1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13. 14. 15. 16. 17. 18. 19. 20. 21. 22. 23. 24. 25. 26. 27. 28. 29. 30. 31. 32. 33. 34. 35. 36. 37. 38. 39. 40. 41. 42. 43. 44. 45. 46. 47. 48. 49. 50. 51. 52. 53. 54. 55. 56. 57. 58. 59. 60. 61. 62. 63. 64.
对于6个节点(公司),给出过去50个值,做出10个预测。下面是第一个节点的第10步预测与真值的图。看起来看不错,但并不一定意味着就很好。因为对于时间序列数据,下一个值的最佳估计量总是前一个值。如果没有得到很好的训练,这些模型可以输出与输入数据的最后一个值相似的值,而不是捕获时间动态。
对于给定的节点,我们可以绘制历史输入、预测和真值进行比较,查看预测是否捕获了模式。
第一个节点(Google)在测试数据集的4个不同点上的预测实际上比我想象的要好,其他的看来不怎么样。
总结 我的理解是未来的股票价格不能通过单纯的历史价值自回归来预测,因为股票是由现实世界的事件决定的,这并没有体现在历史价值中。这也就是我们在前面说的不建议在股市预测中使用ST-GNN,我们使用这个数据集只是因为它容易获取。最后不要忘集我们本篇文章的目的,学习ST-GNN的基本概念,以及通过Pytorch代码实现来了解ST-GNN的工作原理。