예제 #1
0
    def __init__(self,
                 node_model_fn,
                 global_model_fn,
                 reducer=tf.unsorted_segment_sum,
                 name="deep_sets"):
        """Initializes the DeepSets module.

    Args:
      node_model_fn: A callable to be passed to NodeBlock. The callable must
        return a Sonnet module (or equivalent; see NodeBlock for details). The
        shape of this module's output must equal the shape of the input graph's
        global features, but for the first and last axis.
      global_model_fn: A callable to be passed to GlobalBlock. The callable must
        return a Sonnet module (or equivalent; see GlobalBlock for details).
      reducer: Reduction to be used when aggregating the nodes in the globals.
        This should be a callable whose signature matches
        tf.unsorted_segment_sum.
      name: The module name.
    """
        super(DeepSets, self).__init__(name=name)

        with self._enter_variable_scope():
            self._node_block = blocks.NodeBlock(node_model_fn=node_model_fn,
                                                use_received_edges=False,
                                                use_sent_edges=False,
                                                use_nodes=True,
                                                use_globals=True)
            self._global_block = blocks.GlobalBlock(
                global_model_fn=global_model_fn,
                use_edges=False,
                use_nodes=True,
                use_globals=False,
                nodes_reducer=reducer)
예제 #2
0
    def __init__(self,
                 edge_model_fn,
                 global_model_fn,
                 reducer=tf.unsorted_segment_sum,
                 name="relation_network"):
        """Initializes the RelationNetwork module.

    Args:
      edge_model_fn: A callable that will be passed to EdgeBlock to perform
        per-edge computations. The callable must return a Sonnet module (or
        equivalent; see EdgeBlock for details).
      global_model_fn: A callable that will be passed to GlobalBlock to perform
        per-global computations. The callable must return a Sonnet module (or
        equivalent; see GlobalBlock for details).
      reducer: Reducer to be used by GlobalBlock to aggregate edges. Defaults to
        tf.unsorted_segment_sum.
      name: The module name.
    """
        super(RelationNetwork, self).__init__(name=name)

        with self._enter_variable_scope():
            self._edge_block = blocks.EdgeBlock(edge_model_fn=edge_model_fn,
                                                use_edges=False,
                                                use_receiver_nodes=True,
                                                use_sender_nodes=True,
                                                use_globals=False)

            self._global_block = blocks.GlobalBlock(
                global_model_fn=global_model_fn,
                use_edges=True,
                use_nodes=False,
                use_globals=False,
                edges_reducer=reducer)
예제 #3
0
    def test_created_variables(self, use_edges, use_nodes, use_globals,
                               expected_first_dim_w):
        """Verifies the variable names and shapes created by a GlobalBlock."""
        output_size = 10
        expected_var_shapes_dict = {
            "global_block/mlp/linear_0/b:0": [output_size],
            "global_block/mlp/linear_0/w:0":
            [expected_first_dim_w, output_size]
        }

        input_graph = self._get_input_graph()

        global_block = blocks.GlobalBlock(global_model_fn=functools.partial(
            snt.nets.MLP, output_sizes=[output_size]),
                                          use_edges=use_edges,
                                          use_nodes=use_nodes,
                                          use_globals=use_globals)

        global_block(input_graph)

        variables = global_block.get_variables()
        var_shapes_dict = {
            var.name: var.get_shape().as_list()
            for var in variables
        }
        self.assertDictEqual(expected_var_shapes_dict, var_shapes_dict)
예제 #4
0
    def test_unused_field_can_be_none(self, use_edges, use_nodes, use_globals,
                                      none_field):
        """Checks that computation can handle non-necessary fields left None."""
        input_graph = self._get_input_graph([none_field])
        global_block = blocks.GlobalBlock(
            global_model_fn=self._global_model_fn,
            use_edges=use_edges,
            use_nodes=use_nodes,
            use_globals=use_globals)
        output_graph = global_block(input_graph)

        model_inputs = []
        if use_edges:
            model_inputs.append(
                blocks.EdgesToGlobalsAggregator(
                    tf.unsorted_segment_sum)(input_graph))
        if use_nodes:
            model_inputs.append(
                blocks.NodesToGlobalsAggregator(
                    tf.unsorted_segment_sum)(input_graph))
        if use_globals:
            model_inputs.append(input_graph.globals)

        model_inputs = tf.concat(model_inputs, axis=-1)
        self.assertEqual(input_graph.edges, output_graph.edges)
        self.assertEqual(input_graph.nodes, output_graph.nodes)

        with self.test_session() as sess:
            actual_globals, model_inputs_out = sess.run(
                (output_graph.globals, model_inputs))

        expected_output_globals = model_inputs_out * self._scale
        self.assertNDArrayNear(expected_output_globals,
                               actual_globals,
                               err=1e-4)
예제 #5
0
    def test_output_values(self, use_edges, use_nodes, use_globals,
                           edges_reducer, nodes_reducer):
        """Compares the output of a GlobalBlock to an explicit computation."""
        input_graph = self._get_input_graph()
        global_block = blocks.GlobalBlock(
            global_model_fn=self._global_model_fn,
            use_edges=use_edges,
            use_nodes=use_nodes,
            use_globals=use_globals,
            edges_reducer=edges_reducer,
            nodes_reducer=nodes_reducer)
        output_graph = global_block(input_graph)

        model_inputs = []
        if use_edges:
            model_inputs.append(
                blocks.EdgesToGlobalsAggregator(edges_reducer)(input_graph))
        if use_nodes:
            model_inputs.append(
                blocks.NodesToGlobalsAggregator(nodes_reducer)(input_graph))
        if use_globals:
            model_inputs.append(input_graph.globals)

        model_inputs = tf.concat(model_inputs, axis=-1)
        self.assertEqual(input_graph.edges, output_graph.edges)
        self.assertEqual(input_graph.nodes, output_graph.nodes)

        with self.test_session() as sess:
            output_graph_out, model_inputs_out = sess.run(
                (output_graph, model_inputs))

        expected_output_globals = model_inputs_out * self._scale
        self.assertNDArrayNear(expected_output_globals,
                               output_graph_out.globals,
                               err=1e-4)
예제 #6
0
 def test_no_input_raises_exception(self):
     """Checks that receiving no input raises an exception."""
     with self.assertRaisesRegexp(ValueError, "At least one of "):
         blocks.GlobalBlock(global_model_fn=self._global_model_fn,
                            use_edges=False,
                            use_nodes=False,
                            use_globals=False)
예제 #7
0
 def test_compatible_higher_rank_no_raise(self):
     """No exception should occur with higher ranks tensors."""
     input_graph = self._get_shaped_input_graph()
     input_graph = input_graph.map(lambda v: tf.transpose(v, [0, 2, 1, 3]))
     network = blocks.GlobalBlock(
         functools.partial(snt.Conv2D,
                           output_channels=10,
                           kernel_shape=[3, 3]))
     self._assert_build_and_run(network, input_graph)
예제 #8
0
 def test_missing_field_raises_exception(self, use_edges, use_nodes,
                                         use_globals, none_field):
     """Checks that missing a required field raises an exception."""
     input_graph = self._get_input_graph([none_field])
     global_block = blocks.GlobalBlock(
         global_model_fn=self._global_model_fn,
         use_edges=use_edges,
         use_nodes=use_nodes,
         use_globals=use_globals)
     with self.assertRaisesRegexp(ValueError, "field cannot be None"):
         global_block(input_graph)
예제 #9
0
 def test_missing_aggregation_raises_exception(self, use_edges, use_nodes,
                                               edges_reducer,
                                               nodes_reducer):
     """Checks that missing a required aggregation argument raises an error."""
     with self.assertRaisesRegexp(ValueError, "should not be None"):
         blocks.GlobalBlock(global_model_fn=self._global_model_fn,
                            use_edges=use_edges,
                            use_nodes=use_nodes,
                            use_globals=False,
                            edges_reducer=edges_reducer,
                            nodes_reducer=nodes_reducer)
예제 #10
0
    def __init__(self,
                 edge_model_fn,
                 node_model_fn,
                 global_model_fn,
                 reducer=tf.unsorted_segment_sum,
                 edge_block_opt=None,
                 node_block_opt=None,
                 global_block_opt=None,
                 name="graph_network"):
        """Initializes the GraphNetwork module.

    Args:
      edge_model_fn: A callable that will be passed to EdgeBlock to perform
        per-edge computations. The callable must return a Sonnet module (or
        equivalent; see EdgeBlock for details).
      node_model_fn: A callable that will be passed to NodeBlock to perform
        per-node computations. The callable must return a Sonnet module (or
        equivalent; see NodeBlock for details).
      global_model_fn: A callable that will be passed to GlobalBlock to perform
        per-global computations. The callable must return a Sonnet module (or
        equivalent; see GlobalBlock for details).
      reducer: Reducer to be used by NodeBlock and GlobalBlock to to aggregate
        nodes and edges. Defaults to tf.unsorted_segment_sum. This will be
        overridden by the reducers specified in `node_block_opt` and
        `global_block_opt`, if any.
      edge_block_opt: Additional options to be passed to the EdgeBlock. Can
        contain keys `use_edges`, `use_receiver_nodes`, `use_sender_nodes`,
        `use_globals`. By default, these are all True.
      node_block_opt: Additional options to be passed to the NodeBlock. Can
        contain the keys `use_received_edges`, `use_sent_edges`, `use_nodes`,
        `use_globals` (all set to True by default), and
        `received_edges_reducer`, `sent_edges_reducer` (default to `reducer`).
      global_block_opt: Additional options to be passed to the GlobalBlock. Can
        contain the keys `use_edges`, `use_nodes`, `use_globals` (all set to
        True by default), and `edges_reducer`, `nodes_reducer` (defaults to
        `reducer`).
      name: The module name.
    """
        super(GraphNetwork, self).__init__(name=name)
        edge_block_opt = _make_default_edge_block_opt(edge_block_opt)
        node_block_opt = _make_default_node_block_opt(node_block_opt, reducer)
        global_block_opt = _make_default_global_block_opt(
            global_block_opt, reducer)

        with self._enter_variable_scope():
            self._edge_block = blocks.EdgeBlock(edge_model_fn=edge_model_fn,
                                                **edge_block_opt)
            self._node_block = blocks.NodeBlock(node_model_fn=node_model_fn,
                                                **node_block_opt)
            self._global_block = blocks.GlobalBlock(
                global_model_fn=global_model_fn, **global_block_opt)
예제 #11
0
 def test_incompatible_higher_rank_inputs_no_raise(self, use_edges,
                                                   use_nodes, use_globals,
                                                   field):
     """No exception should occur if a differently shapped field is not used."""
     input_graph = self._get_shaped_input_graph()
     input_graph = input_graph.replace(
         **{field: tf.transpose(getattr(input_graph, field), [0, 2, 1, 3])})
     network = blocks.GlobalBlock(functools.partial(snt.Conv2D,
                                                    output_channels=10,
                                                    kernel_shape=[3, 3]),
                                  use_edges=use_edges,
                                  use_nodes=use_nodes,
                                  use_globals=use_globals)
     self._assert_build_and_run(network, input_graph)
예제 #12
0
 def test_incompatible_higher_rank_inputs_raises(self, use_edges, use_nodes,
                                                 use_globals, field):
     """A exception should be raised if the inputs have incompatible shapes."""
     input_graph = self._get_shaped_input_graph()
     input_graph = input_graph.replace(
         **{field: tf.transpose(getattr(input_graph, field), [0, 2, 1, 3])})
     network = blocks.GlobalBlock(functools.partial(snt.Conv2D,
                                                    output_channels=10,
                                                    kernel_shape=[3, 3]),
                                  use_edges=use_edges,
                                  use_nodes=use_nodes,
                                  use_globals=use_globals)
     with self.assertRaisesRegexp(ValueError,
                                  "in both shapes must be equal"):
         network(input_graph)