コード例 #1
0
    def __init__(self,
                 edge_model_fn,
                 node_encoder_model_fn,
                 node_model_fn,
                 reducer=scatter_add,
                 name="comm_net"):
        """Initializes the CommNet module.

    Args:
      edge_model_fn: A callable to be passed to EdgeBlock. The callable must
        return a Sonnet module (or equivalent; see EdgeBlock for details).
      node_encoder_model_fn: A callable to be passed to the NodeBlock
        responsible for the first encoding of the nodes. The callable must
        return a Sonnet module (or equivalent; see NodeBlock for details). The
        shape of this module's output should match the shape of the module built
        by `edge_model_fn`, but for the first and last dimension.
      node_model_fn: A callable to be passed to NodeBlock. The callable must
        return a Sonnet module (or equivalent; see NodeBlock for details).
      reducer: Reduction to be used when aggregating the edges in the nodes.
        This should be a callable whose signature matches
        tf.math.unsorted_segment_sum.
      name: The module name.
    """
        super(CommNet, self).__init__()

        # Computes $\Psi_{com}(x_j)$ in Eq. (2) of 1706.06122
        self._edge_block = blocks_torch.EdgeBlock(edge_model_fn=edge_model_fn,
                                                  use_edges=False,
                                                  use_receiver_nodes=False,
                                                  use_sender_nodes=True,
                                                  use_globals=False)
        # Computes $\Phi(x_i)$ in Eq. (2) of 1706.06122
        self._node_encoder_block = blocks_torch.NodeBlock(
            node_model_fn=node_encoder_model_fn,
            use_received_edges=False,
            use_sent_edges=False,
            use_nodes=True,
            use_globals=False,
            received_edges_reducer=reducer,
            name="node_encoder_block")
        # Computes $\Theta(..)$ in Eq.(2) of 1706.06122
        self._node_block = blocks_torch.NodeBlock(
            node_model_fn=node_model_fn,
            use_received_edges=True,
            use_sent_edges=False,
            use_nodes=True,
            use_globals=False,
            received_edges_reducer=reducer)
コード例 #2
0
    def test_unused_field_can_be_none(self, use_edges, use_nodes, use_globals,
                                      none_field):
        """Checks that computation can handle non-necessary fields left None."""
        input_graph = self._get_input_graph([none_field])
        node_block = blocks_torch.NodeBlock(node_model_fn=self._node_model_fn,
                                            use_received_edges=use_edges,
                                            use_sent_edges=use_edges,
                                            use_nodes=use_nodes,
                                            use_globals=use_globals)
        output_graph = node_block(input_graph)

        model_inputs = []
        if use_edges:
            model_inputs.append(
                blocks_torch.ReceivedEdgesToNodesAggregator(
                    torch_scatter.scatter_add)(input_graph))
            model_inputs.append(
                blocks_torch.SentEdgesToNodesAggregator(
                    torch_scatter.scatter_add)(input_graph))
        if use_nodes:
            model_inputs.append(input_graph.nodes)
        if use_globals:
            model_inputs.append(
                blocks_torch.broadcast_globals_to_nodes(input_graph))

        model_inputs = torch.cat(model_inputs,
                                 dim=-1)  # TODO: Check semantics of dim=-1
        self.assertEqual(input_graph.edges, output_graph.edges)
        self.assertEqual(input_graph.globals, output_graph.globals)

        expected_output_nodes = model_inputs.numpy() * self._scale
        self.assertNDArrayNear(expected_output_nodes,
                               output_graph.nodes.numpy(),
                               err=1e-4)
コード例 #3
0
    def test_created_variables(self, use_received_edges, use_sent_edges,
                               use_nodes, use_globals, expected_first_dim_w):
        """Verifies the variable names and shapes created by a NodeBlock."""
        output_size = 10
        expected_var_shapes_dict = {
            "_node_model.mlp.0.bias": [output_size],
            "_node_model.mlp.0.weight": [expected_first_dim_w, output_size]
        }

        input_graph = self._get_input_graph()

        node_block = blocks_torch.NodeBlock(
            node_model_fn=functools.partial(MLP, output_sizes=[output_size]),
            use_received_edges=use_received_edges,
            use_sent_edges=use_sent_edges,
            use_nodes=use_nodes,
            use_globals=use_globals)

        node_block(input_graph)

        variables = node_block.state_dict()
        var_shapes_dict = {
            var: list(reversed(variables[var].shape))
            for var in variables
        }
        self.assertDictEqual(expected_var_shapes_dict, var_shapes_dict)
コード例 #4
0
    def __init__(self,
                 edge_model_fn,
                 node_model_fn,
                 reducer=scatter_add,
                 name="interaction_network"):
        """Initializes the InteractionNetwork module.

    Args:
      edge_model_fn: A callable that will be passed to `EdgeBlock` to perform
        per-edge computations. The callable must return a Sonnet module (or
        equivalent; see `blocks.EdgeBlock` for details), and the shape of the
        output of this module must match the one of the input nodes, but for the
        first and last axis.
      node_model_fn: A callable that will be passed to `NodeBlock` to perform
        per-node computations. The callable must return a Sonnet module (or
        equivalent; see `blocks.NodeBlock` for details).
      reducer: Reducer to be used by NodeBlock to aggregate edges. Defaults to
        tf.math.unsorted_segment_sum.
      name: The module name.
    """
        super(InteractionNetwork, self).__init__()

        self._edge_block = blocks_torch.EdgeBlock(edge_model_fn=edge_model_fn,
                                                  use_globals=False)
        self._node_block = blocks_torch.NodeBlock(
            node_model_fn=node_model_fn,
            use_sent_edges=False,
            use_globals=False,
            received_edges_reducer=reducer)
コード例 #5
0
    def __init__(self,
                 node_model_fn,
                 global_model_fn,
                 reducer=scatter_add,
                 name="deep_sets"):
        """Initializes the DeepSets module.

    Args:
      node_model_fn: A callable to be passed to NodeBlock. The callable must
        return a Sonnet module (or equivalent; see NodeBlock for details). The
        shape of this module's output must equal the shape of the input graph's
        global features, but for the first and last axis.
      global_model_fn: A callable to be passed to GlobalBlock. The callable must
        return a Sonnet module (or equivalent; see GlobalBlock for details).
      reducer: Reduction to be used when aggregating the nodes in the globals.
        This should be a callable whose signature matches
        tf.math.unsorted_segment_sum.
      name: The module name.
    """
        super(DeepSets, self).__init__()

        self._node_block = blocks_torch.NodeBlock(node_model_fn=node_model_fn,
                                                  use_received_edges=False,
                                                  use_sent_edges=False,
                                                  use_nodes=True,
                                                  use_globals=True)
        self._global_block = blocks_torch.GlobalBlock(
            global_model_fn=global_model_fn,
            use_edges=False,
            use_nodes=True,
            use_globals=False,
            nodes_reducer=reducer)
コード例 #6
0
 def test_no_input_raises_exception(self):
     """Checks that receiving no input raises an exception."""
     with self.assertRaisesRegexp(ValueError, "At least one of "):
         blocks_torch.NodeBlock(node_model_fn=self._node_model_fn,
                                use_received_edges=False,
                                use_sent_edges=False,
                                use_nodes=False,
                                use_globals=False)
コード例 #7
0
 def test_compatible_higher_rank_no_raise(self):
     """No exception should occur with higher ranks tensors."""
     input_graph = self._get_shaped_input_graph()
     input_graph = input_graph.map(
         lambda v: v.permute(0, 2, 1, 3))  # TODO: change
     network = blocks_torch.NodeBlock(
         functools.partial(Conv2D, output_channels=10,
                           kernel_shape=[3, 3]))  # TODO: change
     self._assert_build_and_run(network, input_graph)
コード例 #8
0
 def test_missing_field_raises_exception(self, use_received_edges,
                                         use_sent_edges, use_nodes,
                                         use_globals, none_fields):
     """Checks that missing a required field raises an exception."""
     input_graph = self._get_input_graph(none_fields)
     node_block = blocks_torch.NodeBlock(
         node_model_fn=self._node_model_fn,
         use_received_edges=use_received_edges,
         use_sent_edges=use_sent_edges,
         use_nodes=use_nodes,
         use_globals=use_globals)
     with self.assertRaisesRegexp(ValueError, "field cannot be None"):
         node_block(input_graph)
コード例 #9
0
 def test_missing_aggregation_raises_exception(self, use_received_edges,
                                               use_sent_edges,
                                               received_edges_reducer,
                                               sent_edges_reducer):
     """Checks that missing a required aggregation argument raises an error."""
     with self.assertRaisesRegexp(ValueError, "should not be None"):
         blocks_torch.NodeBlock(
             node_model_fn=self._node_model_fn,
             use_received_edges=use_received_edges,
             use_sent_edges=use_sent_edges,
             use_nodes=False,
             use_globals=False,
             received_edges_reducer=received_edges_reducer,
             sent_edges_reducer=sent_edges_reducer)
コード例 #10
0
    def __init__(self,
                 edge_model_fn,
                 node_model_fn,
                 global_model_fn,
                 reducer=scatter_add,
                 edge_block_opt=None,
                 node_block_opt=None,
                 global_block_opt=None,
                 name="graph_network"):
        """Initializes the GraphNetwork module.

    Args:
      edge_model_fn: A callable that will be passed to EdgeBlock to perform
        per-edge computations. The callable must return a Sonnet module (or
        equivalent; see EdgeBlock for details).
      node_model_fn: A callable that will be passed to NodeBlock to perform
        per-node computations. The callable must return a Sonnet module (or
        equivalent; see NodeBlock for details).
      global_model_fn: A callable that will be passed to GlobalBlock to perform
        per-global computations. The callable must return a Sonnet module (or
        equivalent; see GlobalBlock for details).
      reducer: Reducer to be used by NodeBlock and GlobalBlock to aggregate
        nodes and edges. Defaults to tf.math.unsorted_segment_sum. This will be
        overridden by the reducers specified in `node_block_opt` and
        `global_block_opt`, if any.
      edge_block_opt: Additional options to be passed to the EdgeBlock. Can
        contain keys `use_edges`, `use_receiver_nodes`, `use_sender_nodes`,
        `use_globals`. By default, these are all True.
      node_block_opt: Additional options to be passed to the NodeBlock. Can
        contain the keys `use_received_edges`, `use_nodes`, `use_globals` (all
        set to True by default), `use_sent_edges` (defaults to False), and
        `received_edges_reducer`, `sent_edges_reducer` (default to `reducer`).
      global_block_opt: Additional options to be passed to the GlobalBlock. Can
        contain the keys `use_edges`, `use_nodes`, `use_globals` (all set to
        True by default), and `edges_reducer`, `nodes_reducer` (defaults to
        `reducer`).
      name: The module name.
    """
        super(GraphNetwork, self).__init__()
        edge_block_opt = _make_default_edge_block_opt(edge_block_opt)
        node_block_opt = _make_default_node_block_opt(node_block_opt, reducer)
        global_block_opt = _make_default_global_block_opt(
            global_block_opt, reducer)

        self._edge_block = blocks_torch.EdgeBlock(edge_model_fn=edge_model_fn,
                                                  **edge_block_opt)
        self._node_block = blocks_torch.NodeBlock(node_model_fn=node_model_fn,
                                                  **node_block_opt)
        self._global_block = blocks_torch.GlobalBlock(
            global_model_fn=global_model_fn, **global_block_opt)
コード例 #11
0
 def test_incompatible_higher_rank_inputs_no_raise(self, use_received_edges,
                                                   use_sent_edges,
                                                   use_nodes, use_globals,
                                                   field):
     """No exception should occur if a differently shapped field is not used."""
     input_graph = self._get_shaped_input_graph()
     input_graph = input_graph.replace(**{
         field:
         getattr(input_graph, field).permute(0, 2, 1, 3)
     })  # TODO: change
     network = blocks_torch.NodeBlock(
         functools.partial(Conv2D, output_channels=10,
                           kernel_shape=[3, 3]),  # TODO: change
         use_received_edges=use_received_edges,
         use_sent_edges=use_sent_edges,
         use_nodes=use_nodes,
         use_globals=use_globals)
     self._assert_build_and_run(network, input_graph)
コード例 #12
0
 def test_incompatible_higher_rank_inputs_raises(self, use_received_edges,
                                                 use_sent_edges, use_nodes,
                                                 use_globals, field):
     """A exception should be raised if the inputs have incompatible shapes."""
     input_graph = self._get_shaped_input_graph()
     input_graph = input_graph.replace(**{
         field:
         getattr(input_graph, field).permute(0, 2, 1, 3)
     })  # TODO: change
     network = blocks_torch.NodeBlock(
         functools.partial(Conv2D, output_channels=10,
                           kernel_shape=[3, 3]),  # TODO: change
         use_received_edges=use_received_edges,
         use_sent_edges=use_sent_edges,
         use_nodes=use_nodes,
         use_globals=use_globals)
     with self.assertRaisesRegexp(ValueError,
                                  "in both shapes must be equal"):
         network(input_graph)
コード例 #13
0
    def test_output_values(self, use_received_edges, use_sent_edges, use_nodes,
                           use_globals, received_edges_reducer,
                           sent_edges_reducer):
        """Compares the output of a NodeBlock to an explicit computation."""
        input_graph = self._get_input_graph()
        node_block = blocks_torch.NodeBlock(
            node_model_fn=self._node_model_fn,
            use_received_edges=use_received_edges,
            use_sent_edges=use_sent_edges,
            use_nodes=use_nodes,
            use_globals=use_globals,
            received_edges_reducer=received_edges_reducer,
            sent_edges_reducer=sent_edges_reducer)
        output_graph = node_block(input_graph)

        model_inputs = []
        if use_received_edges:
            model_inputs.append(
                blocks_torch.ReceivedEdgesToNodesAggregator(
                    received_edges_reducer)(input_graph))
        if use_sent_edges:
            model_inputs.append(
                blocks_torch.SentEdgesToNodesAggregator(sent_edges_reducer)(
                    input_graph))
        if use_nodes:
            model_inputs.append(input_graph.nodes)
        if use_globals:
            model_inputs.append(
                blocks_torch.broadcast_globals_to_nodes(input_graph))

        model_inputs = torch.cat(
            model_inputs,
            dim=-1)  # TODO: check the semantics of dim=-1 in pytorch
        self.assertEqual(input_graph.edges, output_graph.edges)
        self.assertEqual(input_graph.globals, output_graph.globals)

        expected_output_nodes = model_inputs.numpy() * self._scale
        self.assertNDArrayNear(expected_output_nodes,
                               output_graph.nodes.numpy(),
                               err=1e-4)