def _build(self, input_graph, hidden_size=50, attn_scale=1.0, attn_dropout_keep_prob=1.0, regularizer=None, is_training=False): node_values = input_graph.nodes edge_values = input_graph.edges value_dims = node_values.shape[-1].value assert value_dims == edge_values.shape[-1].value # Compute edge values, sender feature + edge feature. # - edge_values = [total_num_edges, value_dims] edge_value_block = blocks.EdgeBlock(edge_model_fn=lambda: snt.Linear( output_size=value_dims, regularizers={'w': regularizer}), use_edges=True, use_receiver_nodes=True, use_sender_nodes=True, use_globals=False, name='update_edge_values') edge_values = edge_value_block(input_graph).edges tf.summary.histogram('mpnn/edge_values', edge_values) logits_block = blocks.EdgeBlock( edge_model_fn=lambda: snt.Linear(output_size=1, regularizers={'w': regularizer}), # edge_model_fn=lambda: snt.nets.MLP(output_sizes=[hidden_size, 1], # activation=tf.nn.tanh, # regularizers={'w': regularizer}), use_edges=True, use_receiver_nodes=True, use_sender_nodes=True, use_globals=False, name='update_attention_logits') attention_weights_logits = attn_scale * logits_block(input_graph).edges tf.summary.histogram('mpnn/logits', attention_weights_logits) normalized_attention_weight = modules._received_edges_normalizer( input_graph.replace(edges=attention_weights_logits), normalizer=self._normalizer) normalized_attention_weight = slim.dropout(normalized_attention_weight, attn_dropout_keep_prob, is_training=is_training) # Attending to sender values according to the weights. # - attended_edges = [total_num_edges, value_dims] attended_edges = edge_values * normalized_attention_weight # Summing all of the attended values from each node. # aggregated_attended_values = [total_num_nodes, embedding_size] received_edges_aggregator = blocks.ReceivedEdgesToNodesAggregator( reducer=tf.math.unsorted_segment_sum) aggregated_attended_values = received_edges_aggregator( input_graph.replace(edges=attended_edges)) return input_graph.replace(nodes=aggregated_attended_values, edges=edge_values)
def __init__(self, name="DecaySimulator"): super(DecaySimulator, self).__init__(name=name) self._node_linear = make_mlp_model() self._node_rnn = snt.GRU(hidden_size=LATENT_SIZE, name='node_rnn') self._node_proper = snt.nets.MLP([4], activate_final=False) self._edge_block = blocks.EdgeBlock(edge_model_fn=make_mlp_model, use_edges=False, use_receiver_nodes=True, use_sender_nodes=True, use_globals=False, name='edge_encoder_block') self._node_encoder_block = blocks.NodeBlock( node_model_fn=make_mlp_model, use_received_edges=False, use_sent_edges=False, use_nodes=True, use_globals=False, name='node_encoder_block') self._global_encoder_block = blocks.GlobalBlock( global_model_fn=make_mlp_model, use_edges=True, use_nodes=True, use_globals=False, nodes_reducer=tf.math.unsorted_segment_sum, edges_reducer=tf.math.unsorted_segment_sum, name='global_encoder_block') self._core = MLPGraphNetwork() # self._core = InteractionNetwork( # edge_model_fn=make_mlp_model, # node_model_fn=make_mlp_model, # reducer=tf.math.unsorted_segment_sum # ) # # Transforms the outputs into appropriate shapes. node_output_size = 64 node_fn = lambda: snt.Sequential([ snt.nets.MLP( [node_output_size], activation=tf.nn.relu, # default is relu name='node_output') ]) global_output_size = 1 global_fn = lambda: snt.Sequential([ snt.nets.MLP( [global_output_size], activation=tf.nn.relu, # default is relu name='global_output'), tf.sigmoid ]) self._output_transform = modules.GraphIndependent( edge_model_fn=None, node_model_fn=node_fn, global_model_fn=global_fn)
def __init__(self, name="DeepGraphInfoMax"): super(DeepGraphInfoMax, self).__init__(name=name) self._edge_block = blocks.EdgeBlock( edge_model_fn=lambda: snt.nets.MLP([LATENT_SIZE] * 2, activation=tf.nn.relu, activate_final=True, use_dropout=True), use_edges=False, use_receiver_nodes=True, use_sender_nodes=True, use_globals=False, name='edge_encoder_block') self._node_encoder_block = blocks.NodeBlock( node_model_fn=make_mlp_model, use_received_edges=False, use_sent_edges=False, use_nodes=True, use_globals=False, name='node_encoder_block') self._core = modules.InteractionNetwork( edge_model_fn=make_mlp_model, node_model_fn=make_mlp_model, reducer=tf.unsorted_segment_sum)
def __init__(self, name, edge_model_fn, node_model_fn, global_model_fn): """ description: initializes the model :param edge_model_fn: Function passed to the edge block, in this paper, it is an MLP :param node_model_fn: Function passed to the node block, in this paper for the task of bAbI, it is an LSTM over timesteps :param edge_model_fn: Function passed to the global block, in this paper, it is an MLP """ super(RRN, self).__init__(name=name) self._edge_block = blocks.EdgeBlock(edge_model_fn=edge_model_fn, use_edges=False, use_receiver_nodes=True, use_sender_nodes=True, use_globals=False) self._global_block = blocks.GlobalBlock( global_model_fn=global_model_fn, use_edges=False, use_nodes=True, use_globals=False, nodes_reducer=tf.unsorted_segment_sum) self._node_model_fn = node_model_fn
def __init__(self, edge_model_fn, node_model_fn, name="fwd_gnn"): # aggregator_fn = tf.math.unsorted_segment_sum, """Initializes the transition model""" super(GraphNeuralNetwork_transition, self).__init__(name=name) with self._enter_variable_scope(): self._edge_block = blocks.EdgeBlock(edge_model_fn=edge_model_fn, use_edges=False, use_receiver_nodes=True, use_sender_nodes=True, use_globals=False, name='edge_block') # self._node_block = blocks.NodeBlock( # node_model_fn=node_model_fn, # use_received_edges=True, # use_sent_edges=True, # use_nodes=True, # use_globals=True, # received_edges_reducer=tf.math.unsorted_segment_sum, # sent_edges_reducer=tf.math.unsorted_segment_sum, # name="node_block") self._sent_edges_aggregator = blocks.SentEdgesToNodesAggregator( reducer=tf.math.unsorted_segment_sum) self._node_model = node_model_fn()
def __init__(self, name="FourTopPredictor"): super(FourTopPredictor, self).__init__(name=name) self._edge_block = blocks.EdgeBlock(edge_model_fn=make_mlp_model, use_edges=False, use_receiver_nodes=True, use_sender_nodes=True, use_globals=False, name='edge_encoder_block') self._node_encoder_block = blocks.NodeBlock( node_model_fn=make_mlp_model, use_received_edges=False, use_sent_edges=False, use_nodes=True, use_globals=False, name='node_encoder_block') self._global_block = blocks.GlobalBlock( global_model_fn=make_mlp_model, use_edges=True, use_nodes=True, use_globals=False, ) self._core = MLPGraphNetwork() # Transforms the outputs into appropriate shapes. global_output_size = n_target_node_features * n_max_tops self._global_nn = snt.nets.MLP( [128, 128, global_output_size], activation=tf.nn.leaky_relu, # default is relu, tanh dropout_rate=0.30, name='global_output')
def __init__(self, edge_model_fn, node_model_fn, reducer=tf.unsorted_segment_sum, name="interaction_network"): """Initializes the InteractionNetwork module. Args: edge_model_fn: A callable that will be passed to `EdgeBlock` to perform per-edge computations. The callable must return a Sonnet module (or equivalent; see `blocks.EdgeBlock` for details), and the shape of the output of this module must match the one of the input nodes, but for the first and last axis. node_model_fn: A callable that will be passed to `NodeBlock` to perform per-node computations. The callable must return a Sonnet module (or equivalent; see `blocks.NodeBlock` for details). reducer: Reducer to be used by NodeBlock to aggregate edges. Defaults to tf.unsorted_segment_sum. name: The module name. """ super(InteractionNetwork, self).__init__(name=name) with self._enter_variable_scope(): self._edge_block = blocks.EdgeBlock(edge_model_fn=edge_model_fn, use_globals=False) self._node_block = blocks.NodeBlock(node_model_fn=node_model_fn, use_received_edges=True, use_sent_edges=True, use_globals=False, received_edges_reducer=reducer)
def __init__(self, edge_model_fn, node_model_fn, reducer=tf.unsorted_segment_sum, name="tcomm_net"): super(TCommNet, self).__init__(name=name) with self._enter_variable_scope(): self._edge_block = blocks.LEdgeBlock(edge_model_fn=edge_model_fn, use_edges=False, use_receiver_nodes=False, use_sender_nodes=True, use_globals=False, use_reverse_edges=True) self._edge_block2 = blocks.EdgeBlock(edge_model_fn=edge_model_fn, use_edges=False, use_receiver_nodes=False, use_sender_nodes=True, use_globals=False) self._node_block = blocks.toLNodeBlock( node_model_fn=node_model_fn, use_received_edges=True, use_sent_edges=False, use_nodes=True, use_globals=False, received_edges_reducer=reducer)
def __init__(self, edge_model_fn, global_model_fn, reducer=tf.unsorted_segment_sum, name="relation_network"): """Initializes the RelationNetwork module. Args: edge_model_fn: A callable that will be passed to EdgeBlock to perform per-edge computations. The callable must return a Sonnet module (or equivalent; see EdgeBlock for details). global_model_fn: A callable that will be passed to GlobalBlock to perform per-global computations. The callable must return a Sonnet module (or equivalent; see GlobalBlock for details). reducer: Reducer to be used by GlobalBlock to aggregate edges. Defaults to tf.unsorted_segment_sum. name: The module name. """ super(RelationNetwork, self).__init__(name=name) with self._enter_variable_scope(): self._edge_block = blocks.EdgeBlock(edge_model_fn=edge_model_fn, use_edges=False, use_receiver_nodes=True, use_sender_nodes=True, use_globals=False) self._global_block = blocks.GlobalBlock( global_model_fn=global_model_fn, use_edges=True, use_nodes=False, use_globals=False, edges_reducer=reducer)
def __init__(self, edge_model_fn, node_model_fn, reducer=tf.unsorted_segment_sum, name="comm_net"): """Initializes the CommNet module. Args: edge_model_fn: A callable to be passed to EdgeBlock. The callable must return a Sonnet module (or equivalent; see EdgeBlock for details). node_encoder_model_fn: A callable to be passed to the NodeBlock responsible for the first encoding of the nodes. The callable must return a Sonnet module (or equivalent; see NodeBlock for details). The shape of this module's output should match the shape of the module built by `edge_model_fn`, but for the first and last dimension. node_model_fn: A callable to be passed to NodeBlock. The callable must return a Sonnet module (or equivalent; see NodeBlock for details). reducer: Reduction to be used when aggregating the edges in the nodes. This should be a callable whose signature matches tf.unsorted_segment_sum. name: The module name. """ super(CommNet, self).__init__(name=name) with self._enter_variable_scope(): # Computes $\Psi_{com}(x_j)$ in Eq. (2) of 1706.06122 self._edge_block = blocks.EdgeBlock(edge_model_fn=edge_model_fn, use_edges=False, use_receiver_nodes=False, use_sender_nodes=True, use_globals=False)
def test_compatible_higher_rank_no_raise(self): """No exception should occur with higher ranks tensors.""" input_graph = self._get_shaped_input_graph() input_graph = input_graph.map(lambda v: tf.transpose(v, [0, 2, 1, 3])) network = blocks.EdgeBlock( functools.partial(snt.Conv2D, output_channels=10, kernel_shape=[3, 3])) self._assert_build_and_run(network, input_graph)
def test_same_as_subblocks(self, reducer): """Compares the output to explicit subblocks output. Args: reducer: The reducer used in the `NodeBlock` and `GlobalBlock`. """ input_graph = self._get_input_graph() edge_model_fn = functools.partial(snt.Linear, output_size=5) node_model_fn = functools.partial(snt.Linear, output_size=10) global_model_fn = functools.partial(snt.Linear, output_size=15) graph_network = modules.GraphNetwork( edge_model_fn=edge_model_fn, node_model_fn=node_model_fn, global_model_fn=global_model_fn, reducer=reducer) output_graph = graph_network(input_graph) edge_block = blocks.EdgeBlock( edge_model_fn=lambda: graph_network._edge_block._edge_model, use_sender_nodes=True, use_edges=True, use_receiver_nodes=True, use_globals=True) node_block = blocks.NodeBlock( node_model_fn=lambda: graph_network._node_block._node_model, use_nodes=True, use_sent_edges=False, use_received_edges=True, use_globals=True, received_edges_reducer=reducer) global_block = blocks.GlobalBlock( global_model_fn=lambda: graph_network._global_block._global_model, use_nodes=True, use_edges=True, use_globals=True, edges_reducer=reducer, nodes_reducer=reducer) expected_output_edge_block = edge_block(input_graph) expected_output_node_block = node_block(expected_output_edge_block) expected_output_global_block = global_block(expected_output_node_block) expected_edges = expected_output_edge_block.edges expected_nodes = expected_output_node_block.nodes expected_globals = expected_output_global_block.globals with self.test_session() as sess: sess.run(tf.global_variables_initializer()) (output_graph_out, expected_edges_out, expected_nodes_out, expected_globals_out) = sess.run( (output_graph, expected_edges, expected_nodes, expected_globals)) self._assert_all_none_or_all_close(expected_edges_out, output_graph_out.edges) self._assert_all_none_or_all_close(expected_nodes_out, output_graph_out.nodes) self._assert_all_none_or_all_close(expected_globals_out, output_graph_out.globals)
def __init__(self, name="GlobalClassifierNoEdgeInfo"): super(GlobalClassifierNoEdgeInfo, self).__init__(name=name) self._edge_block = blocks.EdgeBlock(edge_model_fn=make_mlp_model, use_edges=False, use_receiver_nodes=True, use_sender_nodes=True, use_globals=False, name='edge_encoder_block') self._node_encoder_block = blocks.NodeBlock( node_model_fn=make_mlp_model, use_received_edges=False, use_sent_edges=False, use_nodes=True, use_globals=False, name='node_encoder_block') self._global_block = blocks.GlobalBlock( global_model_fn=make_mlp_model, use_edges=True, use_nodes=True, use_globals=False, ) self._core = MLPGraphNetwork() # Transforms the outputs into appropriate shapes. global_output_size = 1 global_fn = lambda: snt.Sequential([ snt.nets.MLP([LATENT_SIZE, global_output_size], name='global_output'), tf.sigmoid ]) self._output_transform = modules.GraphIndependent( None, None, global_fn)
def test_output_values( self, use_edges, use_receiver_nodes, use_sender_nodes, use_globals): """Compares the output of an EdgeBlock to an explicit computation.""" input_graph = self._get_input_graph() edge_block = blocks.EdgeBlock( edge_model_fn=self._edge_model_fn, use_edges=use_edges, use_receiver_nodes=use_receiver_nodes, use_sender_nodes=use_sender_nodes, use_globals=use_globals) output_graph = edge_block(input_graph) model_inputs = [] if use_edges: model_inputs.append(input_graph.edges) if use_receiver_nodes: model_inputs.append(blocks.broadcast_receiver_nodes_to_edges(input_graph)) if use_sender_nodes: model_inputs.append(blocks.broadcast_sender_nodes_to_edges(input_graph)) if use_globals: model_inputs.append(blocks.broadcast_globals_to_edges(input_graph)) model_inputs = tf.concat(model_inputs, axis=-1) self.assertEqual(input_graph.nodes, output_graph.nodes) self.assertEqual(input_graph.globals, output_graph.globals) with self.test_session() as sess: output_graph_out, model_inputs_out = sess.run( (output_graph, model_inputs)) expected_output_edges = model_inputs_out * self._scale self.assertNDArrayNear( expected_output_edges, output_graph_out.edges, err=1e-4)
def test_created_variables(self, use_edges, use_receiver_nodes, use_sender_nodes, use_globals, expected_first_dim_w): """Verifies the variable names and shapes created by an EdgeBlock.""" output_size = 10 expected_var_shapes_dict = { "edge_block/mlp/linear_0/b:0": [output_size], "edge_block/mlp/linear_0/w:0": [expected_first_dim_w, output_size] } input_graph = self._get_input_graph() edge_block = blocks.EdgeBlock(edge_model_fn=functools.partial( snt.nets.MLP, output_sizes=[output_size]), use_edges=use_edges, use_receiver_nodes=use_receiver_nodes, use_sender_nodes=use_sender_nodes, use_globals=use_globals) edge_block(input_graph) variables = edge_block.get_variables() var_shapes_dict = { var.name: var.get_shape().as_list() for var in variables } self.assertDictEqual(expected_var_shapes_dict, var_shapes_dict)
def test_unused_field_can_be_none( self, use_edges, use_nodes, use_globals, none_field): """Checks that computation can handle non-necessary fields left None.""" input_graph = self._get_input_graph([none_field]) edge_block = blocks.EdgeBlock( edge_model_fn=self._edge_model_fn, use_edges=use_edges, use_receiver_nodes=use_nodes, use_sender_nodes=use_nodes, use_globals=use_globals) output_graph = edge_block(input_graph) model_inputs = [] if use_edges: model_inputs.append(input_graph.edges) if use_nodes: model_inputs.append(blocks.broadcast_receiver_nodes_to_edges(input_graph)) model_inputs.append(blocks.broadcast_sender_nodes_to_edges(input_graph)) if use_globals: model_inputs.append(blocks.broadcast_globals_to_edges(input_graph)) model_inputs = tf.concat(model_inputs, axis=-1) self.assertEqual(input_graph.nodes, output_graph.nodes) self.assertEqual(input_graph.globals, output_graph.globals) with self.test_session() as sess: actual_edges, model_inputs_out = sess.run( (output_graph.edges, model_inputs)) expected_output_edges = model_inputs_out * self._scale self.assertNDArrayNear(expected_output_edges, actual_edges, err=1e-4)
def __init__(self, name="SegmentClassifier"): super(SegmentClassifier, self).__init__(name=name) self._edge_block = blocks.EdgeBlock(edge_model_fn=make_mlp_model, use_edges=False, use_receiver_nodes=True, use_sender_nodes=True, use_globals=False, name='edge_encoder_block') self._node_encoder_block = blocks.NodeBlock( node_model_fn=make_mlp_model, use_received_edges=False, use_sent_edges=False, use_nodes=True, use_globals=False, name='node_encoder_block') self._core = InteractionNetwork(edge_model_fn=make_mlp_model, node_model_fn=make_mlp_model, reducer=tf.math.unsorted_segment_sum) # Transforms the outputs into appropriate shapes. edge_output_size = 1 edge_fn = lambda: snt.Sequential([ snt.nets.MLP( [edge_output_size], activation=tf.nn.relu, # default is relu name='edge_output'), tf.sigmoid ]) self._output_transform = modules.GraphIndependent(edge_fn, None, None)
def __init__(self, edge_model_fn, node_encoder_model_fn, node_model_fn, reducer=tf.unsorted_segment_sum, name="oricomm_net"): super(OriCommNet, self).__init__(name=name) with self._enter_variable_scope(): # Computes $\Psi_{com}(x_j)$ in Eq. (2) of 1706.06122 self._edge_block = blocks.EdgeBlock(edge_model_fn=edge_model_fn, use_edges=False, use_receiver_nodes=False, use_sender_nodes=True, use_globals=False) # Computes $\Phi(x_i)$ in Eq. (2) of 1706.06122 self._node_encoder_block = blocks.NodeBlock( node_model_fn=node_encoder_model_fn, use_received_edges=False, use_sent_edges=False, use_nodes=True, use_globals=False, received_edges_reducer=reducer, name="node_encoder_block") # Computes $\Theta(..)$ in Eq.(2) of 1706.06122 self._node_block = blocks.NodeBlock(node_model_fn=node_model_fn, use_received_edges=True, use_sent_edges=False, use_nodes=True, use_globals=False, received_edges_reducer=reducer)
def test_edge_block_options(self, use_edges, use_receiver_nodes, use_sender_nodes, use_globals): """Test for configuring the EdgeBlock options.""" reducer = tf.math.unsorted_segment_sum input_graph = self._get_input_graph() edge_model_fn = functools.partial(snt.Linear, output_size=10) edge_block_opt = {"use_edges": use_edges, "use_receiver_nodes": use_receiver_nodes, "use_sender_nodes": use_sender_nodes, "use_globals": use_globals} # Identity node model node_model_fn = lambda: tf.identity node_block_opt = {"use_received_edges": False, "use_sent_edges": False, "use_nodes": True, "use_globals": False} # Identity global model global_model_fn = lambda: tf.identity global_block_opt = {"use_globals": True, "use_nodes": False, "use_edges": False} graph_network = modules.GraphNetwork( edge_model_fn=edge_model_fn, edge_block_opt=edge_block_opt, node_model_fn=node_model_fn, node_block_opt=node_block_opt, global_model_fn=global_model_fn, global_block_opt=global_block_opt, reducer=reducer) output_graph = graph_network(input_graph) edge_block = blocks.EdgeBlock( edge_model_fn=lambda: graph_network._edge_block._edge_model, use_edges=use_edges, use_receiver_nodes=use_receiver_nodes, use_sender_nodes=use_sender_nodes, use_globals=use_globals) expected_output_edge_block = edge_block(input_graph) expected_output_node_block = expected_output_edge_block expected_output_global_block = expected_output_node_block expected_edges = expected_output_edge_block.edges expected_nodes = expected_output_node_block.nodes expected_globals = expected_output_global_block.globals with self.test_session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) (output_graph_out, expected_edges_out, expected_nodes_out, expected_globals_out) = sess.run( (output_graph, expected_edges, expected_nodes, expected_globals)) self.assertAllEqual(expected_edges_out, output_graph_out.edges) self.assertAllEqual(expected_nodes_out, output_graph_out.nodes) self.assertAllEqual(expected_globals_out, output_graph_out.globals)
def test_no_input_raises_exception(self): """Checks that receiving no input raises an exception.""" with self.assertRaisesRegexp(ValueError, "At least one of "): blocks.EdgeBlock(edge_model_fn=self._edge_model_fn, use_edges=False, use_receiver_nodes=False, use_sender_nodes=False, use_globals=False)
def __init__(self, attention_node_projection_model, attention_edge_projection_model, query_key_product_model, node_model_fn, edge_model_fn, global_model_fn, num_heads, key_size, value_size, edge_block_opt=None, global_block_opt=None, name="GAT"): """ Args: attention_node_projection_model: Model used for projection to get query, key and values Final layer dim should be key_size * num_heads attention_edge_projection_model: Model used for projection to get query, key and values Final layer dim should be (key_size + value_size) * num_heads query_key_product_model: Model used to find "dot product" between queries and keys. Final layer dim should be 1. node_model_fn: Model applied to node embeddings finally. edge_model_fn: Model applied to node embeddings finally. num_heads: Number of attention heads key_size: Key dimension value_size: value dimension edge_block_opt: Additional options to be passed to the EdgeBlock. Can contain keys `use_edges`, `use_receiver_nodes`, `use_sender_nodes`, `use_globals`. By default, these are all True. global_block_opt: Additional options to be passed to the GlobalBlock. Can contain the keys `use_edges`, `use_nodes`, `use_globals` (all set to True by default), and `edges_reducer`, `nodes_reducer` (defaults to `reducer`). name: The module name. """ super().__init__(name=name) self._attention_node_projection_model = attention_node_projection_model self._attention_edge_projection_model = attention_edge_projection_model self._query_key_product_model = query_key_product_model self.num_heads = num_heads self.key_size = key_size self.value_size = value_size edge_block_opt = _make_default_edge_block_opt(edge_block_opt) global_block_opt = _make_default_global_block_opt( global_block_opt, tf.unsorted_segment_sum) # does not make sense without using sender nodes. assert edge_block_opt['use_sender_nodes'] with self._enter_variable_scope(): self._node_model = node_model_fn() self._edge_block = blocks.EdgeBlock(edge_model_fn=edge_model_fn, **edge_block_opt) self._global_block = blocks.GlobalBlock( global_model_fn=global_model_fn, **global_block_opt)
def test_optional_arguments(self, scale, offset): """Assesses the correctness of the EdgeBlock using arguments.""" input_graph = self._get_input_graph() edge_block = blocks.EdgeBlock(edge_model_fn=self._edge_model_args_fn) output_graph_out = edge_block( input_graph, edge_model_kwargs=dict(scale=scale, offset=offset)) fixed_scale = scale fixed_offset = offset model_fn = lambda: lambda features: features * fixed_scale + fixed_offset hardcoded_edge_block = blocks.EdgeBlock(edge_model_fn=model_fn) expected_graph_out = hardcoded_edge_block(input_graph) self.assertIs(expected_graph_out.nodes, output_graph_out.nodes) self.assertIs(expected_graph_out.globals, output_graph_out.globals) self.assertNDArrayNear( expected_graph_out.edges.numpy(), output_graph_out.edges.numpy(), err=1e-4)
def test_same_as_subblocks(self, reducer, none_field=None): """Compares the output to explicit subblocks output. Args: reducer: The reducer used in the `NodeBlock`s. none_field: (string, default=None) If not None, the corresponding field is removed from the input graph. """ input_graph = self._get_input_graph(none_field) comm_net = self._get_model(reducer) output_graph = comm_net(input_graph) output_nodes = output_graph.nodes edge_subblock = blocks.EdgeBlock( edge_model_fn=lambda: comm_net._edge_block._edge_model, use_edges=False, use_receiver_nodes=False, use_sender_nodes=True, use_globals=False) node_encoder_subblock = blocks.NodeBlock( node_model_fn=lambda: comm_net._node_encoder_block._node_model, use_received_edges=False, use_sent_edges=False, use_nodes=True, use_globals=False, received_edges_reducer=reducer) node_subblock = blocks.NodeBlock( node_model_fn=lambda: comm_net._node_block._node_model, use_received_edges=True, use_sent_edges=False, use_nodes=True, use_globals=False, received_edges_reducer=reducer) edge_block_out = edge_subblock(input_graph) encoded_nodes = node_encoder_subblock(input_graph).nodes node_input_graph = input_graph.replace( edges=edge_block_out.edges, nodes=encoded_nodes) node_block_out = node_subblock(node_input_graph) expected_nodes = node_block_out.nodes self.assertAllEqual(input_graph.globals, output_graph.globals) self.assertAllEqual(input_graph.edges, output_graph.edges) self.assertAllEqual(input_graph.receivers, output_graph.receivers,) self.assertAllEqual(input_graph.senders, output_graph.senders) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) actual_nodes_output, expected_nodes_output = sess.run( [output_nodes, expected_nodes]) self._assert_all_none_or_all_close(expected_nodes_output, actual_nodes_output)
def __init__(self, edge_fn, with_edge_inputs=False, with_node_inputs=True, encoder_size: list = None, core_size: list = None, name="EdgeLearnerBase", **kwargs): super(EdgeLearnerBase, self).__init__(name=name) if encoder_size is not None: encoder_mlp_fn = partial(make_mlp_model, mlp_size=encoder_size, **kwargs) else: encoder_mlp_fn = partial(make_mlp_model, **kwargs) edge_block_args = dict(use_edges=False, use_receiver_nodes=True, use_sender_nodes=True, use_globals=False) node_block_args = dict(use_received_edges=False, use_sent_edges=False, use_nodes=True, use_globals=False) if with_edge_inputs: edge_block_args['use_edges'] = True node_block_args['use_received_edges'] = True node_block_args['use_sent_edges'] = True if not with_node_inputs: edge_block_args['use_receiver_nodes'] = False edge_block_args['use_sender_nodes'] = False node_block_args['use_nodes'] = False self._edge_block = blocks.EdgeBlock(edge_model_fn=encoder_mlp_fn, **edge_block_args, name='edge_encoder_block') self._node_encoder_block = blocks.NodeBlock( node_model_fn=encoder_mlp_fn, **node_block_args, name='node_encoder_block') if core_size is not None: core_mlp_fn = partial(make_mlp_model, mlp_size=core_size, **kwargs) else: core_mlp_fn = partial(make_mlp_model, **kwargs) self._core = InteractionNetwork(edge_model_fn=core_mlp_fn, node_model_fn=core_mlp_fn, reducer=tf.math.unsorted_segment_sum) self._output_transform = modules.GraphIndependent(edge_fn, None, None)
def __init__(self, edge_model_fn, node_model_fn, global_model_fn, reducer=tf.math.unsorted_segment_sum, edge_block_opt=None, node_block_opt=None, global_block_opt=None, name="graph_network"): """Initializes the GraphNetwork module. Args: edge_model_fn: A callable that will be passed to EdgeBlock to perform per-edge computations. The callable must return a Sonnet module (or equivalent; see EdgeBlock for details). node_model_fn: A callable that will be passed to NodeBlock to perform per-node computations. The callable must return a Sonnet module (or equivalent; see NodeBlock for details). global_model_fn: A callable that will be passed to GlobalBlock to perform per-global computations. The callable must return a Sonnet module (or equivalent; see GlobalBlock for details). reducer: Reducer to be used by NodeBlock and GlobalBlock to aggregate nodes and edges. Defaults to tf.unsorted_segment_sum. This will be overridden by the reducers specified in `node_block_opt` and `global_block_opt`, if any. edge_block_opt: Additional options to be passed to the EdgeBlock. Can contain keys `use_edges`, `use_receiver_nodes`, `use_sender_nodes`, `use_globals`. By default, these are all True. node_block_opt: Additional options to be passed to the NodeBlock. Can contain the keys `use_received_edges`, `use_sent_edges`, `use_nodes`, `use_globals` (all set to True by default), and `received_edges_reducer`, `sent_edges_reducer` (default to `reducer`). global_block_opt: Additional options to be passed to the GlobalBlock. Can contain the keys `use_edges`, `use_nodes`, `use_globals` (all set to True by default), and `edges_reducer`, `nodes_reducer` (defaults to `reducer`). name: The module name. """ super(GraphNetwork, self).__init__(name=name) edge_block_opt = _make_default_edge_block_opt(edge_block_opt) node_block_opt = _make_default_node_block_opt(node_block_opt, reducer) global_block_opt = _make_default_global_block_opt( global_block_opt, reducer) #with self._enter_variable_scope(): if 2 > 1: self._edge_block = blocks.EdgeBlock(edge_model_fn=edge_model_fn, **edge_block_opt) self._node_block = blocks.NodeBlock(node_model_fn=node_model_fn, **node_block_opt) self._global_block = blocks.GlobalBlock( global_model_fn=global_model_fn, **global_block_opt)
def test_missing_field_raises_exception( self, use_edges, use_receiver_nodes, use_sender_nodes, use_globals, none_fields): """Checks that missing a required field raises an exception.""" input_graph = self._get_input_graph(none_fields) edge_block = blocks.EdgeBlock( edge_model_fn=self._edge_model_fn, use_edges=use_edges, use_receiver_nodes=use_receiver_nodes, use_sender_nodes=use_sender_nodes, use_globals=use_globals) with self.assertRaisesRegexp(ValueError, "field cannot be None"): edge_block(input_graph)
def __init__(self, edge_model_fn, node_model_fn, reducer=tf.math.unsorted_segment_sum, name="interaction_network"): super(InteractionNetwork, self).__init__(name=name) self._edge_block = blocks.EdgeBlock(edge_model_fn=edge_model_fn, use_globals=False) self._node_block = blocks.NodeBlock(node_model_fn=node_model_fn, use_received_edges=True, use_sent_edges=True, use_globals=False, received_edges_reducer=reducer)
def test_incompatible_higher_rank_inputs_no_raise(self, use_edges, use_receiver_nodes, use_sender_nodes, use_globals, field): """No exception should occur if a differently shapped field is not used.""" input_graph = self._get_shaped_input_graph() input_graph = input_graph.replace( **{field: tf.transpose(getattr(input_graph, field), [0, 2, 1, 3])}) network = blocks.EdgeBlock(functools.partial(snt.Conv2D, output_channels=10, kernel_shape=[3, 3]), use_edges=use_edges, use_receiver_nodes=use_receiver_nodes, use_sender_nodes=use_sender_nodes, use_globals=use_globals) self._assert_build_and_run(network, input_graph)
def test_incompatible_higher_rank_inputs_raises(self, use_edges, use_receiver_nodes, use_sender_nodes, use_globals, field): """A exception should be raised if the inputs have incompatible shapes.""" input_graph = self._get_shaped_input_graph() input_graph = input_graph.replace( **{field: tf.transpose(getattr(input_graph, field), [0, 2, 1, 3])}) network = blocks.EdgeBlock(functools.partial(snt.Conv2D, output_channels=10, kernel_shape=[3, 3]), use_edges=use_edges, use_receiver_nodes=use_receiver_nodes, use_sender_nodes=use_sender_nodes, use_globals=use_globals) with self.assertRaisesRegexp(ValueError, "in both shapes must be equal"): network(input_graph)
def test_same_as_subblocks(self, reducer, none_field=None): """Compares the output to explicit subblocks output. Args: reducer: The reducer used in the `NodeBlock`s. none_field: (string, default=None) If not None, the corresponding field is removed from the input graph. """ input_graph = self._get_input_graph(none_field) interaction_network = self._get_model(reducer) output_graph = interaction_network(input_graph) edges_out = output_graph.edges nodes_out = output_graph.nodes self.assertAllEqual(input_graph.globals, output_graph.globals) edge_block = blocks.EdgeBlock( edge_model_fn=lambda: interaction_network._edge_block._edge_model, use_sender_nodes=True, use_edges=True, use_receiver_nodes=True, use_globals=False) node_block = blocks.NodeBlock( node_model_fn=lambda: interaction_network._node_block._node_model, use_nodes=True, use_sent_edges=False, use_received_edges=True, use_globals=False, received_edges_reducer=reducer) expected_output_edge_block = edge_block(input_graph) expected_output_node_block = node_block(expected_output_edge_block) expected_edges = expected_output_edge_block.edges expected_nodes = expected_output_node_block.nodes with self.test_session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) (actual_edges_out, actual_nodes_out, expected_edges_out, expected_nodes_out) = sess.run( [edges_out, nodes_out, expected_edges, expected_nodes]) self.assertAllEqual(expected_edges_out, actual_edges_out) self.assertAllEqual(expected_nodes_out, actual_nodes_out)