Skip to content

network_plots

Network plots with tikz.

NetworkPlot

Bases: pathpyG.visualisations._tikz.core.TikzPlot

Network plot class for a static network.

Source code in src/pathpyG/visualisations/_tikz/network_plots.py
class NetworkPlot(TikzPlot):
    """Network plot class for a static network."""

    _kind = "network"

    def __init__(self, data: dict, **kwargs: Any) -> None:
        """Initialize network plot class."""
        super().__init__()
        self.data = data
        self.config = kwargs
        self.config["width"] = self.config.pop("width", 6)
        self.config["height"] = self.config.pop("height", 6)
        self.generate()

    def generate(self) -> None:
        """Clen up data."""
        self._compute_node_data()
        self._compute_edge_data()
        self._update_layout()

    def _compute_node_data(self) -> None:
        """Generate the data structure for the nodes."""
        default: set = {"uid", "x", "y", "size", "color", "opacity"}
        mapping: dict = {}

        for node in self.data["nodes"]:
            for key in list(node):
                if key in mapping:
                    node[mapping[key]] = node.pop(key)
                if key not in default:
                    node.pop(key, None)

            color = node.get("color", None)
            if isinstance(color, str) and "#" in color:
                color = hex_to_rgb(color)
                node["color"] = f"{{{color[0]},{color[1]},{color[2]}}}"
                node["RGB"] = True

    def _compute_edge_data(self) -> None:
        """Generate the data structure for the edges."""
        default: set = {"uid", "source", "target", "lw", "color", "opacity"}
        mapping: dict = {"size": "lw"}

        for edge in self.data["edges"]:
            for key in list(edge):
                if key in mapping:
                    edge[mapping[key]] = edge.pop(key)
                if key not in default:
                    edge.pop(key, None)

            color = edge.get("color", None)
            if isinstance(color, str) and "#" in color:
                color = hex_to_rgb(color)
                edge["color"] = f"{{{color[0]},{color[1]},{color[2]}}}"
                edge["RGB"] = True

    def _update_layout(self, default_size: float = 0.6) -> None:
        """Update the layout."""
        layout = self.config.get("layout")

        if layout is None:
            return

        # get data
        layout = {n["uid"]: (n["x"], n["y"]) for n in self.data["nodes"]}
        sizes = {n["uid"]: n.get("size", default_size) for n in self.data["nodes"]}

        # get config values
        width = self.config["width"]
        height = self.config["height"]
        keep_aspect_ratio = self.config.get("keep_aspect_ratio", True)
        margin = self.config.get("margin", 0.0)
        margins = {"top": margin, "left": margin, "bottom": margin, "right": margin}

        # calculate the scaling ratio
        x_ratio = float("inf")
        y_ratio = float("inf")

        # calculate absolute min and max coordinates
        x_absolute = []
        y_absolute = []
        for uid, (_x, _y) in layout.items():
            _s = sizes[uid] / 2
            x_absolute.extend([_x - _s, _x + _s])
            y_absolute.extend([_y - _s, _y + _s])

        # calculate min and max center coordinates
        x_values, y_values = zip(*layout.values())
        x_min, x_max = min(x_values), max(x_values)
        y_min, y_max = min(y_values), max(y_values)

        # change margins
        margins["left"] += abs(x_min - min(x_absolute))
        margins["bottom"] += abs(y_min - min(y_absolute))
        margins["top"] += abs(y_max - max(y_absolute))
        margins["right"] += abs(x_max - max(x_absolute))

        if x_max - x_min > 0:
            x_ratio = (width - margins["left"] - margins["right"]) / (x_max - x_min)
        if y_max - y_min > 0:
            y_ratio = (height - margins["top"] - margins["bottom"]) / (y_max - y_min)

        if keep_aspect_ratio:
            scaling = (min(x_ratio, y_ratio), min(x_ratio, y_ratio))
        else:
            scaling = (x_ratio, y_ratio)

        if scaling[0] == float("inf"):
            scaling = (1, scaling[1])
        if scaling[1] == float("inf"):
            scaling = (scaling[0], 1)

        x_values = []
        y_values = []

        # apply scaling to the points
        _layout = {n: (x * scaling[0], y * scaling[1]) for n, (x, y) in layout.items()}

        # find min and max values of the points
        x_values, y_values = zip(*_layout.values())
        x_min, x_max = min(x_values), max(x_values)
        y_min, y_max = min(y_values), max(y_values)

        # calculate the translation
        translation = (
            ((width - margins["left"] - margins["right"]) / 2 + margins["left"])
            - ((x_max - x_min) / 2 + x_min),
            ((height - margins["top"] - margins["bottom"]) / 2 + margins["bottom"])
            - ((y_max - y_min) / 2 + y_min),
        )

        # apply translation to the points
        _layout = {
            n: (x + translation[0], y + translation[1]) for n, (x, y) in _layout.items()
        }

        # update node position for the plot
        for node in self.data["nodes"]:
            node["x"], node["y"] = _layout[node["uid"]]

    def to_tikz(self) -> str:
        """Convert to Tex."""

        def _add_args(args: dict):
            string = ""
            for key, value in args.items():
                string += f",{key}" if value is True else f",{key}={value}"
            return string

        tikz = ""
        for node in self.data["nodes"]:
            uid = node.pop("uid")
            string = "\\Vertex["
            string += _add_args(node)
            string += "]{{{}}}\n".format(uid)
            tikz += string

        for edge in self.data["edges"]:
            uid = edge.pop("uid")
            source = edge.pop("source")
            target = edge.pop("target")
            string = "\\Edge["
            string += _add_args(edge)
            string += "]({})({})\n".format(source, target)
            tikz += string
        return tikz

__init__

Initialize network plot class.

Source code in src/pathpyG/visualisations/_tikz/network_plots.py
def __init__(self, data: dict, **kwargs: Any) -> None:
    """Initialize network plot class."""
    super().__init__()
    self.data = data
    self.config = kwargs
    self.config["width"] = self.config.pop("width", 6)
    self.config["height"] = self.config.pop("height", 6)
    self.generate()

generate

Clen up data.

Source code in src/pathpyG/visualisations/_tikz/network_plots.py
def generate(self) -> None:
    """Clen up data."""
    self._compute_node_data()
    self._compute_edge_data()
    self._update_layout()

to_tikz

Convert to Tex.

Source code in src/pathpyG/visualisations/_tikz/network_plots.py
def to_tikz(self) -> str:
    """Convert to Tex."""

    def _add_args(args: dict):
        string = ""
        for key, value in args.items():
            string += f",{key}" if value is True else f",{key}={value}"
        return string

    tikz = ""
    for node in self.data["nodes"]:
        uid = node.pop("uid")
        string = "\\Vertex["
        string += _add_args(node)
        string += "]{{{}}}\n".format(uid)
        tikz += string

    for edge in self.data["edges"]:
        uid = edge.pop("uid")
        source = edge.pop("source")
        target = edge.pop("target")
        string = "\\Edge["
        string += _add_args(edge)
        string += "]({})({})\n".format(source, target)
        tikz += string
    return tikz

StaticNetworkPlot

Bases: pathpyG.visualisations._tikz.network_plots.NetworkPlot

Network plot class for a static network.

Source code in src/pathpyG/visualisations/_tikz/network_plots.py
class StaticNetworkPlot(NetworkPlot):
    """Network plot class for a static network."""

    _kind = "static"

TemporalNetworkPlot

Bases: pathpyG.visualisations._tikz.network_plots.NetworkPlot

Network plot class for a static network.

Source code in src/pathpyG/visualisations/_tikz/network_plots.py
class TemporalNetworkPlot(NetworkPlot):
    """Network plot class for a static network."""

    _kind = "temporal"

    def __init__(self, data: dict, **kwargs: Any) -> None:
        """Initialize network plot class."""
        raise NotImplementedError

__init__

Initialize network plot class.

Source code in src/pathpyG/visualisations/_tikz/network_plots.py
def __init__(self, data: dict, **kwargs: Any) -> None:
    """Initialize network plot class."""
    raise NotImplementedError