General Concepts of the Tensor-based Implementations¶
Prerequisites¶
First, we need to set up our Python environment that has PyTorch, PyTorch Geometric and PathpyG installed. Depending on where you are executing this notebook, this might already be (partially) done. E.g. Google Colab has PyTorch installed by default so we only need to install the remaining dependencies. The DevContainer that is part of our GitHub Repository on the other hand already has all of the necessary dependencies installed.
In the following, we install the packages for usage in Google Colab using Jupyter magic commands. For other environments comment in or out the commands as necessary. For more details on how to install pathpyG especially if you want to install it with GPU-support, we refer to our documentation. Note that %%capture discards the full output of the cell to not clutter this tutorial with unnecessary installation details. If you want to print the output, you can comment %%capture out.
%%capture
# !pip install torch
# !pip install torch_geometric
# !pip install git+https://github.com/pathpy/pathpyG.git
Motivation and Learning Objectives¶
The inner workings of the core classes of PathpyG are based on tensor operations provided by PyTorch and PyTorch Geometric. Especially the creation of higher-order structures using the lift-order functions and the MultiOderModel heavily rely on tensor operations for efficiency reasons. While these implementations are highly optimized, they are very hard to read and understand for newcomers. This tutorial aims to explain the general concepts and ideas behind these implementations in a more accessible way. Additionally, we will provide step-by-step explanations of the core functions in the following sections.
import torch
from torch_geometric.data import Data
from torch_geometric.utils import cumsum, degree, sort_edge_index
import pathpyG as pp
Order-lifting and Line Graph Transformations¶
At the core of creating higher-order models is the lift_order_edge_index function that is essentially a line graph transformation. Given an edge index of a graph and the number of nodes in the graph, this function creates the edge index for the corresponding line graph. Let's look at an example:
mapping = pp.IndexMap(list("abcdef"))
graph = pp.Graph.from_edge_index(
edge_index=torch.tensor([[0, 1, 3, 4, 2, 2, 5], [2, 2, 5, 5, 3, 4, 0]]), mapping=mapping
)
pp.plot(graph, node_label=graph.nodes)
<pathpyG.visualisations.network_plots.StaticNetworkPlot at 0x7fd5d0e990f0>
We can create the line graph for this graph using the lift_order_edge_index function as follows:
second_order_edge_index = pp.algorithms.lift_order.lift_order_edge_index(edge_index=graph.data.edge_index, num_nodes=graph.n)
second_order_mapping = pp.IndexMap(graph.edges)
second_order_data = Data(edge_index=second_order_edge_index, node_sequence=graph.data.edge_index.t())
line_graph = pp.Graph(data=second_order_data, mapping=second_order_mapping)
pp.plot(line_graph, node_label=line_graph.nodes)
<pathpyG.visualisations.network_plots.StaticNetworkPlot at 0x7fd40eee1930>
To create the higher-order PathpyG.Graph, we needed to specify a node_sequence in the Data object. The node sequence above was given by the original edges of the graph. This node_sequence keeps track of which original nodes correspond to which higher-order nodes in the higher-order graph. In a second order graph, each higher-order node corresponds to an edge in the original graph. In a graph of order k, each higher-order node corresponds to a path of length k in the original graph. With this, we can always trace back which higher-order node corresponds to which original nodes.
As long as we have this mapping from higher-order nodes to original nodes, we can always do an additional line graph transformation to create even higher order graphs. Below, we create a third-order graph:
third_order_edge_index = pp.algorithms.lift_order.lift_order_edge_index(edge_index=line_graph.data.edge_index, num_nodes=line_graph.n)
third_order_data = Data(edge_index=third_order_edge_index, node_sequence=torch.cat([line_graph.data.node_sequence[line_graph.data.edge_index[0]], line_graph.data.node_sequence[line_graph.data.edge_index[1]][:, -1:]], dim=1))
third_order_mapping = pp.IndexMap([tuple(seq) for seq in graph.mapping.to_ids(third_order_data.node_sequence).tolist()])
third_order_graph = pp.Graph(data=third_order_data, mapping=third_order_mapping)
pp.plot(third_order_graph, node_label=third_order_graph.nodes)
<pathpyG.visualisations.network_plots.StaticNetworkPlot at 0x7fd40eee3070>
Note that above, we constructed the node_sequence for the third-order graph by concatenating the sequences of the two nodes that form each edge in the second-order graph. However, only the first node in the sequence of the higher-order source and the last node in the sequence of the higher-order target node are different. The middle nodes are the same for both higher-order nodes since they represent the overlapping part of the paths.
Under the Hood of lift_order_edge_index¶
Let us now take a closer look at how the lift_order_edge_index function works under the hood. The whole function essentially only needs 10 lines of code and looks as follows:
def lift_order_edge_index(edge_index: torch.Tensor, num_nodes: int ) -> torch.Tensor:
outdegree = degree(edge_index[0], dtype=torch.long, num_nodes=num_nodes)
outdegree_per_dst = outdegree[edge_index[1]]
num_new_edges = outdegree_per_dst.sum()
ho_edge_srcs = torch.repeat_interleave(outdegree_per_dst)
ptrs = cumsum(outdegree, dim=0)[:-1]
ho_edge_dsts = torch.repeat_interleave(ptrs[edge_index[1]], outdegree_per_dst)
idx_correction = torch.arange(num_new_edges, dtype=torch.long)
idx_correction -= cumsum(outdegree_per_dst, dim=0)[ho_edge_srcs]
ho_edge_dsts += idx_correction
return torch.stack([ho_edge_srcs, ho_edge_dsts], dim=0)
However, what the function does exactly is obfuscated by the heavy use of tensor operations. Let us break down the function step-by-step to understand what is happening internally.
Note
Due to the high complexity of the tensor operations, we will maintain to lines of explanations that try to explain the same concepts with different words. One explanation line will be added to the code snippets as comments and the other explanation line will be provided in the markdown cells between the code snippets.
Edge index must be sorted!
The lift_order_edge_index function assumes that the input edge_index is sorted by source nodes. This is not enforced by the function itself, because we ensure that the edge indices are sorted whenever we create a PathpyG.Graph object. This step is crucial for the correct functioning of the lift_order_edge_index function.
- The function first computes the outdegree of each node in the graph using the
degreefunction fromtorch_geometric.utils. This gives us a tensor containing the number of outgoing edges for each node.
# Compute the outdegree of each node used to get all the edge combinations leading to a higher-order edge
outdegree = degree(graph.data.edge_index[0], dtype=torch.long, num_nodes=graph.n)
print("Outdegree per node:")
for node in graph.nodes:
print(f"\t{node}: {outdegree[graph.mapping.to_idx(node)].item()}")
Outdegree per node: a: 1 b: 1 c: 2 d: 1 e: 1 f: 1
- Next, we map the outdegree values to the destination nodes of each edge in the edge index. This gives us a tensor where each entry corresponds to the outdegree of the target node of each edge.
Note
This helps us because for the line graph transformation, we need to transform each edge into a node and then connect these nodes (previously edges) if a node in the original graph connects them. Therefore, we need to create a higher-order edge for each combination of incoming and outgoing edges for each node in the original graph. The outdegree of the target node tells us how many outgoing edges there are for each target node, which directly translates to how many higher-order edges we need to create for each incoming edge.
# For each center node, we need to combine each outgoing edge with each incoming edge
# We achieve this by creating `outdegree` number of edges for each destination node
# of the old edge index
outdegree_per_dst = outdegree[graph.data.edge_index[1]]
print("\nOutdegree per destination node of each edge:")
for e, outdeg in zip(graph.edges, outdegree_per_dst.tolist()):
print(f"\t{e}: {outdeg}")
Outdegree per destination node of each edge:
('a', 'c'): 2
('b', 'c'): 2
('c', 'd'): 1
('c', 'e'): 1
('d', 'f'): 1
('e', 'f'): 1
('f', 'a'): 1
- Next, we create the source nodes for the higher-order graph. For this, we create a new index that maps the original edges to its index as a higher-order node. This is done by creating a range from 0 to the number of edges in the original graph. We then repeat each index according to the outdegree of the corresponding target node. This way, we create a source node for each combination of incoming and outgoing edges for each target node, which will be the edges in the higher-order graph.
# Use each edge from the edge index as node and assign the new indices in the order of the original edge index
# Each higher order node has one outgoing edge for each outgoing edge of the original destination node
# Since we keep the ordering, we can just repeat each node using the `outdegree_per_dst` tensor
ho_edge_srcs = torch.repeat_interleave(outdegree_per_dst)
print("\nHigher-order edge source indices:\n", ho_edge_srcs.tolist())
print("Higher-order edge sources:\n", graph.mapping.to_ids(graph.data.edge_index[:, ho_edge_srcs]).T)
Higher-order edge source indices: [0, 0, 1, 1, 2, 3, 4, 5, 6] Higher-order edge sources: [['a' 'c'] ['a' 'c'] ['b' 'c'] ['b' 'c'] ['c' 'd'] ['c' 'e'] ['d' 'f'] ['e' 'f'] ['f' 'a']]
- Now, we need to create the target nodes for the higher-order edges. For this, we first need to know where the edges of each node start in the original edge index. We can compute this by calculating the cumulative sum of the outdegree values of all nodes. This gives us a tensor where each entry corresponds to the starting index of the edges for each node in the original edge index.
Cumulative Sum
There is one cumsum implementation in PyTorch and one in PyTorch Geometric. The one in PyTorch Geometric starts with an initial zero value, while the one in PyTorch does not. This means that the torch.cumsum function will give us the end pointers of the edges for each node, while the torch_geometric.utils.cumsum function will give us the start pointers (including a last pointer that is equal to the total number of edges). Therefore, we use the torch_geometric.utils.cumsum function here and remove the last entry afterwards.
# For each node, we calculate pointers of shape (num_nodes,) that indicate the start of the original edges
# (new higher-order nodes) that have the node as source node
ptrs = cumsum(outdegree, dim=0)[:-1]
print("Edge start pointers per node:\n", ptrs.tolist())
Edge start pointers per node: [0, 1, 2, 4, 5, 6]
- With the starting pointers of the edges for each node, we can start with the creation of the target nodes for the higher-order edges. Remember that we assigned the node indices based on the order of edges in the original edge index and ordered the higher-order source nodes accordingly. Therefore, we are essentially going through each edge, and combine it with each outgoing edge of the edges target node to create the higher-order edges. Since the edges are ordered by source nodes, we are going through all nodes in the original graph in order by going through each outgoing edge of each node. This means that for each edge in the original graph, we can look up where the outgoing edges of its target node start in the original edge index using the
ptrstensor we created in the previous step. We then repeat these starting pointers according to the outdegree of the corresponding target node to create a target node for each combination of incoming and outgoing edges for each target node.
# Use these pointers to get the start of the edges for each higher-order src and repeat it `outdegree` times
# Since we keep the ordering, all new higher-order edges that have the same src are indexed consecutively
ho_edge_dsts = torch.repeat_interleave(ptrs[graph.data.edge_index[1]], outdegree_per_dst)
print("Higher-order edge destination indices (before correction):\n", ho_edge_dsts.tolist())
Higher-order edge destination indices (before correction): [2, 2, 2, 2, 4, 5, 6, 6, 0]
- For now, we do not have the correct indices for the higher-order target nodes yet. Since we only repeated the starting pointers of the edges for each target node, we only have the correct offsets for each group of higher-order edges corresponding to each target node. However, within each group, we need to assign the correct indices to the higher-order target nodes. Luckily, we only need to count up from the starting pointer for each group corresponding to one incoming edge in the original graph due to the ordering of the edges. For this, we create a correction index that counts up from 0 to the total number of higher-order edges.
# Since the above only repeats the start of the edges, we need to add (0, 1, 2, 3, ...)
# for all `outdegree` number of edges consecutively to get the correct destination nodes
# We can achieve this by starting with a range from (0, 1, ..., num_new_edges)
idx_correction = torch.arange(ho_edge_srcs.size(0), dtype=torch.long)
print("Index correction (before adjustment):\n", idx_correction.tolist())
Index correction (before adjustment): [0, 1, 2, 3, 4, 5, 6, 7, 8]
- We then subtract the cumulative sum of the outdegree values of the higher-order source nodes from this correction index. This effectively resets the counting for each group of higher-order edges corresponding to each target node.
# Then, we subtract the cumulative sum of the outdegree for each destination node
idx_correction -= cumsum(outdegree_per_dst, dim=0)[ho_edge_srcs]
print("Index correction (after adjustment):\n", idx_correction.tolist())
Index correction (after adjustment): [0, 1, 0, 1, 0, 0, 0, 0, 0]
- Finally, we add this correction index to the starting pointers of the edges for each target node to get the correct indices for the higher-order target nodes.
# Add this tensor to the destination nodes to get the correct destination nodes for each higher-order edge
ho_edge_dsts += idx_correction
print("Higher-order edge destination indices (after correction):\n", ho_edge_dsts.tolist())
print("Higher-order edge destinations:\n", graph.mapping.to_ids(graph.data.edge_index[:, ho_edge_dsts]).T)
Higher-order edge destination indices (after correction): [2, 3, 2, 3, 4, 5, 6, 6, 0] Higher-order edge destinations: [['c' 'd'] ['c' 'e'] ['c' 'd'] ['c' 'e'] ['d' 'f'] ['e' 'f'] ['f' 'a'] ['f' 'a'] ['a' 'c']]
This gives us the final higher-order edge index that we can return from the function.
Temporal Order Lifting¶
One of the core functionalities of PathpyG is the ability to create temporal higher-order models. For this, an extension of the lift_order_edge_index function to temporal graphs is needed. We implement this in the lift_order_temporal function. This function works similarly to the lift_order_edge_index function, but with some additional steps to account for the temporal aspect of the graph. The main difference is that we need to ensure that the higher-order edges respect the temporal ordering of the original edges. Let us take a look at an example:
tedges = [
("a", "b", 1),
("a", "b", 2),
("b", "a", 3),
("b", "c", 3),
("d", "c", 4),
("a", "b", 4),
("c", "b", 4),
("c", "d", 5),
("b", "a", 5),
("c", "b", 6),
]
t = pp.TemporalGraph.from_edge_list(tedges)
pp.plot(t, node_label=t.nodes)
<pathpyG.visualisations.network_plots.TemporalNetworkPlot at 0x7fd5d0e98160>
We can create a second-order graph from this temporal graph using the lift_order_temporal function. This second-order graph is typically referred to as an event graph. Each node in the graph is an event (edge) in the original temporal graph and two events are connected if they can follow each other in time respecting a maximum time difference delta. Here, we set delta=2 which means that two events can be connected if the time difference between them is at most 2 time units.
event_edge_index = pp.algorithms.temporal.lift_order_temporal(t, delta=2)
event_mapping = pp.IndexMap(t.temporal_edges)
event_data = Data(edge_index=event_edge_index, node_sequence=graph.data.edge_index.t())
event_graph = pp.Graph(data=event_data, mapping=event_mapping)
pp.plot(event_graph, node_label=event_graph.nodes)
100%|██████████| 6/6 [00:00<00:00, 3495.74it/s]
<pathpyG.visualisations.network_plots.StaticNetworkPlot at 0x7fd40eee3eb0>
Starting with the event graph, we have a static higher-order representation of the temporal graph that we can use to create higher-order models. For each following lift-order transformations, we can use the same principles as described in the previous section on order-lifting and line graph transformations.
Internals of the lift_order_temporal Function¶
The simplest way to implement the lift_order_temporal function would be to first create the full higher-order edge index using the lift_order_edge_index function and then filter out the edges that do not respect the temporal ordering. The filter function could look as follows:
def filter_time_respecting_edges(event_edge_index: torch.Tensor, timestamps: torch.Tensor, delta: int) -> torch.Tensor:
# Subtract timestamps of the two events to get the time difference
time_diff = timestamps[event_edge_index[1]] - timestamps[event_edge_index[0]]
# Create masks for filtering
# Remove non-time-respecting higher-order edges
non_negative_mask = time_diff > 0
# Remove edges that are too far apart in time based on delta
delta_mask = time_diff <= delta
# Combine masks to get the final time-respecting edges
time_respecting_mask = non_negative_mask & delta_mask
# Filter the event_edge_index using the time_respecting_mask
return event_edge_index[:, time_respecting_mask]
We can combine the above filter function with the lift_order_edge_index function to create a lift-order function for temporal graphs as follows:
Warning
If we use the standard lift_order_edge_index function, we need to ensure that the input edge index is sorted by source nodes because the edge_index of a TemporalGraph is sorted by time and not by source nodes.
# Sort by source node indices
sorted_edge_index, time = sort_edge_index(t.data.edge_index.as_tensor(), t.data.time)
# Lift the edge index to the second order
second_order_edge_index = pp.algorithms.lift_order.lift_order_edge_index(edge_index=sorted_edge_index, num_nodes=t.n)
# Filter the edges based on the lifted edge index
filtered_edge_index = filter_time_respecting_edges(second_order_edge_index, timestamps=time, delta=2)
# Create `pp.Graph` from the filtered edge index
filtered_event_mapping = pp.IndexMap([tuple([*t.mapping.to_ids(edge).tolist(), timestamp.item()]) for edge, timestamp in zip(sorted_edge_index.t(), time)])
filtered_event_data = Data(edge_index=filtered_edge_index, node_sequence=sorted_edge_index.t())
filtered_event_graph = pp.Graph(data=filtered_event_data, mapping=filtered_event_mapping)
pp.plot(filtered_event_graph, node_label=filtered_event_graph.nodes)
<pathpyG.visualisations.network_plots.StaticNetworkPlot at 0x7fd40eee0400>
Note
The indexing of the above implementation is different from the one currently implemented in PathpyG. So while the illustrations look identical, the actual indices of the higher-order nodes will differ.
However, the above implementation has a large memory consumption for graphs with many edges because the full higher-order edge index is created before filtering. Therefore, we implement a more memory-efficient version in PathpyG that constructs the higher-order edges from the temporal graph sequentially for each timestamp. This implementation looks as follows:
def lift_order_temporal(g: pp.TemporalGraph, delta: int = 1):
indices = torch.arange(0, g.data.edge_index.size(1))
unique_t = torch.unique(g.data.time)
second_order = []
# lift order: find possible continuations for edges in each time stamp
for t in unique_t:
# find indices of all source edges that occur at unique timestamp t
src_time_mask = g.data.time == t
src_edge_idx = indices[src_time_mask]
# find indices of all edges that can possibly continue edges occurring at time t for the given delta
dst_time_mask = (g.data.time > t) & (g.data.time <= t + delta)
dst_edge_idx = indices[dst_time_mask]
if dst_edge_idx.size(0) > 0 and src_edge_idx.size(0) > 0:
# compute second-order edges between src and dst idx
# create all possible combinations of src and dst edges
x = torch.cartesian_prod(src_edge_idx, dst_edge_idx)
# filter combinations for real higher-order edges
# for all edges where dst in src_edges (g.data.edge_index[1, x[:, 0]]) matches src in dst_edges (g.data.edge_index[0, x[:, 1]])
ho_edge_index = x[g.data.edge_index[1, x[:, 0]] == g.data.edge_index[0, x[:, 1]]]
second_order.append(ho_edge_index)
ho_index = torch.cat(second_order, dim=0).t().contiguous()
return ho_index
Note that above we do not use the same indexing trick that is used in the standard lift_order_edge_index function. Instead, we create all possible combinations of incoming and outgoing edges for all incoming edges at each timestamp. Therefore, we need a filtering step afterwards to ensure that only valid higher-order edges are created. However, we can skip the sorting step beforehand because we create all possible edge combinations using the cartesian product.
It is also possible to combine both approaches, i.e., we create the higher-order edges for each timestamp separately using the indexing trick from the standard lift_order_edge_index function. While it saves the filtering step, it again requires sorting the edges beforehand which has been shown to be similar in performance to the above method. The code would look as follows:
def lift_order_temporal_combined(g: pp.TemporalGraph, delta: int = 1):
indices = torch.arange(0, g.data.edge_index.size(1))
unique_t = torch.unique(g.data.time)
second_order = []
# lift order: find possible continuations for edges in each time stamp
for i in range(unique_t.size(0)):
t = unique_t[i]
# find indices of all source edges that occur at unique timestamp t
src_time_mask = g.data.time == t
src_edge_idx = indices[src_time_mask]
# find indices of all edges that can possibly continue edges occurring at time t for the given delta
dst_time_mask = (g.data.time > t) & (g.data.time <= t + delta)
dst_node_mask = torch.isin(g.data.edge_index[0], g.data.edge_index[1, src_edge_idx])
dst_edge_idx = indices[dst_time_mask & dst_node_mask]
if dst_edge_idx.size(0) > 0 and src_edge_idx.size(0) > 0:
# get sorted dst edges for efficient processing
src_edges = g.data.edge_index[:, src_edge_idx]
dst_edges = g.data.edge_index[:, dst_edge_idx]
sorted_idx = torch.argsort(dst_edges[0])
dst_edge_idx = dst_edge_idx[sorted_idx]
dst_edges = dst_edges[:, sorted_idx]
# Use indexing trick to create higher-order edges
outdegree = degree(dst_edges[0], dtype=torch.long, num_nodes=g.n)
outdegree_per_dst = outdegree[src_edges[1]]
num_new_edges = outdegree_per_dst.sum()
ho_edge_srcs = torch.repeat_interleave(outdegree_per_dst)
ptrs = cumsum(outdegree, dim=0)[:-1]
ho_edge_dsts = torch.repeat_interleave(ptrs[src_edges[1]], outdegree_per_dst)
idx_correction = torch.arange(num_new_edges, dtype=torch.long)
idx_correction -= cumsum(outdegree_per_dst, dim=0)[ho_edge_srcs]
ho_edge_dsts += idx_correction
second_order.append(torch.stack([src_edge_idx[ho_edge_srcs], dst_edge_idx[ho_edge_dsts]], dim=0))
ho_index = torch.cat(second_order, dim=1)
return ho_index
In contrast to the lift_order_edge_index implementation, the temporal version splits the edges into source and destination edges based on timestamps. For each timestamp, we select the edges that occur at that timestamp as source edges and all edges that occur at later timestamps (within the delta time window) as destination edges. Then, instead of repeating the higher-order source nodes for all edges, we only repeat them for the destination edges.
Paths in PathpyG¶
One other core functionality of PathpyG is the ability to work with paths. Paths are sequences of nodes that represent a walk through the graph. We show an example below:
path_mapping = pp.IndexMap(list("abcde"))
paths = pp.PathData(mapping=path_mapping)
paths.append_walk(list("ab"))
paths.append_walk(list("abd"))
paths.append_walk(list("abec"))
paths.append_walk(list("dbecb"))
pp.plot(
pp.Graph.from_edge_index(paths.data.edge_index), node_label=paths.mapping.to_ids(paths.data.node_sequence).tolist()
)
<pathpyG.visualisations.network_plots.StaticNetworkPlot at 0x7fd40eee34c0>
pp.PathData is the core class for working with paths in PathpyG. It allows us to gather a collection of paths that are all walks on the same underlying graph. All paths are stored using one edge_index internally. Thus, two nodes in a path that both correspond to the same node in the underlying graph will not share the same index in the path graph. Instead, each occurrence of a node in a path is represented by a separate node in the path graph. This allows us to represent paths that visit the same node multiple times without ambiguity. The information about the underlying graph is stored in the internal PathData.data.node_sequence tensor, similar to higher-order graphs. Let us look at the example above to illustrate this:
print("The paths represented using edge index look as follows:")
for edge in paths.data.edge_index.t():
print(
f"\tInternal {edge.tolist()}: Underlying graph edge {paths.mapping.to_ids(paths.data.node_sequence[edge].view(-1)).tolist()}"
)
The paths represented using edge index look as follows: Internal [0, 1]: Underlying graph edge ['a', 'b'] Internal [2, 3]: Underlying graph edge ['a', 'b'] Internal [3, 4]: Underlying graph edge ['b', 'd'] Internal [5, 6]: Underlying graph edge ['a', 'b'] Internal [6, 7]: Underlying graph edge ['b', 'e'] Internal [7, 8]: Underlying graph edge ['e', 'c'] Internal [9, 10]: Underlying graph edge ['d', 'b'] Internal [10, 11]: Underlying graph edge ['b', 'e'] Internal [11, 12]: Underlying graph edge ['e', 'c'] Internal [12, 13]: Underlying graph edge ['c', 'b']
PathData additionally stores some metadata about the paths so that you can easily access information about which nodes belong to which path. This includes
dag_weight: A tensor that stores the weight of each path (i.e., the number of times the path was observed).dag_num_edges: A tensor that stores the number of edges in each path.dag_num_nodes: A tensor that stores the number of nodes in each path.
Using this information, you can, e.g., access the second path in the collection as follows:
start = paths.data.dag_num_nodes[:1].sum().item()
end = start + paths.data.dag_num_nodes[1].item()
paths.mapping.to_ids(paths.data.node_sequence[start:end].view(-1)).tolist()
['a', 'b', 'd']
Lastly, since we are using an edge_index internally, the lift_order_edge_index function works out-of-the-box for paths. A second-order representation of the paths can be created as follows:
second_order_edge_index = pp.algorithms.lift_order.lift_order_edge_index(
edge_index=paths.data.edge_index, num_nodes=paths.data.num_nodes
)
second_order_paths = pp.Graph.from_edge_index(edge_index=second_order_edge_index)
pp.plot(
second_order_paths,
node_label=paths.mapping.to_ids(paths.data.node_sequence[paths.data.edge_index.t()].squeeze()).tolist(),
)
<pathpyG.visualisations.network_plots.StaticNetworkPlot at 0x7fd40ee8ff10>
Multi-Order Models¶
With the concepts above, we can now create multi-order models using the MultiOrderModel class. This class allows us to create higher-order models of arbitrary order from a given base temporal graph or paths. Let's look at an example of creating a multi-order model from a temporal graph:
m_t = pp.MultiOrderModel.from_temporal_graph(t, max_order=2)
pp.plot(m_t.layers[2], node_label=m_t.layers[2].nodes)
100%|██████████| 6/6 [00:00<00:00, 2312.82it/s]
<pathpyG.visualisations.network_plots.StaticNetworkPlot at 0x7fd40eee32b0>
We can see that the second-order graph created by the MultiOrderModel is different from the one created by the lift_order_temporal function directly. This is because the MultiOrderModel higher-order DeBruijn graph representation. This representation merges higher-order nodes that correspond to the same path in the original graph. This means that temporal edges that appear in the event graph as different nodes will be merged into one node in the DeBruijn graph if they correspond to the same path in the original graph. This results in a more compact representation of the higher-order graph.
The same is true for paths. We can create a multi-order model from a collection of paths as follows:
m_p = pp.MultiOrderModel.from_path_data(paths, max_order=2)
pp.plot(m_p.layers[2], node_label=m_p.layers[2].nodes)
<pathpyG.visualisations.network_plots.StaticNetworkPlot at 0x7fd40eee2860>
We can see that the higher-order node a->b which appeared thrice in the second-order graph created by the lift_order_edge_index function is now merged into one node in the DeBruijn graph representation.
Internals of the MultiOrderModel Class¶
Let us now take a closer look at how the MultiOrderModel class works under the hood. We already saw that the MultiOrderModel merges higher-order nodes from the line/event graph transformations.
This is done in 3 distinct steps which we will go through using the paths example above:
- Order Lifting: First, we create the higher-order edge index using the appropriate lift-order function (
lift_order_edge_indexorlift_order_temporal) depending on whether we are working with paths or temporal graphs in the first order andlift_order_edge_indexfor the second order and beyond regardless of the input type.
Note
While we merge the higher-order nodes and aggregate the higher-order edges for each order, we need to use the original higher-order edge index to create the next order. This is because the transitivity of paths is only preserved in the original higher-order edge index.
# We create the third-order representation of the paths
third_order_edge_index = pp.algorithms.lift_order.lift_order_edge_index(
edge_index=second_order_paths.data.edge_index, num_nodes=second_order_paths.n
)
- Update Node Sequences: Next, we need to update the internal
node_sequencetensor to reflect the new higher-order nodes. For this, we create a newnode_sequenceby concatenating the last node of the target node sequence to the source node sequence. This way, we create a new sequence that corresponds to the paths represented by the next order nodes.
second_order_node_sequence = paths.data.node_sequence[paths.data.edge_index.t()].squeeze()
third_order_node_sequence = torch.cat([
second_order_node_sequence[second_order_paths.data.edge_index[0]],
second_order_node_sequence[second_order_paths.data.edge_index[1]][:, -1:]
], dim=1)
- Merge Higher-Order Nodes: Finally, we need to merge the higher-order nodes that correspond to the same path in the original graph. For this, we create a unique mapping from the new
node_sequenceto unique indices. We can then use this mapping to update the higher-order edge index to reflect the merged nodes and then aggregate duplicate edges.
third_order_paths = pp.algorithms.lift_order.aggregate_edge_index(
edge_index=third_order_edge_index, node_sequence=third_order_node_sequence
)
After performing these steps, we can again visualize the resulting higher-order graph:
third_order_paths.mapping = pp.IndexMap([tuple(mapping.to_ids(v).tolist()) for v in third_order_paths.data.node_sequence])
pp.plot(third_order_paths, node_label=third_order_paths.nodes)
<pathpyG.visualisations.network_plots.StaticNetworkPlot at 0x7fd40ed71e70>
These steps can be repeated for each order until we reach the desired maximum order for the MultiOrderModel.
Other Tensor-based Implementations¶
The concepts from above can also be useful to implement other functionalities using tensor operations.
Longest Path Extraction¶
One example is the extraction of all longest paths from a directed acyclic graph (DAG). This can be done by iterating through all nodes in the DAG in topological order at the same time. We provide an example implementation below:
def get_all_paths_DAG(g: pp.Graph) -> dict:
"""Calculate all existing paths from any root node to any leaf node in a directed acyclic graph (DAG)."""
paths_of_length = {}
edge_index = g.data.edge_index.as_tensor()
# calculate degrees
out_degree = degree(edge_index[0], num_nodes=g.n, dtype=torch.long)
in_degree = degree(edge_index[1], num_nodes=g.n, dtype=torch.long)
# identify root nodes with in-degree zero
roots = torch.where(in_degree == 0)[0]
leafs = out_degree == 0
# create path tensor that contains all paths that are not yet at a leaf node
paths = roots.unsqueeze(1)
# remove all paths that are already at a leaf node
paths_of_length[1] = paths[leafs[roots]].cpu().tolist()
# continue all paths that are not at a leaf node
paths = paths[~leafs[roots]]
# remember nodes that haven't been traversed yet
nodes = roots[~leafs[roots]]
ptrs = cumsum(out_degree, dim=0)
# count all longest paths in DAG
step = 1
while nodes.size(0) > 0 or step > g.n:
idx_repeat = torch.repeat_interleave(out_degree[nodes])
next_idx = torch.repeat_interleave(ptrs[nodes], out_degree[nodes])
idx_correction = (
torch.arange(next_idx.size(0), device=edge_index.device) - cumsum(out_degree[nodes], dim=0)[idx_repeat]
)
next_idx += idx_correction
next_nodes = edge_index[1][next_idx]
paths = torch.cat([paths[idx_repeat], next_nodes.unsqueeze(1)], dim=1)
paths_of_length[step] = paths[leafs[next_nodes]].tolist()
paths = paths[~leafs[next_nodes]]
nodes = next_nodes[~leafs[next_nodes]]
step += 1
return paths_of_length
The function above starts at all root nodes (nodes with no incoming edges) and iteratively traverses all possible next nodes while keeping track of all current paths. Whenever a path reaches a leaf node (a node with no outgoing edges), it is added to the list of longest paths and removed from the current paths. This continues until all paths have reached a leaf node.
Tip
Getting the next nodes for all current paths is done using a similar indexing trick as in the lift_order_edge_index function. This allows us to efficiently get all next nodes for all current paths in one go using tensor operations.