def test_created_variables(self, use_edges, use_receiver_nodes, use_sender_nodes, use_globals, expected_first_dim_w): """Verifies the variable names and shapes created by an EdgeBlock.""" output_size = 10 expected_var_shapes_dict = { "_edge_model.mlp.0.bias": [output_size], "_edge_model.mlp.0.weight": [expected_first_dim_w, output_size] } input_graph = self._get_input_graph() edge_block = blocks_torch.EdgeBlock( edge_model_fn=functools.partial( MLP, # TODO: change output_sizes=[output_size]), use_edges=use_edges, use_receiver_nodes=use_receiver_nodes, use_sender_nodes=use_sender_nodes, use_globals=use_globals) edge_block(input_graph) variables = edge_block.state_dict() var_shapes_dict = { var: list(reversed(variables[var].shape)) for var in variables } self.assertDictEqual(expected_var_shapes_dict, var_shapes_dict)
def test_unused_field_can_be_none(self, use_edges, use_nodes, use_globals, none_field): """Checks that computation can handle non-necessary fields left None.""" input_graph = self._get_input_graph([none_field]) edge_block = blocks_torch.EdgeBlock(edge_model_fn=self._edge_model_fn, use_edges=use_edges, use_receiver_nodes=use_nodes, use_sender_nodes=use_nodes, use_globals=use_globals) output_graph = edge_block(input_graph) model_inputs = [] if use_edges: model_inputs.append(input_graph.edges) if use_nodes: model_inputs.append( blocks_torch.broadcast_receiver_nodes_to_edges(input_graph)) model_inputs.append( blocks_torch.broadcast_sender_nodes_to_edges(input_graph)) if use_globals: model_inputs.append( blocks_torch.broadcast_globals_to_edges(input_graph)) model_inputs = torch.cat(model_inputs, dim=-1) self.assertEqual(input_graph.nodes, output_graph.nodes) self.assertEqual(input_graph.globals, output_graph.globals) expected_output_edges = model_inputs.numpy() * self._scale self.assertNDArrayNear(expected_output_edges, output_graph.edges.numpy(), err=1e-4)
def test_output_values(self, use_edges, use_receiver_nodes, use_sender_nodes, use_globals): """Compares the output of an EdgeBlock to an explicit computation.""" input_graph = self._get_input_graph() edge_block = blocks_torch.EdgeBlock( edge_model_fn=self._edge_model_fn, use_edges=use_edges, use_receiver_nodes=use_receiver_nodes, use_sender_nodes=use_sender_nodes, use_globals=use_globals) output_graph = edge_block(input_graph) model_inputs = [] if use_edges: model_inputs.append(input_graph.edges) if use_receiver_nodes: model_inputs.append( blocks_torch.broadcast_receiver_nodes_to_edges(input_graph)) if use_sender_nodes: model_inputs.append( blocks_torch.broadcast_sender_nodes_to_edges(input_graph)) if use_globals: model_inputs.append( blocks_torch.broadcast_globals_to_edges(input_graph)) model_inputs = torch.cat(model_inputs, dim=-1) self.assertEqual(input_graph.nodes, output_graph.nodes) self.assertEqual(input_graph.globals, output_graph.globals) expected_output_edges = model_inputs.numpy() * self._scale self.assertNDArrayNear(expected_output_edges, output_graph.edges.numpy(), err=1e-4)
def __init__(self, edge_model_fn, node_model_fn, reducer=scatter_add, name="interaction_network"): """Initializes the InteractionNetwork module. Args: edge_model_fn: A callable that will be passed to `EdgeBlock` to perform per-edge computations. The callable must return a Sonnet module (or equivalent; see `blocks.EdgeBlock` for details), and the shape of the output of this module must match the one of the input nodes, but for the first and last axis. node_model_fn: A callable that will be passed to `NodeBlock` to perform per-node computations. The callable must return a Sonnet module (or equivalent; see `blocks.NodeBlock` for details). reducer: Reducer to be used by NodeBlock to aggregate edges. Defaults to tf.math.unsorted_segment_sum. name: The module name. """ super(InteractionNetwork, self).__init__() self._edge_block = blocks_torch.EdgeBlock(edge_model_fn=edge_model_fn, use_globals=False) self._node_block = blocks_torch.NodeBlock( node_model_fn=node_model_fn, use_sent_edges=False, use_globals=False, received_edges_reducer=reducer)
def __init__(self, edge_model_fn, global_model_fn, reducer=scatter_add, name="relation_network"): """Initializes the RelationNetwork module. Args: edge_model_fn: A callable that will be passed to EdgeBlock to perform per-edge computations. The callable must return a Sonnet module (or equivalent; see EdgeBlock for details). global_model_fn: A callable that will be passed to GlobalBlock to perform per-global computations. The callable must return a Sonnet module (or equivalent; see GlobalBlock for details). reducer: Reducer to be used by GlobalBlock to aggregate edges. Defaults to tf.math.unsorted_segment_sum. name: The module name. """ super(RelationNetwork, self).__init__() self._edge_block = blocks_torch.EdgeBlock(edge_model_fn=edge_model_fn, use_edges=False, use_receiver_nodes=True, use_sender_nodes=True, use_globals=False) self._global_block = blocks_torch.GlobalBlock( global_model_fn=global_model_fn, use_edges=True, use_nodes=False, use_globals=False, edges_reducer=reducer)
def test_no_input_raises_exception(self): """Checks that receiving no input raises an exception.""" with self.assertRaisesRegexp(ValueError, "At least one of "): blocks_torch.EdgeBlock(edge_model_fn=self._edge_model_fn, use_edges=False, use_receiver_nodes=False, use_sender_nodes=False, use_globals=False)
def test_compatible_higher_rank_no_raise(self): """No exception should occur with higher ranks tensors.""" input_graph = self._get_shaped_input_graph() input_graph = input_graph.map( lambda v: v.permute(0, 2, 1, 3)) # TODO: change network = blocks_torch.EdgeBlock( functools.partial(Conv2D, output_channels=10, kernel_shape=[3, 3])) # TODO: change self._assert_build_and_run(network, input_graph)
def test_missing_field_raises_exception(self, use_edges, use_receiver_nodes, use_sender_nodes, use_globals, none_fields): """Checks that missing a required field raises an exception.""" input_graph = self._get_input_graph(none_fields) edge_block = blocks_torch.EdgeBlock( edge_model_fn=self._edge_model_fn, use_edges=use_edges, use_receiver_nodes=use_receiver_nodes, use_sender_nodes=use_sender_nodes, use_globals=use_globals) with self.assertRaisesRegexp(ValueError, "field cannot be None"): edge_block(input_graph)
def __init__(self, edge_model_fn, node_model_fn, global_model_fn, reducer=scatter_add, edge_block_opt=None, node_block_opt=None, global_block_opt=None, name="graph_network"): """Initializes the GraphNetwork module. Args: edge_model_fn: A callable that will be passed to EdgeBlock to perform per-edge computations. The callable must return a Sonnet module (or equivalent; see EdgeBlock for details). node_model_fn: A callable that will be passed to NodeBlock to perform per-node computations. The callable must return a Sonnet module (or equivalent; see NodeBlock for details). global_model_fn: A callable that will be passed to GlobalBlock to perform per-global computations. The callable must return a Sonnet module (or equivalent; see GlobalBlock for details). reducer: Reducer to be used by NodeBlock and GlobalBlock to aggregate nodes and edges. Defaults to tf.math.unsorted_segment_sum. This will be overridden by the reducers specified in `node_block_opt` and `global_block_opt`, if any. edge_block_opt: Additional options to be passed to the EdgeBlock. Can contain keys `use_edges`, `use_receiver_nodes`, `use_sender_nodes`, `use_globals`. By default, these are all True. node_block_opt: Additional options to be passed to the NodeBlock. Can contain the keys `use_received_edges`, `use_nodes`, `use_globals` (all set to True by default), `use_sent_edges` (defaults to False), and `received_edges_reducer`, `sent_edges_reducer` (default to `reducer`). global_block_opt: Additional options to be passed to the GlobalBlock. Can contain the keys `use_edges`, `use_nodes`, `use_globals` (all set to True by default), and `edges_reducer`, `nodes_reducer` (defaults to `reducer`). name: The module name. """ super(GraphNetwork, self).__init__() edge_block_opt = _make_default_edge_block_opt(edge_block_opt) node_block_opt = _make_default_node_block_opt(node_block_opt, reducer) global_block_opt = _make_default_global_block_opt( global_block_opt, reducer) self._edge_block = blocks_torch.EdgeBlock(edge_model_fn=edge_model_fn, **edge_block_opt) self._node_block = blocks_torch.NodeBlock(node_model_fn=node_model_fn, **node_block_opt) self._global_block = blocks_torch.GlobalBlock( global_model_fn=global_model_fn, **global_block_opt)
def __init__(self, edge_model_fn, node_encoder_model_fn, node_model_fn, reducer=scatter_add, name="comm_net"): """Initializes the CommNet module. Args: edge_model_fn: A callable to be passed to EdgeBlock. The callable must return a Sonnet module (or equivalent; see EdgeBlock for details). node_encoder_model_fn: A callable to be passed to the NodeBlock responsible for the first encoding of the nodes. The callable must return a Sonnet module (or equivalent; see NodeBlock for details). The shape of this module's output should match the shape of the module built by `edge_model_fn`, but for the first and last dimension. node_model_fn: A callable to be passed to NodeBlock. The callable must return a Sonnet module (or equivalent; see NodeBlock for details). reducer: Reduction to be used when aggregating the edges in the nodes. This should be a callable whose signature matches tf.math.unsorted_segment_sum. name: The module name. """ super(CommNet, self).__init__() # Computes $\Psi_{com}(x_j)$ in Eq. (2) of 1706.06122 self._edge_block = blocks_torch.EdgeBlock(edge_model_fn=edge_model_fn, use_edges=False, use_receiver_nodes=False, use_sender_nodes=True, use_globals=False) # Computes $\Phi(x_i)$ in Eq. (2) of 1706.06122 self._node_encoder_block = blocks_torch.NodeBlock( node_model_fn=node_encoder_model_fn, use_received_edges=False, use_sent_edges=False, use_nodes=True, use_globals=False, received_edges_reducer=reducer, name="node_encoder_block") # Computes $\Theta(..)$ in Eq.(2) of 1706.06122 self._node_block = blocks_torch.NodeBlock( node_model_fn=node_model_fn, use_received_edges=True, use_sent_edges=False, use_nodes=True, use_globals=False, received_edges_reducer=reducer)
def test_incompatible_higher_rank_inputs_no_raise(self, use_edges, use_receiver_nodes, use_sender_nodes, use_globals, field): """No exception should occur if a differently shapped field is not used.""" input_graph = self._get_shaped_input_graph() input_graph = input_graph.replace(**{ field: getattr(input_graph, field).permute(0, 2, 1, 3) }) # TODO: change network = blocks_torch.EdgeBlock( functools.partial(Conv2D, output_channels=10, kernel_shape=[3, 3]), # TODO: change use_edges=use_edges, use_receiver_nodes=use_receiver_nodes, use_sender_nodes=use_sender_nodes, use_globals=use_globals) self._assert_build_and_run(network, input_graph)
def test_incompatible_higher_rank_inputs_raises(self, use_edges, use_receiver_nodes, use_sender_nodes, use_globals, field): """A exception should be raised if the inputs have incompatible shapes.""" input_graph = self._get_shaped_input_graph() input_graph = input_graph.replace( **{field: getattr(input_graph, field).permute(0, 2, 1, 3)}) network = blocks_torch.EdgeBlock(functools.partial(Conv2D, output_channels=10, kernel_shape=[3, 3]), use_edges=use_edges, use_receiver_nodes=use_receiver_nodes, use_sender_nodes=use_sender_nodes, use_globals=use_globals) with self.assertRaisesRegexp(ValueError, "in both shapes must be equal"): network(input_graph)