def __init__(self, edge_model_fn, node_encoder_model_fn, node_model_fn, reducer=tf.unsorted_segment_sum, name="oricomm_net"): super(OriCommNet, self).__init__(name=name) with self._enter_variable_scope(): # Computes $\Psi_{com}(x_j)$ in Eq. (2) of 1706.06122 self._edge_block = blocks.EdgeBlock(edge_model_fn=edge_model_fn, use_edges=False, use_receiver_nodes=False, use_sender_nodes=True, use_globals=False) # Computes $\Phi(x_i)$ in Eq. (2) of 1706.06122 self._node_encoder_block = blocks.NodeBlock( node_model_fn=node_encoder_model_fn, use_received_edges=False, use_sent_edges=False, use_nodes=True, use_globals=False, received_edges_reducer=reducer, name="node_encoder_block") # Computes $\Theta(..)$ in Eq.(2) of 1706.06122 self._node_block = blocks.NodeBlock(node_model_fn=node_model_fn, use_received_edges=True, use_sent_edges=False, use_nodes=True, use_globals=False, received_edges_reducer=reducer)
def test_same_as_subblocks(self, reducer, none_field=None): """Compares the output to explicit subblocks output. Args: reducer: The reducer used in the `NodeBlock`s. none_field: (string, default=None) If not None, the corresponding field is removed from the input graph. """ input_graph = self._get_input_graph(none_field) comm_net = self._get_model(reducer) output_graph = comm_net(input_graph) output_nodes = output_graph.nodes edge_subblock = blocks.EdgeBlock( edge_model_fn=lambda: comm_net._edge_block._edge_model, use_edges=False, use_receiver_nodes=False, use_sender_nodes=True, use_globals=False) node_encoder_subblock = blocks.NodeBlock( node_model_fn=lambda: comm_net._node_encoder_block._node_model, use_received_edges=False, use_sent_edges=False, use_nodes=True, use_globals=False, received_edges_reducer=reducer) node_subblock = blocks.NodeBlock( node_model_fn=lambda: comm_net._node_block._node_model, use_received_edges=True, use_sent_edges=False, use_nodes=True, use_globals=False, received_edges_reducer=reducer) edge_block_out = edge_subblock(input_graph) encoded_nodes = node_encoder_subblock(input_graph).nodes node_input_graph = input_graph.replace( edges=edge_block_out.edges, nodes=encoded_nodes) node_block_out = node_subblock(node_input_graph) expected_nodes = node_block_out.nodes self.assertAllEqual(input_graph.globals, output_graph.globals) self.assertAllEqual(input_graph.edges, output_graph.edges) self.assertAllEqual(input_graph.receivers, output_graph.receivers,) self.assertAllEqual(input_graph.senders, output_graph.senders) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) actual_nodes_output, expected_nodes_output = sess.run( [output_nodes, expected_nodes]) self._assert_all_none_or_all_close(expected_nodes_output, actual_nodes_output)
def __init__(self, edge_model_fn, node_encoder_model_fn, node_model_fn, reducer=tf.math.unsorted_segment_sum, name="comm_net"): """Initializes the CommNet module. Args: edge_model_fn: A callable to be passed to EdgeBlock. The callable must return a Sonnet module (or equivalent; see EdgeBlock for details). node_encoder_model_fn: A callable to be passed to the NodeBlock responsible for the first encoding of the nodes. The callable must return a Sonnet module (or equivalent; see NodeBlock for details). The shape of this module's output should match the shape of the module built by `edge_model_fn`, but for the first and last dimension. node_model_fn: A callable to be passed to NodeBlock. The callable must return a Sonnet module (or equivalent; see NodeBlock for details). reducer: Reduction to be used when aggregating the edges in the nodes. This should be a callable whose signature matches tf.unsorted_segment_sum. name: The module name. """ super(CommNet, self).__init__(name=name) #with self._enter_variable_scope(): if 2 > 1: # Computes $\Psi_{com}(x_j)$ in Eq. (2) of 1706.06122 self._edge_block = blocks.EdgeBlock(edge_model_fn=edge_model_fn, use_edges=False, use_receiver_nodes=False, use_sender_nodes=True, use_globals=False) # Computes $\Phi(x_i)$ in Eq. (2) of 1706.06122 self._node_encoder_block = blocks.NodeBlock( node_model_fn=node_encoder_model_fn, use_received_edges=False, use_sent_edges=False, use_nodes=True, use_globals=False, received_edges_reducer=reducer, name="node_encoder_block") # Computes $\Theta(..)$ in Eq.(2) of 1706.06122 self._node_block = blocks.NodeBlock(node_model_fn=node_model_fn, use_received_edges=True, use_sent_edges=False, use_nodes=True, use_globals=False, received_edges_reducer=reducer)
def test_compatible_higher_rank_no_raise(self): """No exception should occur with higher ranks tensors.""" input_graph = self._get_shaped_input_graph() input_graph = input_graph.map(lambda v: tf.transpose(v, [0, 2, 1, 3])) network = blocks.NodeBlock( functools.partial(snt.Conv2D, output_channels=10, kernel_shape=[3, 3])) self._assert_build_and_run(network, input_graph)
def __init__(self, node_model_fn, global_model_fn, name="rwrd_gnn"): # aggregator_fn = tf.math.unsorted_segment_sum, """Initializes the reward model""" super(GraphNeuralNetwork_reward, self).__init__(name=name) with self._enter_variable_scope(): # self._edge_block = blocks.EdgeBlock( # edge_model_fn=edge_model_fn, # use_edges=False, # use_receiver_nodes=True, # use_sender_nodes=True, # use_globals=True, # name='edge_block') self._node_block = blocks.NodeBlock( node_model_fn=node_model_fn, use_received_edges=False, use_sent_edges=False, use_nodes=True, use_globals=True, received_edges_reducer=tf.math.unsorted_segment_sum, sent_edges_reducer=tf.math.unsorted_segment_sum, name="node_block") self._global_block = blocks.GlobalBlock( global_model_fn=global_model_fn, use_edges=False, use_nodes=True, use_globals=True, nodes_reducer=tf.math.unsorted_segment_sum, edges_reducer=tf.math.unsorted_segment_sum, name="global_block")
def __init__(self, conf, name="encoder-attention-tsp"): """Inits the module. Args: name: The module name. """ super(Encoder, self).__init__(name=name) self.conf = conf self.training = True with self._enter_variable_scope(): self._initial_projection = snt.Linear( output_size=self.conf.embedding_dim, initializers={ 'w': utils.initializer(conf.init_dim), 'b': utils.initializer(conf.init_dim) }, name="initial_projection") self._initial_projection_block = blocks.NodeBlock( lambda: self._initial_projection, use_received_edges=False, use_nodes=True, use_globals=False, name="initial_block_projection") self._encoder_layers = [ EncoderLayer(conf, "encoder_layer_%i" % i) for i in range(self.conf.encoder_nbr_layers) ]
def test_unused_field_can_be_none( self, use_edges, use_nodes, use_globals, none_field): """Checks that computation can handle non-necessary fields left None.""" input_graph = self._get_input_graph([none_field]) node_block = blocks.NodeBlock( node_model_fn=self._node_model_fn, use_received_edges=use_edges, use_sent_edges=use_edges, use_nodes=use_nodes, use_globals=use_globals) output_graph = node_block(input_graph) model_inputs = [] if use_edges: model_inputs.append( blocks.ReceivedEdgesToNodesAggregator( tf.unsorted_segment_sum)(input_graph)) model_inputs.append( blocks.SentEdgesToNodesAggregator( tf.unsorted_segment_sum)(input_graph)) if use_nodes: model_inputs.append(input_graph.nodes) if use_globals: model_inputs.append(blocks.broadcast_globals_to_nodes(input_graph)) model_inputs = tf.concat(model_inputs, axis=-1) self.assertEqual(input_graph.edges, output_graph.edges) self.assertEqual(input_graph.globals, output_graph.globals) with self.test_session() as sess: actual_nodes, model_inputs_out = sess.run( (output_graph.nodes, model_inputs)) expected_output_nodes = model_inputs_out * self._scale self.assertNDArrayNear(expected_output_nodes, actual_nodes, err=1e-4)
def __init__(self, name="DecaySimulator"): super(DecaySimulator, self).__init__(name=name) self._node_linear = make_mlp_model() self._node_rnn = snt.GRU(hidden_size=LATENT_SIZE, name='node_rnn') self._node_proper = snt.nets.MLP([4], activate_final=False) self._edge_block = blocks.EdgeBlock(edge_model_fn=make_mlp_model, use_edges=False, use_receiver_nodes=True, use_sender_nodes=True, use_globals=False, name='edge_encoder_block') self._node_encoder_block = blocks.NodeBlock( node_model_fn=make_mlp_model, use_received_edges=False, use_sent_edges=False, use_nodes=True, use_globals=False, name='node_encoder_block') self._global_encoder_block = blocks.GlobalBlock( global_model_fn=make_mlp_model, use_edges=True, use_nodes=True, use_globals=False, nodes_reducer=tf.math.unsorted_segment_sum, edges_reducer=tf.math.unsorted_segment_sum, name='global_encoder_block') self._core = MLPGraphNetwork() # self._core = InteractionNetwork( # edge_model_fn=make_mlp_model, # node_model_fn=make_mlp_model, # reducer=tf.math.unsorted_segment_sum # ) # # Transforms the outputs into appropriate shapes. node_output_size = 64 node_fn = lambda: snt.Sequential([ snt.nets.MLP( [node_output_size], activation=tf.nn.relu, # default is relu name='node_output') ]) global_output_size = 1 global_fn = lambda: snt.Sequential([ snt.nets.MLP( [global_output_size], activation=tf.nn.relu, # default is relu name='global_output'), tf.sigmoid ]) self._output_transform = modules.GraphIndependent( edge_model_fn=None, node_model_fn=node_fn, global_model_fn=global_fn)
def __init__(self, name="DeepGraphInfoMax"): super(DeepGraphInfoMax, self).__init__(name=name) self._edge_block = blocks.EdgeBlock( edge_model_fn=lambda: snt.nets.MLP([LATENT_SIZE] * 2, activation=tf.nn.relu, activate_final=True, use_dropout=True), use_edges=False, use_receiver_nodes=True, use_sender_nodes=True, use_globals=False, name='edge_encoder_block') self._node_encoder_block = blocks.NodeBlock( node_model_fn=make_mlp_model, use_received_edges=False, use_sent_edges=False, use_nodes=True, use_globals=False, name='node_encoder_block') self._core = modules.InteractionNetwork( edge_model_fn=make_mlp_model, node_model_fn=make_mlp_model, reducer=tf.unsorted_segment_sum)
def __init__(self, node_model_fn, name=None): super(DecoderNetwork, self).__init__(name=name) self.node_block = blocks.NodeBlock(node_model_fn, use_received_edges=False, use_sent_edges=False, use_nodes=False, use_globals=True)
def __init__(self, name="SegmentClassifier"): super(SegmentClassifier, self).__init__(name=name) self._edge_block = blocks.EdgeBlock(edge_model_fn=make_mlp_model, use_edges=False, use_receiver_nodes=True, use_sender_nodes=True, use_globals=False, name='edge_encoder_block') self._node_encoder_block = blocks.NodeBlock( node_model_fn=make_mlp_model, use_received_edges=False, use_sent_edges=False, use_nodes=True, use_globals=False, name='node_encoder_block') self._core = InteractionNetwork(edge_model_fn=make_mlp_model, node_model_fn=make_mlp_model, reducer=tf.math.unsorted_segment_sum) # Transforms the outputs into appropriate shapes. edge_output_size = 1 edge_fn = lambda: snt.Sequential([ snt.nets.MLP( [edge_output_size], activation=tf.nn.relu, # default is relu name='edge_output'), tf.sigmoid ]) self._output_transform = modules.GraphIndependent(edge_fn, None, None)
def test_same_as_subblocks(self, reducer): """Compares the output to explicit subblocks output. Args: reducer: The reducer used in the `NodeBlock` and `GlobalBlock`. """ input_graph = self._get_input_graph() edge_model_fn = functools.partial(snt.Linear, output_size=5) node_model_fn = functools.partial(snt.Linear, output_size=10) global_model_fn = functools.partial(snt.Linear, output_size=15) graph_network = modules.GraphNetwork( edge_model_fn=edge_model_fn, node_model_fn=node_model_fn, global_model_fn=global_model_fn, reducer=reducer) output_graph = graph_network(input_graph) edge_block = blocks.EdgeBlock( edge_model_fn=lambda: graph_network._edge_block._edge_model, use_sender_nodes=True, use_edges=True, use_receiver_nodes=True, use_globals=True) node_block = blocks.NodeBlock( node_model_fn=lambda: graph_network._node_block._node_model, use_nodes=True, use_sent_edges=False, use_received_edges=True, use_globals=True, received_edges_reducer=reducer) global_block = blocks.GlobalBlock( global_model_fn=lambda: graph_network._global_block._global_model, use_nodes=True, use_edges=True, use_globals=True, edges_reducer=reducer, nodes_reducer=reducer) expected_output_edge_block = edge_block(input_graph) expected_output_node_block = node_block(expected_output_edge_block) expected_output_global_block = global_block(expected_output_node_block) expected_edges = expected_output_edge_block.edges expected_nodes = expected_output_node_block.nodes expected_globals = expected_output_global_block.globals with self.test_session() as sess: sess.run(tf.global_variables_initializer()) (output_graph_out, expected_edges_out, expected_nodes_out, expected_globals_out) = sess.run( (output_graph, expected_edges, expected_nodes, expected_globals)) self._assert_all_none_or_all_close(expected_edges_out, output_graph_out.edges) self._assert_all_none_or_all_close(expected_nodes_out, output_graph_out.nodes) self._assert_all_none_or_all_close(expected_globals_out, output_graph_out.globals)
def __init__(self, edge_model_fn, node_model_fn, reducer=tf.unsorted_segment_sum, name="interaction_network"): """Initializes the InteractionNetwork module. Args: edge_model_fn: A callable that will be passed to `EdgeBlock` to perform per-edge computations. The callable must return a Sonnet module (or equivalent; see `blocks.EdgeBlock` for details), and the shape of the output of this module must match the one of the input nodes, but for the first and last axis. node_model_fn: A callable that will be passed to `NodeBlock` to perform per-node computations. The callable must return a Sonnet module (or equivalent; see `blocks.NodeBlock` for details). reducer: Reducer to be used by NodeBlock to aggregate edges. Defaults to tf.unsorted_segment_sum. name: The module name. """ super(InteractionNetwork, self).__init__(name=name) with self._enter_variable_scope(): self._edge_block = blocks.EdgeBlock(edge_model_fn=edge_model_fn, use_globals=False) self._node_block = blocks.NodeBlock(node_model_fn=node_model_fn, use_received_edges=True, use_sent_edges=True, use_globals=False, received_edges_reducer=reducer)
def __init__(self, name="FourTopPredictor"): super(FourTopPredictor, self).__init__(name=name) self._edge_block = blocks.EdgeBlock(edge_model_fn=make_mlp_model, use_edges=False, use_receiver_nodes=True, use_sender_nodes=True, use_globals=False, name='edge_encoder_block') self._node_encoder_block = blocks.NodeBlock( node_model_fn=make_mlp_model, use_received_edges=False, use_sent_edges=False, use_nodes=True, use_globals=False, name='node_encoder_block') self._global_block = blocks.GlobalBlock( global_model_fn=make_mlp_model, use_edges=True, use_nodes=True, use_globals=False, ) self._core = MLPGraphNetwork() # Transforms the outputs into appropriate shapes. global_output_size = n_target_node_features * n_max_tops self._global_nn = snt.nets.MLP( [128, 128, global_output_size], activation=tf.nn.leaky_relu, # default is relu, tanh dropout_rate=0.30, name='global_output')
def __init__(self, node_model_fn, global_model_fn, reducer=tf.unsorted_segment_sum, name="deep_sets"): """Initializes the DeepSets module. Args: node_model_fn: A callable to be passed to NodeBlock. The callable must return a Sonnet module (or equivalent; see NodeBlock for details). The shape of this module's output must equal the shape of the input graph's global features, but for the first and last axis. global_model_fn: A callable to be passed to GlobalBlock. The callable must return a Sonnet module (or equivalent; see GlobalBlock for details). reducer: Reduction to be used when aggregating the nodes in the globals. This should be a callable whose signature matches tf.unsorted_segment_sum. name: The module name. """ super(DeepSets, self).__init__(name=name) with self._enter_variable_scope(): self._node_block = blocks.NodeBlock(node_model_fn=node_model_fn, use_received_edges=False, use_sent_edges=False, use_nodes=True, use_globals=True) self._global_block = blocks.GlobalBlock( global_model_fn=global_model_fn, use_edges=False, use_nodes=True, use_globals=False, nodes_reducer=reducer)
def __init__(self, name="GlobalClassifierNoEdgeInfo"): super(GlobalClassifierNoEdgeInfo, self).__init__(name=name) self._edge_block = blocks.EdgeBlock(edge_model_fn=make_mlp_model, use_edges=False, use_receiver_nodes=True, use_sender_nodes=True, use_globals=False, name='edge_encoder_block') self._node_encoder_block = blocks.NodeBlock( node_model_fn=make_mlp_model, use_received_edges=False, use_sent_edges=False, use_nodes=True, use_globals=False, name='node_encoder_block') self._global_block = blocks.GlobalBlock( global_model_fn=make_mlp_model, use_edges=True, use_nodes=True, use_globals=False, ) self._core = MLPGraphNetwork() # Transforms the outputs into appropriate shapes. global_output_size = 1 global_fn = lambda: snt.Sequential([ snt.nets.MLP([LATENT_SIZE, global_output_size], name='global_output'), tf.sigmoid ]) self._output_transform = modules.GraphIndependent( None, None, global_fn)
def test_created_variables(self, use_received_edges, use_sent_edges, use_nodes, use_globals, expected_first_dim_w): """Verifies the variable names and shapes created by a NodeBlock.""" output_size = 10 expected_var_shapes_dict = { "node_block/mlp/linear_0/b:0": [output_size], "node_block/mlp/linear_0/w:0": [expected_first_dim_w, output_size] } input_graph = self._get_input_graph() node_block = blocks.NodeBlock(node_model_fn=functools.partial( snt.nets.MLP, output_sizes=[output_size]), use_received_edges=use_received_edges, use_sent_edges=use_sent_edges, use_nodes=use_nodes, use_globals=use_globals) node_block(input_graph) variables = node_block.get_variables() var_shapes_dict = { var.name: var.get_shape().as_list() for var in variables } self.assertDictEqual(expected_var_shapes_dict, var_shapes_dict)
def test_no_input_raises_exception(self): """Checks that receiving no input raises an exception.""" with self.assertRaisesRegexp(ValueError, "At least one of "): blocks.NodeBlock(node_model_fn=self._node_model_fn, use_received_edges=False, use_sent_edges=False, use_nodes=False, use_globals=False)
def __init__(self, mlp_size=16, cluster_encoded_size=10, num_heads=10, core_steps=10, name=None): super(Model, self).__init__(name=name) self.epd_encoder = EncodeProcessDecode_E( encoder=EncoderNetwork(edge_model_fn=lambda: snt.nets.MLP( [mlp_size], activate_final=True, activation=tf.nn.leaky_relu), node_model_fn=lambda: snt.Linear( cluster_encoded_size), global_model_fn=lambda: snt. nets.MLP([mlp_size], activate_final=True, activation=tf.nn.leaky_relu)), core=CoreNetwork(num_heads=num_heads, multi_head_output_size=cluster_encoded_size, input_node_size=cluster_encoded_size), decoder=EncoderNetwork(edge_model_fn=lambda: snt.nets.MLP( [mlp_size], activate_final=True, activation=tf.nn.leaky_relu), node_model_fn=lambda: snt.Linear( cluster_encoded_size), global_model_fn=lambda: snt.nets.MLP( [32, 32, 64], activate_final=True, activation=tf.nn.leaky_relu))) self.epd_decoder = EncodeProcessDecode_D( encoder=DecoderNetwork(node_model_fn=lambda: snt.nets.MLP( [32, 32, cluster_encoded_size], activate_final=True, activation=tf.nn.leaky_relu)), core=CoreNetwork(num_heads=num_heads, multi_head_output_size=cluster_encoded_size, input_node_size=cluster_encoded_size), decoder=snt.Sequential([ RelationNetwork(edge_model_fn=lambda: snt.nets.MLP( [mlp_size], activate_final=True, activation=tf.nn.leaky_relu), global_model_fn=lambda: snt.nets.MLP( [mlp_size], activate_final=True, activation=tf.nn.leaky_relu)), blocks.NodeBlock(node_model_fn=lambda: snt.nets.MLP( [cluster_encoded_size - 3], activate_final=True, activation=tf.nn.leaky_relu), use_received_edges=True, use_sent_edges=True, use_nodes=True, use_globals=True) ])) self._core_steps = core_steps
def test_optional_arguments(self, scale, offset): """Assesses the correctness of the NodeBlock using arguments.""" input_graph = self._get_input_graph() node_block = blocks.NodeBlock(node_model_fn=self._node_model_args_fn) output_graph_out = node_block( input_graph, node_model_kwargs=dict(scale=scale, offset=offset)) fixed_scale = scale fixed_offset = offset model_fn = lambda: lambda features: features * fixed_scale + fixed_offset hardcoded_node_block = blocks.NodeBlock(node_model_fn=model_fn) expected_graph_out = hardcoded_node_block(input_graph) self.assertIs(expected_graph_out.edges, output_graph_out.edges) self.assertIs(expected_graph_out.globals, output_graph_out.globals) self.assertNDArrayNear( expected_graph_out.nodes.numpy(), output_graph_out.nodes.numpy(), err=1e-4)
def tinet(input_graph): embedding = blocks.NodeBlock( node_model_fn=lambda: tf.keras.layers.Embedding(800, 32), use_received_edges=False, use_sent_edges=False, use_nodes=True, use_globals=False) graph_network_layer1 = blocks.NodeBlock( # edge_model_fn = lambda: tf.layers.Dense(16, activation=tf.nn.relu), node_model_fn=lambda: tf.layers.Dense(32, activation=tf.nn.relu)) # global_model_fn = lambda: tf.layers.Dense(8, activation=tf.nn.relu)) graph_network_layer2 = blocks.NodeBlock( # edge_model_fn = lambda: tf.layers.Dense(16, activation=tf.nn.relu), node_model_fn=lambda: tf.layers.Dense(32, activation=tf.nn.relu)) # global_model_fn = lambda: tf.layers.Dense(8, activation=tf.nn.relu)) graph_network_layer3 = blocks.NodeBlock( # edge_model_fn = lambda: tf.layers.Dense(16, activation=tf.nn.relu), node_model_fn=lambda: tf.layers.Dense(32, activation=tf.nn.relu)) # global_model_fn = lambda: tf.layers.Dense(8, activation=tf.nn.relu)) graph_network_layer4 = blocks.NodeBlock( # edge_model_fn = lambda: tf.layers.Dense(16, activation=tf.nn.relu), node_model_fn=lambda: tf.layers.Dense(32, activation=tf.nn.relu)) # global_model_fn = lambda: tf.layers.Dense(8, activation=tf.nn.relu)) graph_network_layer5 = blocks.GlobalBlock( # edge_model_fn = lambda: tf.layers.Dense(16, activation=tf.nn.relu), # node_model_fn = lambda: tf.layers.Dense(16, activation=tf.nn.relu), global_model_fn=lambda: tf.layers.Dense(40, activation=tf.nn.relu)) graph_network_layer6 = blocks.GlobalBlock( # edge_model_fn = lambda: tf.layers.Dense(16, activation=tf.nn.relu), # node_model_fn = lambda: tf.layers.Dense(16, activation=tf.nn.relu), global_model_fn=lambda: tf.layers.Dense(40, activation=tf.nn.relu)) h0 = embedding(input_graph) h1 = graph_network_layer1(h0) h2 = graph_network_layer2(h1) h3 = graph_network_layer3(h2) h4 = graph_network_layer4(h3) h5 = graph_network_layer5(h4) h6 = graph_network_layer6(h5) out = h6.globals return tf.layers.dense(out, 4, activation=None)
def test_missing_field_raises_exception(self, use_received_edges, use_sent_edges, use_nodes, use_globals, none_fields): """Checks that missing a required field raises an exception.""" input_graph = self._get_input_graph(none_fields) node_block = blocks.NodeBlock(node_model_fn=self._node_model_fn, use_received_edges=use_received_edges, use_sent_edges=use_sent_edges, use_nodes=use_nodes, use_globals=use_globals) with self.assertRaisesRegexp(ValueError, "field cannot be None"): node_block(input_graph)
def __init__(self, edge_fn, with_edge_inputs=False, with_node_inputs=True, encoder_size: list = None, core_size: list = None, name="EdgeLearnerBase", **kwargs): super(EdgeLearnerBase, self).__init__(name=name) if encoder_size is not None: encoder_mlp_fn = partial(make_mlp_model, mlp_size=encoder_size, **kwargs) else: encoder_mlp_fn = partial(make_mlp_model, **kwargs) edge_block_args = dict(use_edges=False, use_receiver_nodes=True, use_sender_nodes=True, use_globals=False) node_block_args = dict(use_received_edges=False, use_sent_edges=False, use_nodes=True, use_globals=False) if with_edge_inputs: edge_block_args['use_edges'] = True node_block_args['use_received_edges'] = True node_block_args['use_sent_edges'] = True if not with_node_inputs: edge_block_args['use_receiver_nodes'] = False edge_block_args['use_sender_nodes'] = False node_block_args['use_nodes'] = False self._edge_block = blocks.EdgeBlock(edge_model_fn=encoder_mlp_fn, **edge_block_args, name='edge_encoder_block') self._node_encoder_block = blocks.NodeBlock( node_model_fn=encoder_mlp_fn, **node_block_args, name='node_encoder_block') if core_size is not None: core_mlp_fn = partial(make_mlp_model, mlp_size=core_size, **kwargs) else: core_mlp_fn = partial(make_mlp_model, **kwargs) self._core = InteractionNetwork(edge_model_fn=core_mlp_fn, node_model_fn=core_mlp_fn, reducer=tf.math.unsorted_segment_sum) self._output_transform = modules.GraphIndependent(edge_fn, None, None)
def __init__(self, edge_model_fn, node_model_fn, global_model_fn, reducer=tf.math.unsorted_segment_sum, edge_block_opt=None, node_block_opt=None, global_block_opt=None, name="graph_network"): """Initializes the GraphNetwork module. Args: edge_model_fn: A callable that will be passed to EdgeBlock to perform per-edge computations. The callable must return a Sonnet module (or equivalent; see EdgeBlock for details). node_model_fn: A callable that will be passed to NodeBlock to perform per-node computations. The callable must return a Sonnet module (or equivalent; see NodeBlock for details). global_model_fn: A callable that will be passed to GlobalBlock to perform per-global computations. The callable must return a Sonnet module (or equivalent; see GlobalBlock for details). reducer: Reducer to be used by NodeBlock and GlobalBlock to aggregate nodes and edges. Defaults to tf.unsorted_segment_sum. This will be overridden by the reducers specified in `node_block_opt` and `global_block_opt`, if any. edge_block_opt: Additional options to be passed to the EdgeBlock. Can contain keys `use_edges`, `use_receiver_nodes`, `use_sender_nodes`, `use_globals`. By default, these are all True. node_block_opt: Additional options to be passed to the NodeBlock. Can contain the keys `use_received_edges`, `use_sent_edges`, `use_nodes`, `use_globals` (all set to True by default), and `received_edges_reducer`, `sent_edges_reducer` (default to `reducer`). global_block_opt: Additional options to be passed to the GlobalBlock. Can contain the keys `use_edges`, `use_nodes`, `use_globals` (all set to True by default), and `edges_reducer`, `nodes_reducer` (defaults to `reducer`). name: The module name. """ super(GraphNetwork, self).__init__(name=name) edge_block_opt = _make_default_edge_block_opt(edge_block_opt) node_block_opt = _make_default_node_block_opt(node_block_opt, reducer) global_block_opt = _make_default_global_block_opt( global_block_opt, reducer) #with self._enter_variable_scope(): if 2 > 1: self._edge_block = blocks.EdgeBlock(edge_model_fn=edge_model_fn, **edge_block_opt) self._node_block = blocks.NodeBlock(node_model_fn=node_model_fn, **node_block_opt) self._global_block = blocks.GlobalBlock( global_model_fn=global_model_fn, **global_block_opt)
def __init__(self, edge_model_fn, node_model_fn, reducer=tf.math.unsorted_segment_sum, name="interaction_network"): super(InteractionNetwork, self).__init__(name=name) self._edge_block = blocks.EdgeBlock(edge_model_fn=edge_model_fn, use_globals=False) self._node_block = blocks.NodeBlock(node_model_fn=node_model_fn, use_received_edges=True, use_sent_edges=True, use_globals=False, received_edges_reducer=reducer)
def test_missing_aggregation_raises_exception( self, use_received_edges, use_sent_edges, received_edges_reducer, sent_edges_reducer): """Checks that missing a required aggregation argument raises an error.""" with self.assertRaisesRegexp(ValueError, "should not be None"): blocks.NodeBlock( node_model_fn=self._node_model_fn, use_received_edges=use_received_edges, use_sent_edges=use_sent_edges, use_nodes=False, use_globals=False, received_edges_reducer=received_edges_reducer, sent_edges_reducer=sent_edges_reducer)
def __init__(self, edge_model_fn, node_model_fn, global_model_fn, name=None): super(EncoderNetwork, self).__init__(name=name) self.node_block = blocks.NodeBlock(node_model_fn, use_received_edges=False, use_sent_edges=False, use_nodes=True, use_globals=False) self.relation_network = RelationNetwork(edge_model_fn=edge_model_fn, global_model_fn=global_model_fn)
def test_incompatible_higher_rank_inputs_no_raise(self, use_received_edges, use_sent_edges, use_nodes, use_globals, field): """No exception should occur if a differently shapped field is not used.""" input_graph = self._get_shaped_input_graph() input_graph = input_graph.replace( **{field: tf.transpose(getattr(input_graph, field), [0, 2, 1, 3])}) network = blocks.NodeBlock(functools.partial(snt.Conv2D, output_channels=10, kernel_shape=[3, 3]), use_received_edges=use_received_edges, use_sent_edges=use_sent_edges, use_nodes=use_nodes, use_globals=use_globals) self._assert_build_and_run(network, input_graph)
def test_incompatible_higher_rank_inputs_raises(self, use_received_edges, use_sent_edges, use_nodes, use_globals, field): """A exception should be raised if the inputs have incompatible shapes.""" input_graph = self._get_shaped_input_graph() input_graph = input_graph.replace( **{field: tf.transpose(getattr(input_graph, field), [0, 2, 1, 3])}) network = blocks.NodeBlock(functools.partial(snt.Conv2D, output_channels=10, kernel_shape=[3, 3]), use_received_edges=use_received_edges, use_sent_edges=use_sent_edges, use_nodes=use_nodes, use_globals=use_globals) with self.assertRaisesRegexp(ValueError, "in both shapes must be equal"): network(input_graph)
def test_same_as_subblocks(self, reducer, none_fields): """Compares the output to explicit subblocks output. Args: reducer: The reducer used in the NodeBlock. none_fields: (list of strings) The corresponding fields are removed from the input graph. """ input_graph = self._get_input_graph() input_graph = input_graph.map(lambda _: None, none_fields) deep_sets = self._get_model(reducer) output_graph = deep_sets(input_graph) output_nodes = output_graph.nodes output_globals = output_graph.globals node_block = blocks.NodeBlock( node_model_fn=lambda: deep_sets._node_block._node_model, use_received_edges=False, use_sent_edges=False, use_nodes=True, use_globals=True) global_block = blocks.GlobalBlock( global_model_fn=lambda: deep_sets._global_block._global_model, use_edges=False, use_nodes=True, use_globals=False, nodes_reducer=reducer) node_block_out = node_block(input_graph) expected_nodes = node_block_out.nodes expected_globals = global_block(node_block_out).globals self.assertAllEqual(input_graph.edges, output_graph.edges) self.assertAllEqual(input_graph.receivers, output_graph.receivers) self.assertAllEqual(input_graph.senders, output_graph.senders) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) (output_nodes_, output_globals_, expected_nodes_, expected_globals_) = sess.run( [output_nodes, output_globals, expected_nodes, expected_globals]) self._assert_all_none_or_all_close(expected_nodes_, output_nodes_) self._assert_all_none_or_all_close(expected_globals_, output_globals_)