Exemple #1
0
    def test_dynamic_batch_sizes(self, block_constructor):
        """Checks that all batch sizes are as expected through a GraphNetwork."""
        # Remove all placeholders from here, these are unnecessary in tf2.
        input_graph = utils_np.data_dicts_to_graphs_tuple(
            [SMALL_GRAPH_1, SMALL_GRAPH_2])
        input_graph = input_graph.map(tf.constant, fields=graphs.ALL_FIELDS)
        model = block_constructor(
            functools.partial(snt.nets.MLP, output_sizes=[10]))
        output = model(input_graph)
        actual = utils_tf.nest_to_numpy(output)

        for k, v in input_graph._asdict().items():
            self.assertEqual(v.shape[0], getattr(actual, k).shape[0])
  def test_getitem_one(self, use_tensor_index):
    index = 2
    expected = self.graphs_dicts_out[index]

    if use_tensor_index:
      index = tf.constant(index)

    graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(self.graphs_dicts_in)
    graph = utils_tf.get_graph(graphs_tuple, index)

    graph = utils_tf.nest_to_numpy(graph)
    actual, = utils_np.graphs_tuple_to_data_dicts(graph)

    for k, v in expected.items():
      self.assertAllClose(v, actual[k])
    self.assertEqual(expected["nodes"].shape[0], actual["n_node"])
    self.assertEqual(expected["edges"].shape[0], actual["n_edge"])
  def test_getitem(self, use_tensor_slice):
    index = slice(1, 3)
    expected = self.graphs_dicts_out[index]

    if use_tensor_slice:
      index = slice(tf.constant(index.start), tf.constant(index.stop))

    graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(self.graphs_dicts_in)
    graphs2 = utils_tf.get_graph(graphs_tuple, index)

    graphs2 = utils_tf.nest_to_numpy(graphs2)
    actual = utils_np.graphs_tuple_to_data_dicts(graphs2)

    for ex, ac in zip(expected, actual):
      for k, v in ex.items():
        self.assertAllClose(v, ac[k])
      self.assertEqual(ex["nodes"].shape[0], ac["n_node"])
      self.assertEqual(ex["edges"].shape[0], ac["n_edge"])
Exemple #4
0
    def test_output_values(self, use_received_edges, use_sent_edges, use_nodes,
                           use_globals, received_edges_reducer,
                           sent_edges_reducer):
        """Compares the output of a NodeBlock to an explicit computation."""
        input_graph = self._get_input_graph()
        node_block = blocks.NodeBlock(
            node_model_fn=self._node_model_fn,
            use_received_edges=use_received_edges,
            use_sent_edges=use_sent_edges,
            use_nodes=use_nodes,
            use_globals=use_globals,
            received_edges_reducer=received_edges_reducer,
            sent_edges_reducer=sent_edges_reducer)
        output_graph = node_block(input_graph)

        model_inputs = []
        if use_received_edges:
            model_inputs.append(
                blocks.ReceivedEdgesToNodesAggregator(received_edges_reducer)(
                    input_graph))
        if use_sent_edges:
            model_inputs.append(
                blocks.SentEdgesToNodesAggregator(sent_edges_reducer)(
                    input_graph))
        if use_nodes:
            model_inputs.append(input_graph.nodes)
        if use_globals:
            model_inputs.append(blocks.broadcast_globals_to_nodes(input_graph))

        model_inputs = tf.concat(model_inputs, axis=-1)
        self.assertIs(input_graph.edges, output_graph.edges)
        self.assertIs(input_graph.globals, output_graph.globals)

        output_graph_out = utils_tf.nest_to_numpy(output_graph)
        model_inputs_out = model_inputs

        expected_output_nodes = model_inputs_out * self._scale
        self.assertNDArrayNear(expected_output_nodes,
                               output_graph_out.nodes,
                               err=1e-4)
Exemple #5
0
#   nd = len(graph.nodes())
#   probs = np.zeros((nd, nd))
#   for edge in graph.edges(data=True):
#     probs[edge[0], edge[1]] = edge[2]["features"][0]
#   ax.matshow(probs[sort_indices][:, sort_indices], cmap="viridis")
#   ax.grid(False)



num_elements_min_max = (5, 10)

inputs, targets, sort_indices, ranks = create_data(
    1, num_elements_min_max)

inputs_nodes = inputs.nodes.numpy() #[7,1]
targets = utils_tf.nest_to_numpy(targets)
sort_indices_nodes = sort_indices.nodes.numpy()
ranks_nodes = ranks.nodes.numpy()

sort_indices = np.squeeze(sort_indices_nodes).astype(int)

# # Plot sort linked lists.
# # The matrix plots show each element from the sorted list (rows), and which
# # element they link to as next largest (columns). Ground truth is a diagonal
# # offset toward the upper-right by one.
# fig = plt.figure(1, figsize=(4, 4))
# fig.clf()
# ax = fig.add_subplot(1, 1, 1)
# plot_linked_list(ax,
#                  utils_np.graphs_tuple_to_networkxs(targets)[0], sort_indices)
# ax.set_title("Element-to-element links for sorted elements")
Exemple #6
0
    def test_add_remove_padding(self, experimental_unconnected_padding_edges,
                                nested_features):
        data_dict = test_utils_tf1.generate_random_data_dict(
            (7, ), (8, ), (9, ),
            num_nodes_range=(10, 15),
            num_edges_range=(20, 25))
        node_size_np = data_dict["nodes"].shape[0]
        edge_size_np = data_dict["edges"].shape[0]
        unpadded_batch_size = 2
        graphs_tuple = utils_tf.data_dicts_to_graphs_tuple(
            unpadded_batch_size * [data_dict])

        if nested_features:
            graphs_tuple = graphs_tuple.replace(edges=[graphs_tuple.edges, {}],
                                                nodes=({
                                                    "tensor":
                                                    graphs_tuple.nodes
                                                }, ),
                                                globals=(
                                                    [],
                                                    graphs_tuple.globals,
                                                ))

        num_padding_nodes = 3
        num_padding_edges = 4
        num_padding_graphs = 5
        pad_nodes_to = unpadded_batch_size * node_size_np + num_padding_nodes
        pad_edges_to = unpadded_batch_size * edge_size_np + num_padding_edges
        pad_graphs_to = unpadded_batch_size + num_padding_graphs

        def _get_padded_and_recovered_graphs_tuple(graphs_tuple):
            padded_graphs_tuple = utils_tf.pad_graphs_tuple(
                graphs_tuple, pad_nodes_to, pad_edges_to, pad_graphs_to,
                experimental_unconnected_padding_edges)

            # Check that we have statically defined shapes after padding.
            self.assertEqual(_leading_static_shape(padded_graphs_tuple.nodes),
                             pad_nodes_to)
            self.assertEqual(_leading_static_shape(padded_graphs_tuple.edges),
                             pad_edges_to)
            self.assertEqual(
                _leading_static_shape(padded_graphs_tuple.senders),
                pad_edges_to)
            self.assertEqual(
                _leading_static_shape(padded_graphs_tuple.receivers),
                pad_edges_to)
            self.assertEqual(
                _leading_static_shape(padded_graphs_tuple.globals),
                pad_graphs_to)
            self.assertEqual(_leading_static_shape(padded_graphs_tuple.n_node),
                             pad_graphs_to)
            self.assertEqual(_leading_static_shape(padded_graphs_tuple.n_edge),
                             pad_graphs_to)

            # Check that we can remove the padding.
            graphs_tuple_size = utils_tf.get_graphs_tuple_size(graphs_tuple)
            recovered_graphs_tuple = utils_tf.remove_graphs_tuple_padding(
                padded_graphs_tuple, graphs_tuple_size)

            return padded_graphs_tuple, recovered_graphs_tuple

        # Put it into a tf.function so the shapes are unknown statically.
        compiled_fn = _compile_with_tf_function(
            _get_padded_and_recovered_graphs_tuple, graphs_tuple)
        padded_graphs_tuple, recovered_graphs_tuple = compiled_fn(graphs_tuple)

        if nested_features:
            # Check that the whole structure of the outputs are the same.
            tree.assert_same_structure(padded_graphs_tuple, graphs_tuple)
            tree.assert_same_structure(recovered_graphs_tuple, graphs_tuple)

            # Undo the nesting for the rest of the test.
            def remove_nesting(this_graphs_tuple):
                return this_graphs_tuple.replace(
                    edges=this_graphs_tuple.edges[0],
                    nodes=this_graphs_tuple.nodes[0]["tensor"],
                    globals=this_graphs_tuple.globals[1])

            graphs_tuple = remove_nesting(graphs_tuple)
            padded_graphs_tuple = remove_nesting(padded_graphs_tuple)
            recovered_graphs_tuple = remove_nesting(recovered_graphs_tuple)

        # Inspect the padded_graphs_tuple.
        padded_graphs_tuple_data_dicts = utils_np.graphs_tuple_to_data_dicts(
            utils_tf.nest_to_numpy(padded_graphs_tuple))
        graphs_tuple_data_dicts = utils_np.graphs_tuple_to_data_dicts(
            utils_tf.nest_to_numpy(graphs_tuple))

        self.assertLen(padded_graphs_tuple, pad_graphs_to)

        # Check that the first 2 graphs from the padded_graphs_tuple are the same.
        for example_i in range(unpadded_batch_size):
            tree.map_structure(self.assertAllEqual,
                               graphs_tuple_data_dicts[example_i],
                               padded_graphs_tuple_data_dicts[example_i])

        padding_data_dicts = padded_graphs_tuple_data_dicts[
            unpadded_batch_size:]
        # Check that the third graph contains all of the padding nodes and edges.

        for i, padding_data_dict in enumerate(padding_data_dicts):

            # Only the first padding graph has nodes and edges.
            num_nodes = num_padding_nodes if i == 0 else 0
            num_edges = num_padding_edges if i == 0 else 0

            self.assertAllEqual(padding_data_dict["globals"],
                                np.zeros([9], dtype=np.float32))

            self.assertEqual(padding_data_dict["n_node"], num_nodes)
            self.assertAllEqual(padding_data_dict["nodes"],
                                np.zeros([num_nodes, 7], dtype=np.float32))
            self.assertEqual(padding_data_dict["n_edge"], num_edges)
            self.assertAllEqual(padding_data_dict["edges"],
                                np.zeros([num_edges, 8], dtype=np.float32))

            if experimental_unconnected_padding_edges:
                self.assertAllEqual(
                    padding_data_dict["receivers"],
                    np.zeros([num_edges], dtype=np.int32) + num_nodes)
                self.assertAllEqual(
                    padding_data_dict["senders"],
                    np.zeros([num_edges], dtype=np.int32) + num_nodes)
            else:
                self.assertAllEqual(padding_data_dict["receivers"],
                                    np.zeros([num_edges], dtype=np.int32))
                self.assertAllEqual(padding_data_dict["senders"],
                                    np.zeros([num_edges], dtype=np.int32))

        # Check that the recovered_graphs_tuple after removing padding is identical.
        tree.map_structure(self.assertAllEqual, graphs_tuple._asdict(),
                           recovered_graphs_tuple._asdict())