예제 #1
0
    def __init__(self,
                 edge_model_fn,
                 node_model_fn,
                 reducer=tf.unsorted_segment_sum,
                 name="interaction_network"):
        """Initializes the InteractionNetwork module.

    Args:
      edge_model_fn: A callable that will be passed to `EdgeBlock` to perform
        per-edge computations. The callable must return a Sonnet module (or
        equivalent; see `blocks.EdgeBlock` for details), and the shape of the
        output of this module must match the one of the input nodes, but for the
        first and last axis.
      node_model_fn: A callable that will be passed to `NodeBlock` to perform
        per-node computations. The callable must return a Sonnet module (or
        equivalent; see `blocks.NodeBlock` for details).
      reducer: Reducer to be used by NodeBlock to aggregate edges. Defaults to
        tf.unsorted_segment_sum.
      name: The module name.
    """
        super(InteractionNetwork, self).__init__(name=name)

        with self._enter_variable_scope():
            self._edge_block = blocks.EdgeBlock(edge_model_fn=edge_model_fn,
                                                use_globals=False)
            self._node_block = blocks.NodeBlock(node_model_fn=node_model_fn,
                                                use_sent_edges=False,
                                                use_globals=False,
                                                received_edges_reducer=reducer)
예제 #2
0
    def test_created_variables(self, use_edges, use_receiver_nodes,
                               use_sender_nodes, use_globals,
                               expected_first_dim_w):
        """Verifies the variable names and shapes created by an EdgeBlock."""
        output_size = 10
        expected_var_shapes_dict = {
            "edge_block/mlp/linear_0/b:0": [output_size],
            "edge_block/mlp/linear_0/w:0": [expected_first_dim_w, output_size]
        }

        input_graph = self._get_input_graph()

        edge_block = blocks.EdgeBlock(edge_model_fn=functools.partial(
            snt.nets.MLP, output_sizes=[output_size]),
                                      use_edges=use_edges,
                                      use_receiver_nodes=use_receiver_nodes,
                                      use_sender_nodes=use_sender_nodes,
                                      use_globals=use_globals)
        edge_block(input_graph)

        variables = edge_block.get_variables()
        var_shapes_dict = {
            var.name: var.get_shape().as_list()
            for var in variables
        }
        self.assertDictEqual(expected_var_shapes_dict, var_shapes_dict)
예제 #3
0
    def __init__(self,
                 edge_model_fn,
                 global_model_fn,
                 reducer=tf.unsorted_segment_sum,
                 name="relation_network"):
        """Initializes the RelationNetwork module.

    Args:
      edge_model_fn: A callable that will be passed to EdgeBlock to perform
        per-edge computations. The callable must return a Sonnet module (or
        equivalent; see EdgeBlock for details).
      global_model_fn: A callable that will be passed to GlobalBlock to perform
        per-global computations. The callable must return a Sonnet module (or
        equivalent; see GlobalBlock for details).
      reducer: Reducer to be used by GlobalBlock to aggregate edges. Defaults to
        tf.unsorted_segment_sum.
      name: The module name.
    """
        super(RelationNetwork, self).__init__(name=name)

        with self._enter_variable_scope():
            self._edge_block = blocks.EdgeBlock(edge_model_fn=edge_model_fn,
                                                use_edges=False,
                                                use_receiver_nodes=True,
                                                use_sender_nodes=True,
                                                use_globals=False)

            self._global_block = blocks.GlobalBlock(
                global_model_fn=global_model_fn,
                use_edges=True,
                use_nodes=False,
                use_globals=False,
                edges_reducer=reducer)
예제 #4
0
    def test_unused_field_can_be_none(self, use_edges, use_nodes, use_globals,
                                      none_field):
        """Checks that computation can handle non-necessary fields left None."""
        input_graph = self._get_input_graph([none_field])
        edge_block = blocks.EdgeBlock(edge_model_fn=self._edge_model_fn,
                                      use_edges=use_edges,
                                      use_receiver_nodes=use_nodes,
                                      use_sender_nodes=use_nodes,
                                      use_globals=use_globals)
        output_graph = edge_block(input_graph)

        model_inputs = []
        if use_edges:
            model_inputs.append(input_graph.edges)
        if use_nodes:
            model_inputs.append(
                blocks.broadcast_receiver_nodes_to_edges(input_graph))
            model_inputs.append(
                blocks.broadcast_sender_nodes_to_edges(input_graph))
        if use_globals:
            model_inputs.append(blocks.broadcast_globals_to_edges(input_graph))

        model_inputs = tf.concat(model_inputs, axis=-1)
        self.assertEqual(input_graph.nodes, output_graph.nodes)
        self.assertEqual(input_graph.globals, output_graph.globals)

        with self.test_session() as sess:
            actual_edges, model_inputs_out = sess.run(
                (output_graph.edges, model_inputs))

        expected_output_edges = model_inputs_out * self._scale
        self.assertNDArrayNear(expected_output_edges, actual_edges, err=1e-4)
예제 #5
0
    def test_output_values(self, use_edges, use_receiver_nodes,
                           use_sender_nodes, use_globals):
        """Compares the output of an EdgeBlock to an explicit computation."""
        input_graph = self._get_input_graph()
        edge_block = blocks.EdgeBlock(edge_model_fn=self._edge_model_fn,
                                      use_edges=use_edges,
                                      use_receiver_nodes=use_receiver_nodes,
                                      use_sender_nodes=use_sender_nodes,
                                      use_globals=use_globals)
        output_graph = edge_block(input_graph)

        model_inputs = []
        if use_edges:
            model_inputs.append(input_graph.edges)
        if use_receiver_nodes:
            model_inputs.append(
                blocks.broadcast_receiver_nodes_to_edges(input_graph))
        if use_sender_nodes:
            model_inputs.append(
                blocks.broadcast_sender_nodes_to_edges(input_graph))
        if use_globals:
            model_inputs.append(blocks.broadcast_globals_to_edges(input_graph))

        model_inputs = tf.concat(model_inputs, axis=-1)
        self.assertEqual(input_graph.nodes, output_graph.nodes)
        self.assertEqual(input_graph.globals, output_graph.globals)

        with self.test_session() as sess:
            output_graph_out, model_inputs_out = sess.run(
                (output_graph, model_inputs))

        expected_output_edges = model_inputs_out * self._scale
        self.assertNDArrayNear(expected_output_edges,
                               output_graph_out.edges,
                               err=1e-4)
예제 #6
0
 def test_no_input_raises_exception(self):
     """Checks that receiving no input raises an exception."""
     with self.assertRaisesRegexp(ValueError, "At least one of "):
         blocks.EdgeBlock(edge_model_fn=self._edge_model_fn,
                          use_edges=False,
                          use_receiver_nodes=False,
                          use_sender_nodes=False,
                          use_globals=False)
예제 #7
0
 def test_compatible_higher_rank_no_raise(self):
     """No exception should occur with higher ranks tensors."""
     input_graph = self._get_shaped_input_graph()
     input_graph = input_graph.map(lambda v: tf.transpose(v, [0, 2, 1, 3]))
     network = blocks.EdgeBlock(
         functools.partial(snt.Conv2D,
                           output_channels=10,
                           kernel_shape=[3, 3]))
     self._assert_build_and_run(network, input_graph)
예제 #8
0
 def test_missing_field_raises_exception(self, use_edges,
                                         use_receiver_nodes,
                                         use_sender_nodes, use_globals,
                                         none_fields):
     """Checks that missing a required field raises an exception."""
     input_graph = self._get_input_graph(none_fields)
     edge_block = blocks.EdgeBlock(edge_model_fn=self._edge_model_fn,
                                   use_edges=use_edges,
                                   use_receiver_nodes=use_receiver_nodes,
                                   use_sender_nodes=use_sender_nodes,
                                   use_globals=use_globals)
     with self.assertRaisesRegexp(ValueError, "field cannot be None"):
         edge_block(input_graph)
예제 #9
0
    def __init__(self,
                 edge_model_fn,
                 node_model_fn,
                 global_model_fn,
                 reducer=tf.unsorted_segment_sum,
                 edge_block_opt=None,
                 node_block_opt=None,
                 global_block_opt=None,
                 name="graph_network"):
        """Initializes the GraphNetwork module.

    Args:
      edge_model_fn: A callable that will be passed to EdgeBlock to perform
        per-edge computations. The callable must return a Sonnet module (or
        equivalent; see EdgeBlock for details).
      node_model_fn: A callable that will be passed to NodeBlock to perform
        per-node computations. The callable must return a Sonnet module (or
        equivalent; see NodeBlock for details).
      global_model_fn: A callable that will be passed to GlobalBlock to perform
        per-global computations. The callable must return a Sonnet module (or
        equivalent; see GlobalBlock for details).
      reducer: Reducer to be used by NodeBlock and GlobalBlock to to aggregate
        nodes and edges. Defaults to tf.unsorted_segment_sum. This will be
        overridden by the reducers specified in `node_block_opt` and
        `global_block_opt`, if any.
      edge_block_opt: Additional options to be passed to the EdgeBlock. Can
        contain keys `use_edges`, `use_receiver_nodes`, `use_sender_nodes`,
        `use_globals`. By default, these are all True.
      node_block_opt: Additional options to be passed to the NodeBlock. Can
        contain the keys `use_received_edges`, `use_sent_edges`, `use_nodes`,
        `use_globals` (all set to True by default), and
        `received_edges_reducer`, `sent_edges_reducer` (default to `reducer`).
      global_block_opt: Additional options to be passed to the GlobalBlock. Can
        contain the keys `use_edges`, `use_nodes`, `use_globals` (all set to
        True by default), and `edges_reducer`, `nodes_reducer` (defaults to
        `reducer`).
      name: The module name.
    """
        super(GraphNetwork, self).__init__(name=name)
        edge_block_opt = _make_default_edge_block_opt(edge_block_opt)
        node_block_opt = _make_default_node_block_opt(node_block_opt, reducer)
        global_block_opt = _make_default_global_block_opt(
            global_block_opt, reducer)

        with self._enter_variable_scope():
            self._edge_block = blocks.EdgeBlock(edge_model_fn=edge_model_fn,
                                                **edge_block_opt)
            self._node_block = blocks.NodeBlock(node_model_fn=node_model_fn,
                                                **node_block_opt)
            self._global_block = blocks.GlobalBlock(
                global_model_fn=global_model_fn, **global_block_opt)
예제 #10
0
    def __init__(self,
                 edge_model_fn,
                 node_encoder_model_fn,
                 node_model_fn,
                 reducer=tf.unsorted_segment_sum,
                 name="comm_net"):
        """Initializes the CommNet module.

    Args:
      edge_model_fn: A callable to be passed to EdgeBlock. The callable must
        return a Sonnet module (or equivalent; see EdgeBlock for details).
      node_encoder_model_fn: A callable to be passed to the NodeBlock
        responsible for the first encoding of the nodes. The callable must
        return a Sonnet module (or equivalent; see NodeBlock for details). The
        shape of this module's output should match the shape of the module built
        by `edge_model_fn`, but for the first and last dimension.
      node_model_fn: A callable to be passed to NodeBlock. The callable must
        return a Sonnet module (or equivalent; see NodeBlock for details).
      reducer: Reduction to be used when aggregating the edges in the nodes.
        This should be a callable whose signature matches
        tf.unsorted_segment_sum.
      name: The module name.
    """
        super(CommNet, self).__init__(name=name)

        with self._enter_variable_scope():
            # Computes $\Psi_{com}(x_j)$ in Eq. (2) of 1706.06122
            self._edge_block = blocks.EdgeBlock(edge_model_fn=edge_model_fn,
                                                use_edges=False,
                                                use_receiver_nodes=False,
                                                use_sender_nodes=True,
                                                use_globals=False)
            # Computes $\Phi(x_i)$ in Eq. (2) of 1706.06122
            self._node_encoder_block = blocks.NodeBlock(
                node_model_fn=node_encoder_model_fn,
                use_received_edges=False,
                use_sent_edges=False,
                use_nodes=True,
                use_globals=False,
                received_edges_reducer=reducer,
                name="node_encoder_block")
            # Computes $\Theta(..)$ in Eq.(2) of 1706.06122
            self._node_block = blocks.NodeBlock(node_model_fn=node_model_fn,
                                                use_received_edges=True,
                                                use_sent_edges=False,
                                                use_nodes=True,
                                                use_globals=False,
                                                received_edges_reducer=reducer)
예제 #11
0
 def test_incompatible_higher_rank_inputs_no_raise(self, use_edges,
                                                   use_receiver_nodes,
                                                   use_sender_nodes,
                                                   use_globals, field):
     """No exception should occur if a differently shapped field is not used."""
     input_graph = self._get_shaped_input_graph()
     input_graph = input_graph.replace(
         **{field: tf.transpose(getattr(input_graph, field), [0, 2, 1, 3])})
     network = blocks.EdgeBlock(functools.partial(snt.Conv2D,
                                                  output_channels=10,
                                                  kernel_shape=[3, 3]),
                                use_edges=use_edges,
                                use_receiver_nodes=use_receiver_nodes,
                                use_sender_nodes=use_sender_nodes,
                                use_globals=use_globals)
     self._assert_build_and_run(network, input_graph)
예제 #12
0
 def test_incompatible_higher_rank_inputs_raises(self, use_edges,
                                                 use_receiver_nodes,
                                                 use_sender_nodes,
                                                 use_globals, field):
     """A exception should be raised if the inputs have incompatible shapes."""
     input_graph = self._get_shaped_input_graph()
     input_graph = input_graph.replace(
         **{field: tf.transpose(getattr(input_graph, field), [0, 2, 1, 3])})
     network = blocks.EdgeBlock(functools.partial(snt.Conv2D,
                                                  output_channels=10,
                                                  kernel_shape=[3, 3]),
                                use_edges=use_edges,
                                use_receiver_nodes=use_receiver_nodes,
                                use_sender_nodes=use_sender_nodes,
                                use_globals=use_globals)
     with self.assertRaisesRegexp(ValueError,
                                  "in both shapes must be equal"):
         network(input_graph)