def __init__(self, node_model_fn, global_model_fn, reducer=tf.unsorted_segment_sum, name="deep_sets"): """Initializes the DeepSets module. Args: node_model_fn: A callable to be passed to NodeBlock. The callable must return a Sonnet module (or equivalent; see NodeBlock for details). The shape of this module's output must equal the shape of the input graph's global features, but for the first and last axis. global_model_fn: A callable to be passed to GlobalBlock. The callable must return a Sonnet module (or equivalent; see GlobalBlock for details). reducer: Reduction to be used when aggregating the nodes in the globals. This should be a callable whose signature matches tf.unsorted_segment_sum. name: The module name. """ super(DeepSets, self).__init__(name=name) with self._enter_variable_scope(): self._node_block = blocks.NodeBlock(node_model_fn=node_model_fn, use_received_edges=False, use_sent_edges=False, use_nodes=True, use_globals=True) self._global_block = blocks.GlobalBlock( global_model_fn=global_model_fn, use_edges=False, use_nodes=True, use_globals=False, nodes_reducer=reducer)
def __init__(self, edge_model_fn, global_model_fn, reducer=tf.unsorted_segment_sum, name="relation_network"): """Initializes the RelationNetwork module. Args: edge_model_fn: A callable that will be passed to EdgeBlock to perform per-edge computations. The callable must return a Sonnet module (or equivalent; see EdgeBlock for details). global_model_fn: A callable that will be passed to GlobalBlock to perform per-global computations. The callable must return a Sonnet module (or equivalent; see GlobalBlock for details). reducer: Reducer to be used by GlobalBlock to aggregate edges. Defaults to tf.unsorted_segment_sum. name: The module name. """ super(RelationNetwork, self).__init__(name=name) with self._enter_variable_scope(): self._edge_block = blocks.EdgeBlock(edge_model_fn=edge_model_fn, use_edges=False, use_receiver_nodes=True, use_sender_nodes=True, use_globals=False) self._global_block = blocks.GlobalBlock( global_model_fn=global_model_fn, use_edges=True, use_nodes=False, use_globals=False, edges_reducer=reducer)
def test_created_variables(self, use_edges, use_nodes, use_globals, expected_first_dim_w): """Verifies the variable names and shapes created by a GlobalBlock.""" output_size = 10 expected_var_shapes_dict = { "global_block/mlp/linear_0/b:0": [output_size], "global_block/mlp/linear_0/w:0": [expected_first_dim_w, output_size] } input_graph = self._get_input_graph() global_block = blocks.GlobalBlock(global_model_fn=functools.partial( snt.nets.MLP, output_sizes=[output_size]), use_edges=use_edges, use_nodes=use_nodes, use_globals=use_globals) global_block(input_graph) variables = global_block.get_variables() var_shapes_dict = { var.name: var.get_shape().as_list() for var in variables } self.assertDictEqual(expected_var_shapes_dict, var_shapes_dict)
def test_unused_field_can_be_none(self, use_edges, use_nodes, use_globals, none_field): """Checks that computation can handle non-necessary fields left None.""" input_graph = self._get_input_graph([none_field]) global_block = blocks.GlobalBlock( global_model_fn=self._global_model_fn, use_edges=use_edges, use_nodes=use_nodes, use_globals=use_globals) output_graph = global_block(input_graph) model_inputs = [] if use_edges: model_inputs.append( blocks.EdgesToGlobalsAggregator( tf.unsorted_segment_sum)(input_graph)) if use_nodes: model_inputs.append( blocks.NodesToGlobalsAggregator( tf.unsorted_segment_sum)(input_graph)) if use_globals: model_inputs.append(input_graph.globals) model_inputs = tf.concat(model_inputs, axis=-1) self.assertEqual(input_graph.edges, output_graph.edges) self.assertEqual(input_graph.nodes, output_graph.nodes) with self.test_session() as sess: actual_globals, model_inputs_out = sess.run( (output_graph.globals, model_inputs)) expected_output_globals = model_inputs_out * self._scale self.assertNDArrayNear(expected_output_globals, actual_globals, err=1e-4)
def test_output_values(self, use_edges, use_nodes, use_globals, edges_reducer, nodes_reducer): """Compares the output of a GlobalBlock to an explicit computation.""" input_graph = self._get_input_graph() global_block = blocks.GlobalBlock( global_model_fn=self._global_model_fn, use_edges=use_edges, use_nodes=use_nodes, use_globals=use_globals, edges_reducer=edges_reducer, nodes_reducer=nodes_reducer) output_graph = global_block(input_graph) model_inputs = [] if use_edges: model_inputs.append( blocks.EdgesToGlobalsAggregator(edges_reducer)(input_graph)) if use_nodes: model_inputs.append( blocks.NodesToGlobalsAggregator(nodes_reducer)(input_graph)) if use_globals: model_inputs.append(input_graph.globals) model_inputs = tf.concat(model_inputs, axis=-1) self.assertEqual(input_graph.edges, output_graph.edges) self.assertEqual(input_graph.nodes, output_graph.nodes) with self.test_session() as sess: output_graph_out, model_inputs_out = sess.run( (output_graph, model_inputs)) expected_output_globals = model_inputs_out * self._scale self.assertNDArrayNear(expected_output_globals, output_graph_out.globals, err=1e-4)
def test_no_input_raises_exception(self): """Checks that receiving no input raises an exception.""" with self.assertRaisesRegexp(ValueError, "At least one of "): blocks.GlobalBlock(global_model_fn=self._global_model_fn, use_edges=False, use_nodes=False, use_globals=False)
def test_compatible_higher_rank_no_raise(self): """No exception should occur with higher ranks tensors.""" input_graph = self._get_shaped_input_graph() input_graph = input_graph.map(lambda v: tf.transpose(v, [0, 2, 1, 3])) network = blocks.GlobalBlock( functools.partial(snt.Conv2D, output_channels=10, kernel_shape=[3, 3])) self._assert_build_and_run(network, input_graph)
def test_missing_field_raises_exception(self, use_edges, use_nodes, use_globals, none_field): """Checks that missing a required field raises an exception.""" input_graph = self._get_input_graph([none_field]) global_block = blocks.GlobalBlock( global_model_fn=self._global_model_fn, use_edges=use_edges, use_nodes=use_nodes, use_globals=use_globals) with self.assertRaisesRegexp(ValueError, "field cannot be None"): global_block(input_graph)
def test_missing_aggregation_raises_exception(self, use_edges, use_nodes, edges_reducer, nodes_reducer): """Checks that missing a required aggregation argument raises an error.""" with self.assertRaisesRegexp(ValueError, "should not be None"): blocks.GlobalBlock(global_model_fn=self._global_model_fn, use_edges=use_edges, use_nodes=use_nodes, use_globals=False, edges_reducer=edges_reducer, nodes_reducer=nodes_reducer)
def __init__(self, edge_model_fn, node_model_fn, global_model_fn, reducer=tf.unsorted_segment_sum, edge_block_opt=None, node_block_opt=None, global_block_opt=None, name="graph_network"): """Initializes the GraphNetwork module. Args: edge_model_fn: A callable that will be passed to EdgeBlock to perform per-edge computations. The callable must return a Sonnet module (or equivalent; see EdgeBlock for details). node_model_fn: A callable that will be passed to NodeBlock to perform per-node computations. The callable must return a Sonnet module (or equivalent; see NodeBlock for details). global_model_fn: A callable that will be passed to GlobalBlock to perform per-global computations. The callable must return a Sonnet module (or equivalent; see GlobalBlock for details). reducer: Reducer to be used by NodeBlock and GlobalBlock to to aggregate nodes and edges. Defaults to tf.unsorted_segment_sum. This will be overridden by the reducers specified in `node_block_opt` and `global_block_opt`, if any. edge_block_opt: Additional options to be passed to the EdgeBlock. Can contain keys `use_edges`, `use_receiver_nodes`, `use_sender_nodes`, `use_globals`. By default, these are all True. node_block_opt: Additional options to be passed to the NodeBlock. Can contain the keys `use_received_edges`, `use_sent_edges`, `use_nodes`, `use_globals` (all set to True by default), and `received_edges_reducer`, `sent_edges_reducer` (default to `reducer`). global_block_opt: Additional options to be passed to the GlobalBlock. Can contain the keys `use_edges`, `use_nodes`, `use_globals` (all set to True by default), and `edges_reducer`, `nodes_reducer` (defaults to `reducer`). name: The module name. """ super(GraphNetwork, self).__init__(name=name) edge_block_opt = _make_default_edge_block_opt(edge_block_opt) node_block_opt = _make_default_node_block_opt(node_block_opt, reducer) global_block_opt = _make_default_global_block_opt( global_block_opt, reducer) with self._enter_variable_scope(): self._edge_block = blocks.EdgeBlock(edge_model_fn=edge_model_fn, **edge_block_opt) self._node_block = blocks.NodeBlock(node_model_fn=node_model_fn, **node_block_opt) self._global_block = blocks.GlobalBlock( global_model_fn=global_model_fn, **global_block_opt)
def test_incompatible_higher_rank_inputs_no_raise(self, use_edges, use_nodes, use_globals, field): """No exception should occur if a differently shapped field is not used.""" input_graph = self._get_shaped_input_graph() input_graph = input_graph.replace( **{field: tf.transpose(getattr(input_graph, field), [0, 2, 1, 3])}) network = blocks.GlobalBlock(functools.partial(snt.Conv2D, output_channels=10, kernel_shape=[3, 3]), use_edges=use_edges, use_nodes=use_nodes, use_globals=use_globals) self._assert_build_and_run(network, input_graph)
def test_incompatible_higher_rank_inputs_raises(self, use_edges, use_nodes, use_globals, field): """A exception should be raised if the inputs have incompatible shapes.""" input_graph = self._get_shaped_input_graph() input_graph = input_graph.replace( **{field: tf.transpose(getattr(input_graph, field), [0, 2, 1, 3])}) network = blocks.GlobalBlock(functools.partial(snt.Conv2D, output_channels=10, kernel_shape=[3, 3]), use_edges=use_edges, use_nodes=use_nodes, use_globals=use_globals) with self.assertRaisesRegexp(ValueError, "in both shapes must be equal"): network(input_graph)