class DBGNN(Module):
"""Implementation of time-aware graph neural network DBGNN ([Reference paper](https://openreview.net/pdf?id=Dbkqs1EhTr)).
Args:
num_classes: number of classes
num_features: number of features for first order and higher order nodes, e.g. [first_order_num_features, second_order_num_features]
hidden_dims: number of hidden dimensions per each layer in the first/higher order network
p_dropout: drop-out probability
"""
def __init__(self, num_classes: int, num_features: list[int], hidden_dims: list[int], p_dropout: float = 0.0):
super().__init__()
self.num_features = num_features
self.num_classes = num_classes
self.hidden_dims = hidden_dims
self.p_dropout = p_dropout
# higher-order layers
self.higher_order_layers = ModuleList()
self.higher_order_layers.append(GCNConv(self.num_features[1], self.hidden_dims[0]))
# first-order layers
self.first_order_layers = ModuleList()
self.first_order_layers.append(GCNConv(self.num_features[0], self.hidden_dims[0]))
for dim in range(1, len(self.hidden_dims) - 1):
# higher-order layers
self.higher_order_layers.append(GCNConv(self.hidden_dims[dim - 1], self.hidden_dims[dim]))
# first-order layers
self.first_order_layers.append(GCNConv(self.hidden_dims[dim - 1], self.hidden_dims[dim]))
self.bipartite_layer = BipartiteGraphOperator(self.hidden_dims[-2], self.hidden_dims[-1])
# Linear layer
self.lin = torch.nn.Linear(self.hidden_dims[-1], num_classes)
def forward(self, data):
x = data.x
x_h = data.x_h
# First-order convolutions
for layer in self.first_order_layers:
x = F.dropout(x, p=self.p_dropout, training=self.training)
x = F.elu(layer(x, data.edge_index, data.edge_weights))
x = F.dropout(x, p=self.p_dropout, training=self.training)
# Second-order convolutions
for layer in self.higher_order_layers:
x_h = F.dropout(x_h, p=self.p_dropout, training=self.training)
x_h = F.elu(layer(x_h, data.edge_index_higher_order, data.edge_weights_higher_order))
x_h = F.dropout(x_h, p=self.p_dropout, training=self.training)
# Bipartite message passing
x = torch.nn.functional.elu(
self.bipartite_layer((x_h, x), data.bipartite_edge_index, N=data.num_ho_nodes, M=data.num_nodes)
)
x = F.dropout(x, p=self.p_dropout, training=self.training)
# Linear layer
x = self.lin(x)
return x