def __init__(self, edge_model_fn, node_encoder_model_fn, node_model_fn, reducer=scatter_add, name="comm_net"): """Initializes the CommNet module. Args: edge_model_fn: A callable to be passed to EdgeBlock. The callable must return a Sonnet module (or equivalent; see EdgeBlock for details). node_encoder_model_fn: A callable to be passed to the NodeBlock responsible for the first encoding of the nodes. The callable must return a Sonnet module (or equivalent; see NodeBlock for details). The shape of this module's output should match the shape of the module built by `edge_model_fn`, but for the first and last dimension. node_model_fn: A callable to be passed to NodeBlock. The callable must return a Sonnet module (or equivalent; see NodeBlock for details). reducer: Reduction to be used when aggregating the edges in the nodes. This should be a callable whose signature matches tf.math.unsorted_segment_sum. name: The module name. """ super(CommNet, self).__init__() # Computes $\Psi_{com}(x_j)$ in Eq. (2) of 1706.06122 self._edge_block = blocks_torch.EdgeBlock(edge_model_fn=edge_model_fn, use_edges=False, use_receiver_nodes=False, use_sender_nodes=True, use_globals=False) # Computes $\Phi(x_i)$ in Eq. (2) of 1706.06122 self._node_encoder_block = blocks_torch.NodeBlock( node_model_fn=node_encoder_model_fn, use_received_edges=False, use_sent_edges=False, use_nodes=True, use_globals=False, received_edges_reducer=reducer, name="node_encoder_block") # Computes $\Theta(..)$ in Eq.(2) of 1706.06122 self._node_block = blocks_torch.NodeBlock( node_model_fn=node_model_fn, use_received_edges=True, use_sent_edges=False, use_nodes=True, use_globals=False, received_edges_reducer=reducer)
def test_unused_field_can_be_none(self, use_edges, use_nodes, use_globals, none_field): """Checks that computation can handle non-necessary fields left None.""" input_graph = self._get_input_graph([none_field]) node_block = blocks_torch.NodeBlock(node_model_fn=self._node_model_fn, use_received_edges=use_edges, use_sent_edges=use_edges, use_nodes=use_nodes, use_globals=use_globals) output_graph = node_block(input_graph) model_inputs = [] if use_edges: model_inputs.append( blocks_torch.ReceivedEdgesToNodesAggregator( torch_scatter.scatter_add)(input_graph)) model_inputs.append( blocks_torch.SentEdgesToNodesAggregator( torch_scatter.scatter_add)(input_graph)) if use_nodes: model_inputs.append(input_graph.nodes) if use_globals: model_inputs.append( blocks_torch.broadcast_globals_to_nodes(input_graph)) model_inputs = torch.cat(model_inputs, dim=-1) # TODO: Check semantics of dim=-1 self.assertEqual(input_graph.edges, output_graph.edges) self.assertEqual(input_graph.globals, output_graph.globals) expected_output_nodes = model_inputs.numpy() * self._scale self.assertNDArrayNear(expected_output_nodes, output_graph.nodes.numpy(), err=1e-4)
def test_created_variables(self, use_received_edges, use_sent_edges, use_nodes, use_globals, expected_first_dim_w): """Verifies the variable names and shapes created by a NodeBlock.""" output_size = 10 expected_var_shapes_dict = { "_node_model.mlp.0.bias": [output_size], "_node_model.mlp.0.weight": [expected_first_dim_w, output_size] } input_graph = self._get_input_graph() node_block = blocks_torch.NodeBlock( node_model_fn=functools.partial(MLP, output_sizes=[output_size]), use_received_edges=use_received_edges, use_sent_edges=use_sent_edges, use_nodes=use_nodes, use_globals=use_globals) node_block(input_graph) variables = node_block.state_dict() var_shapes_dict = { var: list(reversed(variables[var].shape)) for var in variables } self.assertDictEqual(expected_var_shapes_dict, var_shapes_dict)
def __init__(self, edge_model_fn, node_model_fn, reducer=scatter_add, name="interaction_network"): """Initializes the InteractionNetwork module. Args: edge_model_fn: A callable that will be passed to `EdgeBlock` to perform per-edge computations. The callable must return a Sonnet module (or equivalent; see `blocks.EdgeBlock` for details), and the shape of the output of this module must match the one of the input nodes, but for the first and last axis. node_model_fn: A callable that will be passed to `NodeBlock` to perform per-node computations. The callable must return a Sonnet module (or equivalent; see `blocks.NodeBlock` for details). reducer: Reducer to be used by NodeBlock to aggregate edges. Defaults to tf.math.unsorted_segment_sum. name: The module name. """ super(InteractionNetwork, self).__init__() self._edge_block = blocks_torch.EdgeBlock(edge_model_fn=edge_model_fn, use_globals=False) self._node_block = blocks_torch.NodeBlock( node_model_fn=node_model_fn, use_sent_edges=False, use_globals=False, received_edges_reducer=reducer)
def __init__(self, node_model_fn, global_model_fn, reducer=scatter_add, name="deep_sets"): """Initializes the DeepSets module. Args: node_model_fn: A callable to be passed to NodeBlock. The callable must return a Sonnet module (or equivalent; see NodeBlock for details). The shape of this module's output must equal the shape of the input graph's global features, but for the first and last axis. global_model_fn: A callable to be passed to GlobalBlock. The callable must return a Sonnet module (or equivalent; see GlobalBlock for details). reducer: Reduction to be used when aggregating the nodes in the globals. This should be a callable whose signature matches tf.math.unsorted_segment_sum. name: The module name. """ super(DeepSets, self).__init__() self._node_block = blocks_torch.NodeBlock(node_model_fn=node_model_fn, use_received_edges=False, use_sent_edges=False, use_nodes=True, use_globals=True) self._global_block = blocks_torch.GlobalBlock( global_model_fn=global_model_fn, use_edges=False, use_nodes=True, use_globals=False, nodes_reducer=reducer)
def test_no_input_raises_exception(self): """Checks that receiving no input raises an exception.""" with self.assertRaisesRegexp(ValueError, "At least one of "): blocks_torch.NodeBlock(node_model_fn=self._node_model_fn, use_received_edges=False, use_sent_edges=False, use_nodes=False, use_globals=False)
def test_compatible_higher_rank_no_raise(self): """No exception should occur with higher ranks tensors.""" input_graph = self._get_shaped_input_graph() input_graph = input_graph.map( lambda v: v.permute(0, 2, 1, 3)) # TODO: change network = blocks_torch.NodeBlock( functools.partial(Conv2D, output_channels=10, kernel_shape=[3, 3])) # TODO: change self._assert_build_and_run(network, input_graph)
def test_missing_field_raises_exception(self, use_received_edges, use_sent_edges, use_nodes, use_globals, none_fields): """Checks that missing a required field raises an exception.""" input_graph = self._get_input_graph(none_fields) node_block = blocks_torch.NodeBlock( node_model_fn=self._node_model_fn, use_received_edges=use_received_edges, use_sent_edges=use_sent_edges, use_nodes=use_nodes, use_globals=use_globals) with self.assertRaisesRegexp(ValueError, "field cannot be None"): node_block(input_graph)
def test_missing_aggregation_raises_exception(self, use_received_edges, use_sent_edges, received_edges_reducer, sent_edges_reducer): """Checks that missing a required aggregation argument raises an error.""" with self.assertRaisesRegexp(ValueError, "should not be None"): blocks_torch.NodeBlock( node_model_fn=self._node_model_fn, use_received_edges=use_received_edges, use_sent_edges=use_sent_edges, use_nodes=False, use_globals=False, received_edges_reducer=received_edges_reducer, sent_edges_reducer=sent_edges_reducer)
def __init__(self, edge_model_fn, node_model_fn, global_model_fn, reducer=scatter_add, edge_block_opt=None, node_block_opt=None, global_block_opt=None, name="graph_network"): """Initializes the GraphNetwork module. Args: edge_model_fn: A callable that will be passed to EdgeBlock to perform per-edge computations. The callable must return a Sonnet module (or equivalent; see EdgeBlock for details). node_model_fn: A callable that will be passed to NodeBlock to perform per-node computations. The callable must return a Sonnet module (or equivalent; see NodeBlock for details). global_model_fn: A callable that will be passed to GlobalBlock to perform per-global computations. The callable must return a Sonnet module (or equivalent; see GlobalBlock for details). reducer: Reducer to be used by NodeBlock and GlobalBlock to aggregate nodes and edges. Defaults to tf.math.unsorted_segment_sum. This will be overridden by the reducers specified in `node_block_opt` and `global_block_opt`, if any. edge_block_opt: Additional options to be passed to the EdgeBlock. Can contain keys `use_edges`, `use_receiver_nodes`, `use_sender_nodes`, `use_globals`. By default, these are all True. node_block_opt: Additional options to be passed to the NodeBlock. Can contain the keys `use_received_edges`, `use_nodes`, `use_globals` (all set to True by default), `use_sent_edges` (defaults to False), and `received_edges_reducer`, `sent_edges_reducer` (default to `reducer`). global_block_opt: Additional options to be passed to the GlobalBlock. Can contain the keys `use_edges`, `use_nodes`, `use_globals` (all set to True by default), and `edges_reducer`, `nodes_reducer` (defaults to `reducer`). name: The module name. """ super(GraphNetwork, self).__init__() edge_block_opt = _make_default_edge_block_opt(edge_block_opt) node_block_opt = _make_default_node_block_opt(node_block_opt, reducer) global_block_opt = _make_default_global_block_opt( global_block_opt, reducer) self._edge_block = blocks_torch.EdgeBlock(edge_model_fn=edge_model_fn, **edge_block_opt) self._node_block = blocks_torch.NodeBlock(node_model_fn=node_model_fn, **node_block_opt) self._global_block = blocks_torch.GlobalBlock( global_model_fn=global_model_fn, **global_block_opt)
def test_incompatible_higher_rank_inputs_no_raise(self, use_received_edges, use_sent_edges, use_nodes, use_globals, field): """No exception should occur if a differently shapped field is not used.""" input_graph = self._get_shaped_input_graph() input_graph = input_graph.replace(**{ field: getattr(input_graph, field).permute(0, 2, 1, 3) }) # TODO: change network = blocks_torch.NodeBlock( functools.partial(Conv2D, output_channels=10, kernel_shape=[3, 3]), # TODO: change use_received_edges=use_received_edges, use_sent_edges=use_sent_edges, use_nodes=use_nodes, use_globals=use_globals) self._assert_build_and_run(network, input_graph)
def test_incompatible_higher_rank_inputs_raises(self, use_received_edges, use_sent_edges, use_nodes, use_globals, field): """A exception should be raised if the inputs have incompatible shapes.""" input_graph = self._get_shaped_input_graph() input_graph = input_graph.replace(**{ field: getattr(input_graph, field).permute(0, 2, 1, 3) }) # TODO: change network = blocks_torch.NodeBlock( functools.partial(Conv2D, output_channels=10, kernel_shape=[3, 3]), # TODO: change use_received_edges=use_received_edges, use_sent_edges=use_sent_edges, use_nodes=use_nodes, use_globals=use_globals) with self.assertRaisesRegexp(ValueError, "in both shapes must be equal"): network(input_graph)
def test_output_values(self, use_received_edges, use_sent_edges, use_nodes, use_globals, received_edges_reducer, sent_edges_reducer): """Compares the output of a NodeBlock to an explicit computation.""" input_graph = self._get_input_graph() node_block = blocks_torch.NodeBlock( node_model_fn=self._node_model_fn, use_received_edges=use_received_edges, use_sent_edges=use_sent_edges, use_nodes=use_nodes, use_globals=use_globals, received_edges_reducer=received_edges_reducer, sent_edges_reducer=sent_edges_reducer) output_graph = node_block(input_graph) model_inputs = [] if use_received_edges: model_inputs.append( blocks_torch.ReceivedEdgesToNodesAggregator( received_edges_reducer)(input_graph)) if use_sent_edges: model_inputs.append( blocks_torch.SentEdgesToNodesAggregator(sent_edges_reducer)( input_graph)) if use_nodes: model_inputs.append(input_graph.nodes) if use_globals: model_inputs.append( blocks_torch.broadcast_globals_to_nodes(input_graph)) model_inputs = torch.cat( model_inputs, dim=-1) # TODO: check the semantics of dim=-1 in pytorch self.assertEqual(input_graph.edges, output_graph.edges) self.assertEqual(input_graph.globals, output_graph.globals) expected_output_nodes = model_inputs.numpy() * self._scale self.assertNDArrayNear(expected_output_nodes, output_graph.nodes.numpy(), err=1e-4)