コード例 #1
0
    def test_benchmark_sampledheterogeneousbreadthfirstwalk(self, benchmark):

        g = create_simple_test_graph()
        bfw = SampledHeterogeneousBreadthFirstWalk(g)

        nodes = [0]
        n = 5
        n_size = [5, 5]

        benchmark(lambda: bfw.run(nodes=nodes, n=n, n_size=n_size))
コード例 #2
0
    def __init__(self, G, batch_size, num_samples, seed=None, name=None):
        if not isinstance(G, StellarGraphBase):
            raise TypeError("Graph must be a StellarGraph object.")

        G.check_graph_for_ml(features=True)

        self.graph = G
        self.num_samples = num_samples
        self.batch_size = batch_size
        self.name = name

        # We need a schema for compatibility with HinSAGE
        self.schema = G.create_graph_schema(create_type_maps=True)

        # The sampler used to generate random samples of neighbours
        self.sampler = SampledHeterogeneousBreadthFirstWalk(
            G, graph_schema=self.schema, seed=seed)
    def test_walk_generation_single_root_node_loner(self):
        """
        Tests that the sampler behaves correctly when a root node is isolated with no self loop
        Returns:

        """
        g = create_simple_test_graph()
        bfw = SampledHeterogeneousBreadthFirstWalk(g)

        nodes = [0]  # this is an isolated user node with no self loop
        n = 1
        n_size = [0]

        subgraphs = bfw.run(nodes=nodes, n=n, n_size=n_size)
        assert len(subgraphs) == 1
        assert len(subgraphs[0]) == 3
        assert subgraphs[0][0][0] == 0  # this should be the root node id
        # node 0 is of type 'user' and for the simple test graph it has 2 types of edges, rating, and friend,
        # so 2 empty subgraphs should be returned
        assert len(subgraphs[0][1]) == 0  # this should be empty list
        assert len(subgraphs[0][2]) == 0  # this should be the empty list

        # These test should return the same result as the one before regardless of the value of n_size
        n_size = [2, 3]
        subgraphs = bfw.run(nodes=nodes, n=n, n_size=n_size)
        assert len(subgraphs) == 1
        assert (
            len(subgraphs[0]) == 9
        )  # we return all fake samples in walk even if there are no neighbours
        assert subgraphs[0][0][0] == 0  # this should be the root node id

        # node 0 is of type 'user' and for the simple test graph it has 2 types of edges, rating, and friend,
        # so 2 subgraphs with None should be returned
        assert len(subgraphs[0][1]) == 2
        assert all([x is None
                    for x in subgraphs[0][1]])  # this should only be None
        assert len(subgraphs[0][2]) == 2
        assert all([x is None
                    for x in subgraphs[0][2]])  # this should only be None
コード例 #4
0
    def test_parameter_checking(self):
        g = create_simple_test_graph()

        graph_schema = g.create_graph_schema(create_type_maps=True)
        bfw = SampledHeterogeneousBreadthFirstWalk(g, graph_schema)

        nodes = [0, 1]
        n = 1
        n_size = [1]
        seed = 1001

        with pytest.raises(ValueError):
            # nodes should be a list of node ids even for a single node
            bfw.run(nodes=None, n=n, n_size=n_size, seed=seed)
        with pytest.raises(ValueError):
            bfw.run(nodes=0, n=n, n_size=n_size, seed=seed)
        # n has to be positive integer
        with pytest.raises(ValueError):
            bfw.run(nodes=nodes, n=-1, n_size=n_size, seed=seed)
        with pytest.raises(ValueError):
            bfw.run(nodes=nodes, n=10.1, n_size=n_size, seed=seed)
        with pytest.raises(ValueError):
            bfw.run(nodes=nodes, n=0, n_size=n_size, seed=seed)
            # n_size has to be list of positive integers
        with pytest.raises(ValueError):
            bfw.run(nodes=nodes, n=n, n_size=0, seed=seed)
        with pytest.raises(ValueError):
            bfw.run(nodes=nodes, n=n, n_size=[-5], seed=seed)
        with pytest.raises(ValueError):
            bfw.run(nodes=nodes, n=-1, n_size=[2.4], seed=seed)
        with pytest.raises(ValueError):
            bfw.run(nodes=nodes, n=n, n_size=(1, 2), seed=seed)
            # graph_schema must be None or GraphSchema type
        with pytest.raises(ValueError):
            SampledHeterogeneousBreadthFirstWalk(g, graph_schema="graph schema")

        with pytest.raises(ValueError):
            SampledHeterogeneousBreadthFirstWalk(g, graph_schema=9092)

        with pytest.raises(ValueError):
            bfw.run(nodes=nodes, n=n, n_size=n_size, seed=-1235)
        with pytest.raises(ValueError):
            bfw.run(nodes=nodes, n=n, n_size=n_size, seed=10.987665)
        with pytest.raises(ValueError):
            bfw.run(nodes=nodes, n=n, n_size=n_size, seed=-982.4746)
        with pytest.raises(ValueError):
            bfw.run(nodes=nodes, n=n, n_size=n_size, seed="don't be random")

        # If no root nodes are given, an empty list is returned which is not an error but I thought this method
        # is the best for checking this behaviour.
        nodes = []
        subgraph = bfw.run(nodes=nodes, n=n, n_size=n_size, seed=seed)
        assert len(subgraph) == 0
コード例 #5
0
    def test_walk_generation_many_root_nodes(self):

        g = create_simple_test_graph()
        bfw = SampledHeterogeneousBreadthFirstWalk(g)

        nodes = [0, 7]  # both nodes are type user
        n = 1
        n_size = [0]

        subgraphs = bfw.run(nodes=nodes, n=n, n_size=n_size, seed=999)
        assert len(subgraphs) == len(nodes) * n
        for i, subgraph in enumerate(subgraphs):
            assert len(subgraph) == 3
            assert subgraph[0][0] == nodes[i]  # should equal the root node

        n_size = [1]
        subgraphs = bfw.run(nodes=nodes, n=n, n_size=n_size, seed=999)
        assert len(subgraphs) == len(nodes) * n
        for i, subgraph in enumerate(subgraphs):
            assert len(subgraph) == 3

        n_size = [2]
        subgraphs = bfw.run(nodes=nodes, n=n, n_size=n_size, seed=999)
        assert len(subgraphs) == len(nodes) * n
        for i, subgraph in enumerate(subgraphs):
            assert len(subgraph) == 3

        valid_result = [[[0], [None, None], [None, None]], [[7], [7, 7], [None, None]]]
        for a, b in zip(subgraphs, valid_result):
            assert a == b

        n_size = [2, 2]
        nodes = [0, 4]
        subgraphs = bfw.run(nodes=nodes, n=n, n_size=n_size, seed=999)
        assert len(subgraphs) == len(nodes) * n
        assert subgraphs[0][0][0] == nodes[0]
        valid_result = [
            [
                [0],
                [None, None],
                [None, None],
                [None, None],
                [None, None],
                [None, None],
                [None, None],
                [None, None],
                [None, None],
            ],
            [[4], [1, "5"], [2, 2], [4, 4], [2, 2], [4, 1], [3, 6], [4, 1], [4, 4]],
        ]
        for a, b in zip(subgraphs, valid_result):
            assert a == b

        n_size = [2, 3]
        nodes = [1, 6]  # a user and a movie node respectively
        subgraphs = bfw.run(nodes=nodes, n=n, n_size=n_size, seed=999)
        assert len(subgraphs) == len(nodes) * n
        assert subgraphs[0][0][0] == nodes[0]
        valid_result = [
            [
                [1],
                [4, "5"],
                [3, 2],
                ["5", 1, 1],
                [2, 2, 2],
                [4, 4, 1],
                [6, 3, 6],
                [1, 1, 1],
                [1, 4, 1],
            ],
            [[6], ["5", "5"], [4, 1, 4], [6, 3, 3], [1, 1, 1], [6, 6, 6]],
        ]
        for a, b in zip(subgraphs, valid_result):
            assert a == b

        n = 5
        subgraphs = bfw.run(nodes=nodes, n=n, n_size=n_size, seed=999)
        assert len(subgraphs) == len(nodes) * n

        #
        # Test with multi-graph
        #
        g = create_multi_test_graph()
        bfw = SampledHeterogeneousBreadthFirstWalk(g)

        nodes = [1, 6]
        n = 1
        n_size = [2]

        subgraphs = bfw.run(nodes=nodes, n=n, n_size=n_size, seed=999)
        assert len(subgraphs) == n * len(nodes)
        valid_result = [[[1], [4, 4], ["5", 4], [3, 2]], [[6], ["5", "5"]]]
        assert subgraphs == valid_result

        n = 1
        n_size = [2, 3]

        subgraphs = bfw.run(nodes=nodes, n=n, n_size=n_size, seed=999)
        assert len(subgraphs) == n * len(nodes)

        nodes = [4, "5", 0]
        n = 1
        n_size = [3, 3, 1]

        subgraphs = bfw.run(nodes=nodes, n=n, n_size=n_size, seed=999)
        assert len(subgraphs) == n * len(nodes)

        n = 99
        subgraphs = bfw.run(nodes=nodes, n=n, n_size=n_size, seed=999)
        assert len(subgraphs) == n * len(nodes)
コード例 #6
0
    def test_walk_generation_single_root_node(self):

        g = create_simple_test_graph()
        bfw = SampledHeterogeneousBreadthFirstWalk(g)

        nodes = [3]
        n = 1
        n_size = [2]

        subgraphs = bfw.run(nodes=nodes, n=n, n_size=n_size, seed=42)
        assert len(subgraphs) == n
        # should return [[[3], [1, 1]]]
        assert subgraphs == [[[3], [1, 1]]]

        n_size = [3]
        subgraphs = bfw.run(nodes=nodes, n=n, n_size=n_size, seed=42)
        assert len(subgraphs) == n
        # should return [[[3], [1, 1, '5']]]
        assert subgraphs == [[[3], [1, 1, "5"]]]

        n_size = [1, 1]
        subgraphs = bfw.run(nodes=nodes, n=n, n_size=n_size, seed=42)
        assert len(subgraphs) == n
        # should return [[[3], [1], [4], [3]]]
        assert subgraphs == [[[3], [1], [4], [3]]]

        nodes = ["5"]
        n_size = [2, 3]
        subgraphs = bfw.run(nodes=nodes, n=n, n_size=n_size, seed=42)
        assert len(subgraphs) == n
        # should return
        # [[['5'],
        #   [4, 4],
        #   [3, 6],
        #   ['5', '5', '5'],
        #   [2, 2, 2],
        #   ['5', '5', '5'],
        #   [2, 2, 2],
        #   ['5', 1, '5'],
        #   ['5', '5', '5']]]
        print(subgraphs)
        assert subgraphs == [
            [
                ["5"],
                [1, 1],
                [6, 3],
                [4, 4, 4],
                [2, 3, 2],
                [4, 4, 4],
                [2, 2, 2],
                ["5", "5", "5"],
                ["5", 1, 1],
            ]
        ]
        # assert subgraphs == [
        #     [
        #         ["5"],
        #         [4, 4],
        #         [3, 6],
        #         ["5", "5", "5"],
        #         [2, 2, 2],
        #         ["5", "5", "5"],
        #         [2, 2, 2],
        #         ["5", 1, "5"],
        #         ["5", "5", "5"],
        #     ]
        # ]

        nodes = ["5"]
        n_size = [2, 3]
        n = 3
        subgraphs = bfw.run(nodes=nodes, n=n, n_size=n_size, seed=42)
        assert len(subgraphs) == n
        valid_result = [
            [
                ["5"],
                [1, 1],
                [6, 3],
                [4, 4, 4],
                [2, 3, 2],
                [4, 4, 4],
                [2, 2, 2],
                ["5", "5", "5"],
                ["5", 1, 1],
            ],
            [
                ["5"],
                [4, 4],
                [6, 3],
                [1, "5", 1],
                [2, 2, 2],
                ["5", "5", "5"],
                [2, 2, 2],
                ["5", "5", "5"],
                ["5", 1, 1],
            ],
            [
                ["5"],
                [1, 1],
                [6, 3],
                [4, 4, "5"],
                [3, 3, 3],
                [4, "5", "5"],
                [2, 3, 2],
                ["5", "5", "5"],
                ["5", "5", "5"],
            ],
        ]
        for a, b in zip(subgraphs, valid_result):
            assert len(a) == len(b)
            assert a == b
        #
        # Test with multi-graph
        #
        g = create_multi_test_graph()
        bfw = SampledHeterogeneousBreadthFirstWalk(g)

        nodes = [1]
        n = 1
        n_size = [2]

        subgraphs = bfw.run(nodes=nodes, n=n, n_size=n_size, seed=19893839)
        assert len(subgraphs) == n
        # should return [[[1], [4, 4], [4, 4], [2, 3]]]
        assert subgraphs == [[[1], [4, 4], [4, 4], [2, 3]]]

        n_size = [2, 3]
        subgraphs = bfw.run(nodes=nodes, n=n, n_size=n_size, seed=19893839)
        assert len(subgraphs) == n
        #
        valid_result = [
            [
                [1],
                [4, 4],
                [4, 4],
                [2, 3],
                [1, 1, 1],
                ["5", 1, 1],
                [2, 2, 2],
                [1, 1, 1],
                ["5", "5", 1],
                [2, 2, 2],
                [1, 1, 1],
                ["5", "5", 1],
                [2, 2, 2],
                [1, 1, 1],
                ["5", 1, "5"],
                [2, 2, 2],
                [1, 1, 4],
                [1, "5", 1],
            ]
        ]
        for a, b in zip(subgraphs, valid_result):
            assert len(a) == len(b)
            assert a == b

        nodes = [1]
        n_size = [2, 0]
        n = 2
        subgraphs = bfw.run(nodes=nodes, n=n, n_size=n_size, seed=19893839)
        assert len(subgraphs) == n
        valid_result = [
            [
                [1],
                [4, 4],
                [4, 4],
                [2, 3],
                [],
                [],
                [],
                [],
                [],
                [],
                [],
                [],
                [],
                [],
                [],
                [],
                [],
                [],
            ],
            [
                [1],
                [4, 4],
                ["5", "5"],
                [2, 2],
                [],
                [],
                [],
                [],
                [],
                [],
                [],
                [],
                [],
                [],
                [],
                [],
                [],
                [],
            ],
        ]
        for a, b in zip(subgraphs, valid_result):
            assert len(a) == len(b)
            assert a == b
コード例 #7
0
    def test_walk_generation_single_root_node_self_loner(self):
        g = create_simple_test_graph()
        bfw = SampledHeterogeneousBreadthFirstWalk(g)

        root_node_id = 7
        nodes = [
            root_node_id
        ]  # this node is only connected with itself with an edge of type "friend"
        n = 1

        n_size = [0]
        subgraphs = bfw.run(nodes=nodes, n=n, n_size=n_size)
        assert len(subgraphs) == 1
        assert len(subgraphs[0]) == 3
        assert subgraphs[0][0][0] == root_node_id  # this should be the root node id
        # node 0 is of type 'user' and for the simple test graph it has 2 types of edges, rating, and friend,
        # so 2 empty subgraphs should be returned
        assert len(subgraphs[0][1]) == 0  # this should be empty list
        assert len(subgraphs[0][2]) == 0  # this should be the empty list

        n_size = [1]
        subgraphs = bfw.run(nodes=nodes, n=n, n_size=n_size)
        assert len(subgraphs) == 1
        assert len(subgraphs[0]) == 3
        assert subgraphs[0][0][0] == root_node_id  # this should be the root node id

        # node 0 is of type 'user' and for the simple test graph it has 2 types of edges, rating, and friend,
        # so 1 subgraph with the root id corresponding to friend type edge and 1 subgraph with None should be returned
        assert subgraphs[0][1][0] == root_node_id  # this should be the root node id
        assert len(subgraphs[0][2]) == 1  # this should be None
        assert subgraphs[0][2] == [None]  # this should be None

        n_size = [2, 2]
        subgraphs = bfw.run(nodes=nodes, n=n, n_size=n_size)
        # The correct result should be:
        #  [[[7], [7, 7], [None, None], [7, 7], [None, None], [7, 7], [None, None], [None, None], [None, None]]]
        assert len(subgraphs) == 1
        assert len(subgraphs[0]) == 9
        for level in subgraphs[0]:
            assert type(level) == list
            if len(level) > 0:
                # All values should be rood_node_id or None
                for value in level:
                    assert (value == root_node_id) or (value is None)

        n_size = [2, 2, 3]
        subgraphs2 = bfw.run(nodes=nodes, n=n, n_size=n_size)
        # The correct result should be the same as previous output plus:
        #  [[7]*3, [None]*3, [7]*3, [None]*3, [None]*3,  [None]*3, [7]*3, [None]*3, [7]*3
        # concatenated with 10 [None]*3
        assert len(subgraphs2) == 1
        assert len(subgraphs2[0]) == 29

        # The previous list should be the same as start of this one
        assert all(
            [subgraphs[0][ii] == subgraphs2[0][ii] for ii in range(len(subgraphs))]
        )

        for level in subgraphs2[0]:
            assert type(level) == list
            if len(level) > 0:
                for value in level:
                    assert (value == root_node_id) or (value is None)
    def test_walk_generation_single_root_node(self):

        g = create_test_graph(self_loop=True)
        bfw = SampledHeterogeneousBreadthFirstWalk(g)

        nodes = [3]
        n = 1
        n_size = [2]

        subgraphs = bfw.run(nodes=nodes, n=n, n_size=n_size, seed=42)
        assert len(subgraphs) == n
        assert subgraphs == [[[3], ["5", 1]]]

        n_size = [3]
        subgraphs = bfw.run(nodes=nodes, n=n, n_size=n_size, seed=42)
        assert len(subgraphs) == n
        assert subgraphs == [[[3], ["5", 1, 1]]]

        n_size = [1, 1]
        subgraphs = bfw.run(nodes=nodes, n=n, n_size=n_size, seed=42)
        assert len(subgraphs) == n
        assert subgraphs == [[[3], ["5"], [1], [3]]]

        nodes = ["5"]
        n_size = [2, 3]
        subgraphs = bfw.run(nodes=nodes, n=n, n_size=n_size, seed=42)
        assert len(subgraphs) == n
        assert subgraphs == [[
            ["5"],
            [4, 1],
            [3, 3],
            ["5", "5", "5"],
            [2, 2, 2],
            [4, "5", 4],
            [2, 3, 3],
            [1, "5", "5"],
            [1, "5", "5"],
        ]]

        nodes = ["5"]
        n_size = [2, 3]
        n = 3
        subgraphs = bfw.run(nodes=nodes, n=n, n_size=n_size, seed=42)
        assert len(subgraphs) == n
        valid_result = [
            [
                ["5"],
                [4, 1],
                [3, 3],
                ["5", "5", "5"],
                [2, 2, 2],
                [4, "5", 4],
                [2, 3, 3],
                [1, "5", "5"],
                [1, "5", "5"],
            ],
            [
                ["5"],
                [1, 1],
                [6, 3],
                [4, 4, "5"],
                [3, 3, 3],
                ["5", "5", 4],
                [3, 3, 3],
                ["5", "5", "5"],
                [1, 1, 1],
            ],
            [
                ["5"],
                [1, 1],
                [3, 3],
                ["5", 4, 4],
                [2, 2, 3],
                ["5", "5", 4],
                [3, 2, 2],
                ["5", "5", "5"],
                ["5", "5", "5"],
            ],
        ]
        for a, b in zip(subgraphs, valid_result):
            assert len(a) == len(b)
            assert a == b
        #
        # Test with multi-graph
        #
        g = create_test_graph(multi=True)
        bfw = SampledHeterogeneousBreadthFirstWalk(g)

        nodes = [1]
        n = 1
        n_size = [2]

        subgraphs = bfw.run(nodes=nodes, n=n, n_size=n_size, seed=19893839)
        assert len(subgraphs) == n
        assert subgraphs == [[[1], [4, 4], [4, 4], [2, 2]]]

        n_size = [2, 3]
        subgraphs = bfw.run(nodes=nodes, n=n, n_size=n_size, seed=19893839)
        assert len(subgraphs) == n
        valid_result = [[
            [1],
            [4, 4],
            [4, 4],
            [2, 2],
            [1, 1, 1],
            ["5", 1, 1],
            [2, 2, 2],
            [1, 1, 1],
            [1, "5", 1],
            [2, 2, 2],
            [1, 1, 1],
            ["5", 1, "5"],
            [2, 2, 2],
            [1, 1, 1],
            ["5", "5", 1],
            [2, 2, 2],
            [4, 1, 1],
            [4, 1, 4],
        ]]
        for a, b in zip(subgraphs, valid_result):
            assert len(a) == len(b)
            assert a == b

        nodes = [1]
        n_size = [2, 0]
        n = 2
        subgraphs = bfw.run(nodes=nodes, n=n, n_size=n_size, seed=19893839)
        assert len(subgraphs) == n
        valid_result = [
            [
                [1],
                [4, 4],
                [4, 4],
                [2, 2],
                [],
                [],
                [],
                [],
                [],
                [],
                [],
                [],
                [],
                [],
                [],
                [],
                [],
                [],
            ],
            [
                [1],
                [4, 4],
                ["5", "5"],
                [2, 2],
                [],
                [],
                [],
                [],
                [],
                [],
                [],
                [],
                [],
                [],
                [],
                [],
                [],
                [],
            ],
        ]
        for a, b in zip(subgraphs, valid_result):
            assert len(a) == len(b)
            assert a == b
コード例 #9
0
class HinSAGELinkGenerator:
    """A data generator for link prediction with Heterogeneous HinSAGE models

    At minimum, supply the StellarGraph, the batch size, and the number of
    node samples for each layer of the GraphSAGE model.

    The supplied graph should be a StellarGraph object that is ready for
    machine learning. Currently the model requires node features for all
    nodes in the graph.

    Use the :meth:`flow` method supplying the nodes and (optionally) targets
    to get an object that can be used as a Keras data generator.

    Note that you don't need to pass link_type (target link type) to the link mapper, considering that:

    * The mapper actually only cares about (src,dst) node types, and these can be inferred from the passed
      link ids (although this might be expensive, as it requires parsing the links ids passed - yet only once)

    * It's possible to do link prediction on a graph where that link type is completely removed from the graph
      (e.g., "same_as" links in ER)


    Example::

        G_generator = HinSAGELinkGenerator(G, 50, [10,10])
        data_gen = G_generator.flow(edge_ids)

    Args:
        g (StellarGraph): A machine-learning ready graph.
        batch_size (int): Size of batch of links to return.
        num_samples (list): List of number of neighbour node samples per GraphSAGE layer (hop) to take.
        seed (int or str), optional: Random seed for the sampling methods.
        name., optional: Name of generator
    """

    def __init__(self, G, batch_size, num_samples, seed=None, name=None):
        if not isinstance(G, StellarGraphBase):
            raise TypeError("Graph must be a StellarGraph object.")

        G.check_graph_for_ml(features=True)

        self.graph = G
        self.num_samples = num_samples
        self.batch_size = batch_size
        self.name = name

        # We need a schema for compatibility with HinSAGE
        self.schema = G.create_graph_schema(create_type_maps=True)

        # The sampler used to generate random samples of neighbours
        self.sampler = SampledHeterogeneousBreadthFirstWalk(
            G, graph_schema=self.schema, seed=seed
        )

    def _get_features(self, node_samples, head_size):
        """
        Collect features from sampled nodes.
        Args:
            node_samples: A list of lists of node IDs
            head_size: The number of head nodes (typically the batch size).

        Returns:
            A list of numpy arrays that store the features for each head
            node.
        """
        # Note the if there are no samples for a node a zero array is returned.
        # Resize features to (batch_size, n_neighbours, feature_size)
        # for each node type (note that we can have different feature size for each node type)
        batch_feats = [
            self.graph.get_feature_for_nodes(layer_nodes, nt)
            for nt, layer_nodes in node_samples
        ]

        # Resize features to (batch_size, n_neighbours, feature_size)
        batch_feats = [np.reshape(a, (head_size, -1, a.shape[1])) for a in batch_feats]

        return batch_feats

    def sample_features(self, head_links, sampling_schema):
        """
        Sample neighbours recursively from the head nodes, collect the features of the
        sampled nodes, and return these as a list of feature arrays for the GraphSAGE
        algorithm.

        Args:
            head_links: An iterable of edges to perform sampling for.
            sampling_schema: The sampling schema for the model

        Returns:
            A list of the same length as `num_samples` of collected features from
            the sampled nodes of shape:
                `(len(head_nodes), num_sampled_at_layer, feature_size)`
            where num_sampled_at_layer is the cumulative product of `num_samples`
            for that layer.
        """
        nodes_by_type = []
        for ii in range(2):
            # Extract head nodes from edges: each edge is a tuple of 2 nodes, so we are extracting 2 head nodes per edge
            head_nodes = [e[ii] for e in head_links]

            # Get sampled nodes for the subgraphs starting from the (src, dst) head nodes
            # nodes_samples is list of two lists: [[samples for src], [samples for dst]]
            node_samples = self.sampler.run(
                nodes=head_nodes, n=1, n_size=self.num_samples
            )

            # Reshape node samples to the required format for the HinSAGE model
            # This requires grouping the sampled nodes by edge type and in order
            nodes_by_type.append(
                [
                    (
                        nt,
                        reduce(
                            operator.concat,
                            (samples[ks] for samples in node_samples for ks in indices),
                            [],
                        ),
                    )
                    for nt, indices in sampling_schema[ii]
                ]
            )

        # Interlace the two lists, nodes_by_type[0] (for src head nodes) and nodes_by_type[1] (for dst head nodes)
        nodes_by_type = [
            tuple((ab[0][0], reduce(operator.concat, (ab[0][1], ab[1][1]))))
            for ab in zip(nodes_by_type[0], nodes_by_type[1])
        ]

        batch_feats = self._get_features(nodes_by_type, len(head_links))

        return batch_feats

    def flow(self, link_ids, targets=None, shuffle=False):
        """
        Creates a generator/sequence object for training or evaluation
        with the supplied edge IDs and numeric targets.

        The edge IDs are the edges to train or inference on. They are
        expected to by tuples of (source_id, destination_id).

        The targets are an array of numeric targets corresponding to the
        supplied link_ids to be used by the downstream task. They should
        be given in the same order as the list of link IDs.
        If they are not specified (for example, for use in prediction),
        the targets will not be available to the downsteam task.

        Note that the shuffle argument should be True for training and
        False for prediction.

        Args:
            link_ids: an iterable of (src_id, dst_id) tuples specifying the
                edges.
            targets: a 2D array of numeric targets with shape
                ``(len(link_ids), target_size)``
            shuffle (bool): If True the node_ids will be shuffled at each
                epoch, if False the node_ids will be processed in order.

        Returns:
            A LinkSequence object to use with the GraphSAGE model
            methods :meth:`fit_generator`, :meth:`evaluate_generator`, and :meth:`predict_generator`
        """
        if not isinstance(link_ids, collections.Iterable):
            raise TypeError(
                "Argument to .flow not recognised. "
                "Please pass a list of samples or a UnsupervisedSampler object."
            )

        return LinkSequence(self, link_ids, targets, shuffle)