예제 #1
0
    def __init__(self,
                 edge_model_fn,
                 node_encoder_model_fn,
                 node_model_fn,
                 reducer=tf.unsorted_segment_sum,
                 name="oricomm_net"):
        super(OriCommNet, self).__init__(name=name)

        with self._enter_variable_scope():
            # Computes $\Psi_{com}(x_j)$ in Eq. (2) of 1706.06122
            self._edge_block = blocks.EdgeBlock(edge_model_fn=edge_model_fn,
                                                use_edges=False,
                                                use_receiver_nodes=False,
                                                use_sender_nodes=True,
                                                use_globals=False)
            # Computes $\Phi(x_i)$ in Eq. (2) of 1706.06122
        self._node_encoder_block = blocks.NodeBlock(
            node_model_fn=node_encoder_model_fn,
            use_received_edges=False,
            use_sent_edges=False,
            use_nodes=True,
            use_globals=False,
            received_edges_reducer=reducer,
            name="node_encoder_block")
        # Computes $\Theta(..)$ in Eq.(2) of 1706.06122
        self._node_block = blocks.NodeBlock(node_model_fn=node_model_fn,
                                            use_received_edges=True,
                                            use_sent_edges=False,
                                            use_nodes=True,
                                            use_globals=False,
                                            received_edges_reducer=reducer)
예제 #2
0
  def test_same_as_subblocks(self, reducer, none_field=None):
    """Compares the output to explicit subblocks output.

    Args:
      reducer: The reducer used in the `NodeBlock`s.
      none_field: (string, default=None) If not None, the corresponding field
        is removed from the input graph.
    """
    input_graph = self._get_input_graph(none_field)

    comm_net = self._get_model(reducer)
    output_graph = comm_net(input_graph)
    output_nodes = output_graph.nodes

    edge_subblock = blocks.EdgeBlock(
        edge_model_fn=lambda: comm_net._edge_block._edge_model,
        use_edges=False,
        use_receiver_nodes=False,
        use_sender_nodes=True,
        use_globals=False)
    node_encoder_subblock = blocks.NodeBlock(
        node_model_fn=lambda: comm_net._node_encoder_block._node_model,
        use_received_edges=False,
        use_sent_edges=False,
        use_nodes=True,
        use_globals=False,
        received_edges_reducer=reducer)
    node_subblock = blocks.NodeBlock(
        node_model_fn=lambda: comm_net._node_block._node_model,
        use_received_edges=True,
        use_sent_edges=False,
        use_nodes=True,
        use_globals=False,
        received_edges_reducer=reducer)

    edge_block_out = edge_subblock(input_graph)
    encoded_nodes = node_encoder_subblock(input_graph).nodes
    node_input_graph = input_graph.replace(
        edges=edge_block_out.edges, nodes=encoded_nodes)
    node_block_out = node_subblock(node_input_graph)
    expected_nodes = node_block_out.nodes

    self.assertAllEqual(input_graph.globals, output_graph.globals)
    self.assertAllEqual(input_graph.edges, output_graph.edges)
    self.assertAllEqual(input_graph.receivers, output_graph.receivers,)
    self.assertAllEqual(input_graph.senders, output_graph.senders)

    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      actual_nodes_output, expected_nodes_output = sess.run(
          [output_nodes, expected_nodes])

    self._assert_all_none_or_all_close(expected_nodes_output,
                                       actual_nodes_output)
예제 #3
0
    def __init__(self,
                 edge_model_fn,
                 node_encoder_model_fn,
                 node_model_fn,
                 reducer=tf.math.unsorted_segment_sum,
                 name="comm_net"):
        """Initializes the CommNet module.

    Args:
      edge_model_fn: A callable to be passed to EdgeBlock. The callable must
        return a Sonnet module (or equivalent; see EdgeBlock for details).
      node_encoder_model_fn: A callable to be passed to the NodeBlock
        responsible for the first encoding of the nodes. The callable must
        return a Sonnet module (or equivalent; see NodeBlock for details). The
        shape of this module's output should match the shape of the module built
        by `edge_model_fn`, but for the first and last dimension.
      node_model_fn: A callable to be passed to NodeBlock. The callable must
        return a Sonnet module (or equivalent; see NodeBlock for details).
      reducer: Reduction to be used when aggregating the edges in the nodes.
        This should be a callable whose signature matches
        tf.unsorted_segment_sum.
      name: The module name.
    """
        super(CommNet, self).__init__(name=name)

        #with self._enter_variable_scope():
        if 2 > 1:
            # Computes $\Psi_{com}(x_j)$ in Eq. (2) of 1706.06122
            self._edge_block = blocks.EdgeBlock(edge_model_fn=edge_model_fn,
                                                use_edges=False,
                                                use_receiver_nodes=False,
                                                use_sender_nodes=True,
                                                use_globals=False)
            # Computes $\Phi(x_i)$ in Eq. (2) of 1706.06122
            self._node_encoder_block = blocks.NodeBlock(
                node_model_fn=node_encoder_model_fn,
                use_received_edges=False,
                use_sent_edges=False,
                use_nodes=True,
                use_globals=False,
                received_edges_reducer=reducer,
                name="node_encoder_block")
            # Computes $\Theta(..)$ in Eq.(2) of 1706.06122
            self._node_block = blocks.NodeBlock(node_model_fn=node_model_fn,
                                                use_received_edges=True,
                                                use_sent_edges=False,
                                                use_nodes=True,
                                                use_globals=False,
                                                received_edges_reducer=reducer)
예제 #4
0
 def test_compatible_higher_rank_no_raise(self):
   """No exception should occur with higher ranks tensors."""
   input_graph = self._get_shaped_input_graph()
   input_graph = input_graph.map(lambda v: tf.transpose(v, [0, 2, 1, 3]))
   network = blocks.NodeBlock(
       functools.partial(snt.Conv2D, output_channels=10, kernel_shape=[3, 3]))
   self._assert_build_and_run(network, input_graph)
예제 #5
0
파일: GNN.py 프로젝트: YannickCharles/RQ1
    def __init__(self, node_model_fn, global_model_fn, name="rwrd_gnn"):
        # aggregator_fn = tf.math.unsorted_segment_sum,
        """Initializes the reward model"""

        super(GraphNeuralNetwork_reward, self).__init__(name=name)

        with self._enter_variable_scope():
            # self._edge_block = blocks.EdgeBlock(
            #     edge_model_fn=edge_model_fn,
            #     use_edges=False,
            #     use_receiver_nodes=True,
            #     use_sender_nodes=True,
            #     use_globals=True,
            #     name='edge_block')

            self._node_block = blocks.NodeBlock(
                node_model_fn=node_model_fn,
                use_received_edges=False,
                use_sent_edges=False,
                use_nodes=True,
                use_globals=True,
                received_edges_reducer=tf.math.unsorted_segment_sum,
                sent_edges_reducer=tf.math.unsorted_segment_sum,
                name="node_block")

            self._global_block = blocks.GlobalBlock(
                global_model_fn=global_model_fn,
                use_edges=False,
                use_nodes=True,
                use_globals=True,
                nodes_reducer=tf.math.unsorted_segment_sum,
                edges_reducer=tf.math.unsorted_segment_sum,
                name="global_block")
예제 #6
0
    def __init__(self, conf, name="encoder-attention-tsp"):
        """Inits the module.

        Args:
            name: The module name.
        """

        super(Encoder, self).__init__(name=name)
        self.conf = conf
        self.training = True
        with self._enter_variable_scope():
            self._initial_projection = snt.Linear(
                output_size=self.conf.embedding_dim,
                initializers={
                    'w': utils.initializer(conf.init_dim),
                    'b': utils.initializer(conf.init_dim)
                },
                name="initial_projection")
            self._initial_projection_block = blocks.NodeBlock(
                lambda: self._initial_projection,
                use_received_edges=False,
                use_nodes=True,
                use_globals=False,
                name="initial_block_projection")
            self._encoder_layers = [
                EncoderLayer(conf, "encoder_layer_%i" % i)
                for i in range(self.conf.encoder_nbr_layers)
            ]
예제 #7
0
  def test_unused_field_can_be_none(
      self, use_edges, use_nodes, use_globals, none_field):
    """Checks that computation can handle non-necessary fields left None."""
    input_graph = self._get_input_graph([none_field])
    node_block = blocks.NodeBlock(
        node_model_fn=self._node_model_fn,
        use_received_edges=use_edges,
        use_sent_edges=use_edges,
        use_nodes=use_nodes,
        use_globals=use_globals)
    output_graph = node_block(input_graph)

    model_inputs = []
    if use_edges:
      model_inputs.append(
          blocks.ReceivedEdgesToNodesAggregator(
              tf.unsorted_segment_sum)(input_graph))
      model_inputs.append(
          blocks.SentEdgesToNodesAggregator(
              tf.unsorted_segment_sum)(input_graph))
    if use_nodes:
      model_inputs.append(input_graph.nodes)
    if use_globals:
      model_inputs.append(blocks.broadcast_globals_to_nodes(input_graph))

    model_inputs = tf.concat(model_inputs, axis=-1)
    self.assertEqual(input_graph.edges, output_graph.edges)
    self.assertEqual(input_graph.globals, output_graph.globals)

    with self.test_session() as sess:
      actual_nodes, model_inputs_out = sess.run(
          (output_graph.nodes, model_inputs))

    expected_output_nodes = model_inputs_out * self._scale
    self.assertNDArrayNear(expected_output_nodes, actual_nodes, err=1e-4)
예제 #8
0
    def __init__(self, name="DecaySimulator"):
        super(DecaySimulator, self).__init__(name=name)

        self._node_linear = make_mlp_model()
        self._node_rnn = snt.GRU(hidden_size=LATENT_SIZE, name='node_rnn')
        self._node_proper = snt.nets.MLP([4], activate_final=False)

        self._edge_block = blocks.EdgeBlock(edge_model_fn=make_mlp_model,
                                            use_edges=False,
                                            use_receiver_nodes=True,
                                            use_sender_nodes=True,
                                            use_globals=False,
                                            name='edge_encoder_block')
        self._node_encoder_block = blocks.NodeBlock(
            node_model_fn=make_mlp_model,
            use_received_edges=False,
            use_sent_edges=False,
            use_nodes=True,
            use_globals=False,
            name='node_encoder_block')

        self._global_encoder_block = blocks.GlobalBlock(
            global_model_fn=make_mlp_model,
            use_edges=True,
            use_nodes=True,
            use_globals=False,
            nodes_reducer=tf.math.unsorted_segment_sum,
            edges_reducer=tf.math.unsorted_segment_sum,
            name='global_encoder_block')

        self._core = MLPGraphNetwork()

        # self._core = InteractionNetwork(
        #     edge_model_fn=make_mlp_model,
        #     node_model_fn=make_mlp_model,
        #     reducer=tf.math.unsorted_segment_sum
        # )

        # # Transforms the outputs into appropriate shapes.
        node_output_size = 64
        node_fn = lambda: snt.Sequential([
            snt.nets.MLP(
                [node_output_size],
                activation=tf.nn.relu,  # default is relu
                name='node_output')
        ])

        global_output_size = 1
        global_fn = lambda: snt.Sequential([
            snt.nets.MLP(
                [global_output_size],
                activation=tf.nn.relu,  # default is relu
                name='global_output'),
            tf.sigmoid
        ])

        self._output_transform = modules.GraphIndependent(
            edge_model_fn=None,
            node_model_fn=node_fn,
            global_model_fn=global_fn)
예제 #9
0
    def __init__(self, name="DeepGraphInfoMax"):
        super(DeepGraphInfoMax, self).__init__(name=name)

        self._edge_block = blocks.EdgeBlock(
            edge_model_fn=lambda: snt.nets.MLP([LATENT_SIZE] * 2,
                                               activation=tf.nn.relu,
                                               activate_final=True,
                                               use_dropout=True),
            use_edges=False,
            use_receiver_nodes=True,
            use_sender_nodes=True,
            use_globals=False,
            name='edge_encoder_block')
        self._node_encoder_block = blocks.NodeBlock(
            node_model_fn=make_mlp_model,
            use_received_edges=False,
            use_sent_edges=False,
            use_nodes=True,
            use_globals=False,
            name='node_encoder_block')

        self._core = modules.InteractionNetwork(
            edge_model_fn=make_mlp_model,
            node_model_fn=make_mlp_model,
            reducer=tf.unsorted_segment_sum)
예제 #10
0
 def __init__(self, node_model_fn, name=None):
     super(DecoderNetwork, self).__init__(name=name)
     self.node_block = blocks.NodeBlock(node_model_fn,
                                        use_received_edges=False,
                                        use_sent_edges=False,
                                        use_nodes=False,
                                        use_globals=True)
예제 #11
0
    def __init__(self, name="SegmentClassifier"):
        super(SegmentClassifier, self).__init__(name=name)

        self._edge_block = blocks.EdgeBlock(edge_model_fn=make_mlp_model,
                                            use_edges=False,
                                            use_receiver_nodes=True,
                                            use_sender_nodes=True,
                                            use_globals=False,
                                            name='edge_encoder_block')
        self._node_encoder_block = blocks.NodeBlock(
            node_model_fn=make_mlp_model,
            use_received_edges=False,
            use_sent_edges=False,
            use_nodes=True,
            use_globals=False,
            name='node_encoder_block')

        self._core = InteractionNetwork(edge_model_fn=make_mlp_model,
                                        node_model_fn=make_mlp_model,
                                        reducer=tf.math.unsorted_segment_sum)

        # Transforms the outputs into appropriate shapes.
        edge_output_size = 1
        edge_fn = lambda: snt.Sequential([
            snt.nets.MLP(
                [edge_output_size],
                activation=tf.nn.relu,  # default is relu
                name='edge_output'),
            tf.sigmoid
        ])

        self._output_transform = modules.GraphIndependent(edge_fn, None, None)
예제 #12
0
  def test_same_as_subblocks(self, reducer):
    """Compares the output to explicit subblocks output.

    Args:
      reducer: The reducer used in the `NodeBlock` and `GlobalBlock`.
    """
    input_graph = self._get_input_graph()

    edge_model_fn = functools.partial(snt.Linear, output_size=5)
    node_model_fn = functools.partial(snt.Linear, output_size=10)
    global_model_fn = functools.partial(snt.Linear, output_size=15)

    graph_network = modules.GraphNetwork(
        edge_model_fn=edge_model_fn,
        node_model_fn=node_model_fn,
        global_model_fn=global_model_fn,
        reducer=reducer)

    output_graph = graph_network(input_graph)

    edge_block = blocks.EdgeBlock(
        edge_model_fn=lambda: graph_network._edge_block._edge_model,
        use_sender_nodes=True,
        use_edges=True,
        use_receiver_nodes=True,
        use_globals=True)
    node_block = blocks.NodeBlock(
        node_model_fn=lambda: graph_network._node_block._node_model,
        use_nodes=True,
        use_sent_edges=False,
        use_received_edges=True,
        use_globals=True,
        received_edges_reducer=reducer)
    global_block = blocks.GlobalBlock(
        global_model_fn=lambda: graph_network._global_block._global_model,
        use_nodes=True,
        use_edges=True,
        use_globals=True,
        edges_reducer=reducer,
        nodes_reducer=reducer)

    expected_output_edge_block = edge_block(input_graph)
    expected_output_node_block = node_block(expected_output_edge_block)
    expected_output_global_block = global_block(expected_output_node_block)
    expected_edges = expected_output_edge_block.edges
    expected_nodes = expected_output_node_block.nodes
    expected_globals = expected_output_global_block.globals

    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      (output_graph_out,
       expected_edges_out, expected_nodes_out, expected_globals_out) = sess.run(
           (output_graph, expected_edges, expected_nodes, expected_globals))

    self._assert_all_none_or_all_close(expected_edges_out,
                                       output_graph_out.edges)
    self._assert_all_none_or_all_close(expected_nodes_out,
                                       output_graph_out.nodes)
    self._assert_all_none_or_all_close(expected_globals_out,
                                       output_graph_out.globals)
예제 #13
0
    def __init__(self,
                 edge_model_fn,
                 node_model_fn,
                 reducer=tf.unsorted_segment_sum,
                 name="interaction_network"):
        """Initializes the InteractionNetwork module.

    Args:
      edge_model_fn: A callable that will be passed to `EdgeBlock` to perform
        per-edge computations. The callable must return a Sonnet module (or
        equivalent; see `blocks.EdgeBlock` for details), and the shape of the
        output of this module must match the one of the input nodes, but for the
        first and last axis.
      node_model_fn: A callable that will be passed to `NodeBlock` to perform
        per-node computations. The callable must return a Sonnet module (or
        equivalent; see `blocks.NodeBlock` for details).
      reducer: Reducer to be used by NodeBlock to aggregate edges. Defaults to
        tf.unsorted_segment_sum.
      name: The module name.
    """
        super(InteractionNetwork, self).__init__(name=name)

        with self._enter_variable_scope():
            self._edge_block = blocks.EdgeBlock(edge_model_fn=edge_model_fn,
                                                use_globals=False)
            self._node_block = blocks.NodeBlock(node_model_fn=node_model_fn,
                                                use_received_edges=True,
                                                use_sent_edges=True,
                                                use_globals=False,
                                                received_edges_reducer=reducer)
예제 #14
0
    def __init__(self, name="FourTopPredictor"):
        super(FourTopPredictor, self).__init__(name=name)

        self._edge_block = blocks.EdgeBlock(edge_model_fn=make_mlp_model,
                                            use_edges=False,
                                            use_receiver_nodes=True,
                                            use_sender_nodes=True,
                                            use_globals=False,
                                            name='edge_encoder_block')
        self._node_encoder_block = blocks.NodeBlock(
            node_model_fn=make_mlp_model,
            use_received_edges=False,
            use_sent_edges=False,
            use_nodes=True,
            use_globals=False,
            name='node_encoder_block')

        self._global_block = blocks.GlobalBlock(
            global_model_fn=make_mlp_model,
            use_edges=True,
            use_nodes=True,
            use_globals=False,
        )

        self._core = MLPGraphNetwork()

        # Transforms the outputs into appropriate shapes.
        global_output_size = n_target_node_features * n_max_tops
        self._global_nn = snt.nets.MLP(
            [128, 128, global_output_size],
            activation=tf.nn.leaky_relu,  # default is relu, tanh
            dropout_rate=0.30,
            name='global_output')
예제 #15
0
    def __init__(self,
                 node_model_fn,
                 global_model_fn,
                 reducer=tf.unsorted_segment_sum,
                 name="deep_sets"):
        """Initializes the DeepSets module.

    Args:
      node_model_fn: A callable to be passed to NodeBlock. The callable must
        return a Sonnet module (or equivalent; see NodeBlock for details). The
        shape of this module's output must equal the shape of the input graph's
        global features, but for the first and last axis.
      global_model_fn: A callable to be passed to GlobalBlock. The callable must
        return a Sonnet module (or equivalent; see GlobalBlock for details).
      reducer: Reduction to be used when aggregating the nodes in the globals.
        This should be a callable whose signature matches
        tf.unsorted_segment_sum.
      name: The module name.
    """
        super(DeepSets, self).__init__(name=name)

        with self._enter_variable_scope():
            self._node_block = blocks.NodeBlock(node_model_fn=node_model_fn,
                                                use_received_edges=False,
                                                use_sent_edges=False,
                                                use_nodes=True,
                                                use_globals=True)
            self._global_block = blocks.GlobalBlock(
                global_model_fn=global_model_fn,
                use_edges=False,
                use_nodes=True,
                use_globals=False,
                nodes_reducer=reducer)
예제 #16
0
    def __init__(self, name="GlobalClassifierNoEdgeInfo"):
        super(GlobalClassifierNoEdgeInfo, self).__init__(name=name)

        self._edge_block = blocks.EdgeBlock(edge_model_fn=make_mlp_model,
                                            use_edges=False,
                                            use_receiver_nodes=True,
                                            use_sender_nodes=True,
                                            use_globals=False,
                                            name='edge_encoder_block')

        self._node_encoder_block = blocks.NodeBlock(
            node_model_fn=make_mlp_model,
            use_received_edges=False,
            use_sent_edges=False,
            use_nodes=True,
            use_globals=False,
            name='node_encoder_block')

        self._global_block = blocks.GlobalBlock(
            global_model_fn=make_mlp_model,
            use_edges=True,
            use_nodes=True,
            use_globals=False,
        )

        self._core = MLPGraphNetwork()
        # Transforms the outputs into appropriate shapes.
        global_output_size = 1
        global_fn = lambda: snt.Sequential([
            snt.nets.MLP([LATENT_SIZE, global_output_size],
                         name='global_output'), tf.sigmoid
        ])

        self._output_transform = modules.GraphIndependent(
            None, None, global_fn)
예제 #17
0
    def test_created_variables(self, use_received_edges, use_sent_edges,
                               use_nodes, use_globals, expected_first_dim_w):
        """Verifies the variable names and shapes created by a NodeBlock."""
        output_size = 10
        expected_var_shapes_dict = {
            "node_block/mlp/linear_0/b:0": [output_size],
            "node_block/mlp/linear_0/w:0": [expected_first_dim_w, output_size]
        }

        input_graph = self._get_input_graph()

        node_block = blocks.NodeBlock(node_model_fn=functools.partial(
            snt.nets.MLP, output_sizes=[output_size]),
                                      use_received_edges=use_received_edges,
                                      use_sent_edges=use_sent_edges,
                                      use_nodes=use_nodes,
                                      use_globals=use_globals)

        node_block(input_graph)

        variables = node_block.get_variables()
        var_shapes_dict = {
            var.name: var.get_shape().as_list()
            for var in variables
        }
        self.assertDictEqual(expected_var_shapes_dict, var_shapes_dict)
예제 #18
0
 def test_no_input_raises_exception(self):
     """Checks that receiving no input raises an exception."""
     with self.assertRaisesRegexp(ValueError, "At least one of "):
         blocks.NodeBlock(node_model_fn=self._node_model_fn,
                          use_received_edges=False,
                          use_sent_edges=False,
                          use_nodes=False,
                          use_globals=False)
예제 #19
0
    def __init__(self,
                 mlp_size=16,
                 cluster_encoded_size=10,
                 num_heads=10,
                 core_steps=10,
                 name=None):
        super(Model, self).__init__(name=name)

        self.epd_encoder = EncodeProcessDecode_E(
            encoder=EncoderNetwork(edge_model_fn=lambda: snt.nets.MLP(
                [mlp_size], activate_final=True, activation=tf.nn.leaky_relu),
                                   node_model_fn=lambda: snt.Linear(
                                       cluster_encoded_size),
                                   global_model_fn=lambda: snt.
                                   nets.MLP([mlp_size],
                                            activate_final=True,
                                            activation=tf.nn.leaky_relu)),
            core=CoreNetwork(num_heads=num_heads,
                             multi_head_output_size=cluster_encoded_size,
                             input_node_size=cluster_encoded_size),
            decoder=EncoderNetwork(edge_model_fn=lambda: snt.nets.MLP(
                [mlp_size], activate_final=True, activation=tf.nn.leaky_relu),
                                   node_model_fn=lambda: snt.Linear(
                                       cluster_encoded_size),
                                   global_model_fn=lambda: snt.nets.MLP(
                                       [32, 32, 64],
                                       activate_final=True,
                                       activation=tf.nn.leaky_relu)))

        self.epd_decoder = EncodeProcessDecode_D(
            encoder=DecoderNetwork(node_model_fn=lambda: snt.nets.MLP(
                [32, 32, cluster_encoded_size],
                activate_final=True,
                activation=tf.nn.leaky_relu)),
            core=CoreNetwork(num_heads=num_heads,
                             multi_head_output_size=cluster_encoded_size,
                             input_node_size=cluster_encoded_size),
            decoder=snt.Sequential([
                RelationNetwork(edge_model_fn=lambda: snt.nets.MLP(
                    [mlp_size],
                    activate_final=True,
                    activation=tf.nn.leaky_relu),
                                global_model_fn=lambda: snt.nets.MLP(
                                    [mlp_size],
                                    activate_final=True,
                                    activation=tf.nn.leaky_relu)),
                blocks.NodeBlock(node_model_fn=lambda: snt.nets.MLP(
                    [cluster_encoded_size - 3],
                    activate_final=True,
                    activation=tf.nn.leaky_relu),
                                 use_received_edges=True,
                                 use_sent_edges=True,
                                 use_nodes=True,
                                 use_globals=True)
            ]))

        self._core_steps = core_steps
예제 #20
0
  def test_optional_arguments(self, scale, offset):
    """Assesses the correctness of the NodeBlock using arguments."""
    input_graph = self._get_input_graph()
    node_block = blocks.NodeBlock(node_model_fn=self._node_model_args_fn)
    output_graph_out = node_block(
        input_graph, node_model_kwargs=dict(scale=scale, offset=offset))

    fixed_scale = scale
    fixed_offset = offset
    model_fn = lambda: lambda features: features * fixed_scale + fixed_offset
    hardcoded_node_block = blocks.NodeBlock(node_model_fn=model_fn)
    expected_graph_out = hardcoded_node_block(input_graph)

    self.assertIs(expected_graph_out.edges, output_graph_out.edges)
    self.assertIs(expected_graph_out.globals, output_graph_out.globals)
    self.assertNDArrayNear(
        expected_graph_out.nodes.numpy(),
        output_graph_out.nodes.numpy(),
        err=1e-4)
예제 #21
0
def tinet(input_graph):
    embedding = blocks.NodeBlock(
        node_model_fn=lambda: tf.keras.layers.Embedding(800, 32),
        use_received_edges=False,
        use_sent_edges=False,
        use_nodes=True,
        use_globals=False)
    graph_network_layer1 = blocks.NodeBlock(
        # edge_model_fn = lambda: tf.layers.Dense(16, activation=tf.nn.relu),
        node_model_fn=lambda: tf.layers.Dense(32, activation=tf.nn.relu))
    # global_model_fn = lambda: tf.layers.Dense(8, activation=tf.nn.relu))
    graph_network_layer2 = blocks.NodeBlock(
        # edge_model_fn = lambda: tf.layers.Dense(16, activation=tf.nn.relu),
        node_model_fn=lambda: tf.layers.Dense(32, activation=tf.nn.relu))
    # global_model_fn = lambda: tf.layers.Dense(8, activation=tf.nn.relu))
    graph_network_layer3 = blocks.NodeBlock(
        # edge_model_fn = lambda: tf.layers.Dense(16, activation=tf.nn.relu),
        node_model_fn=lambda: tf.layers.Dense(32, activation=tf.nn.relu))
    # global_model_fn = lambda: tf.layers.Dense(8, activation=tf.nn.relu))
    graph_network_layer4 = blocks.NodeBlock(
        # edge_model_fn = lambda: tf.layers.Dense(16, activation=tf.nn.relu),
        node_model_fn=lambda: tf.layers.Dense(32, activation=tf.nn.relu))
    # global_model_fn = lambda: tf.layers.Dense(8, activation=tf.nn.relu))
    graph_network_layer5 = blocks.GlobalBlock(
        # edge_model_fn = lambda: tf.layers.Dense(16, activation=tf.nn.relu),
        # node_model_fn = lambda: tf.layers.Dense(16, activation=tf.nn.relu),
        global_model_fn=lambda: tf.layers.Dense(40, activation=tf.nn.relu))
    graph_network_layer6 = blocks.GlobalBlock(
        # edge_model_fn = lambda: tf.layers.Dense(16, activation=tf.nn.relu),
        # node_model_fn = lambda: tf.layers.Dense(16, activation=tf.nn.relu),
        global_model_fn=lambda: tf.layers.Dense(40, activation=tf.nn.relu))

    h0 = embedding(input_graph)
    h1 = graph_network_layer1(h0)
    h2 = graph_network_layer2(h1)
    h3 = graph_network_layer3(h2)
    h4 = graph_network_layer4(h3)
    h5 = graph_network_layer5(h4)
    h6 = graph_network_layer6(h5)

    out = h6.globals

    return tf.layers.dense(out, 4, activation=None)
예제 #22
0
 def test_missing_field_raises_exception(self, use_received_edges,
                                         use_sent_edges, use_nodes,
                                         use_globals, none_fields):
     """Checks that missing a required field raises an exception."""
     input_graph = self._get_input_graph(none_fields)
     node_block = blocks.NodeBlock(node_model_fn=self._node_model_fn,
                                   use_received_edges=use_received_edges,
                                   use_sent_edges=use_sent_edges,
                                   use_nodes=use_nodes,
                                   use_globals=use_globals)
     with self.assertRaisesRegexp(ValueError, "field cannot be None"):
         node_block(input_graph)
예제 #23
0
    def __init__(self,
                 edge_fn,
                 with_edge_inputs=False,
                 with_node_inputs=True,
                 encoder_size: list = None,
                 core_size: list = None,
                 name="EdgeLearnerBase",
                 **kwargs):
        super(EdgeLearnerBase, self).__init__(name=name)

        if encoder_size is not None:
            encoder_mlp_fn = partial(make_mlp_model,
                                     mlp_size=encoder_size,
                                     **kwargs)
        else:
            encoder_mlp_fn = partial(make_mlp_model, **kwargs)

        edge_block_args = dict(use_edges=False,
                               use_receiver_nodes=True,
                               use_sender_nodes=True,
                               use_globals=False)
        node_block_args = dict(use_received_edges=False,
                               use_sent_edges=False,
                               use_nodes=True,
                               use_globals=False)
        if with_edge_inputs:
            edge_block_args['use_edges'] = True
            node_block_args['use_received_edges'] = True
            node_block_args['use_sent_edges'] = True
        if not with_node_inputs:
            edge_block_args['use_receiver_nodes'] = False
            edge_block_args['use_sender_nodes'] = False
            node_block_args['use_nodes'] = False

        self._edge_block = blocks.EdgeBlock(edge_model_fn=encoder_mlp_fn,
                                            **edge_block_args,
                                            name='edge_encoder_block')

        self._node_encoder_block = blocks.NodeBlock(
            node_model_fn=encoder_mlp_fn,
            **node_block_args,
            name='node_encoder_block')

        if core_size is not None:
            core_mlp_fn = partial(make_mlp_model, mlp_size=core_size, **kwargs)
        else:
            core_mlp_fn = partial(make_mlp_model, **kwargs)

        self._core = InteractionNetwork(edge_model_fn=core_mlp_fn,
                                        node_model_fn=core_mlp_fn,
                                        reducer=tf.math.unsorted_segment_sum)

        self._output_transform = modules.GraphIndependent(edge_fn, None, None)
예제 #24
0
    def __init__(self,
                 edge_model_fn,
                 node_model_fn,
                 global_model_fn,
                 reducer=tf.math.unsorted_segment_sum,
                 edge_block_opt=None,
                 node_block_opt=None,
                 global_block_opt=None,
                 name="graph_network"):
        """Initializes the GraphNetwork module.

    Args:
      edge_model_fn: A callable that will be passed to EdgeBlock to perform
        per-edge computations. The callable must return a Sonnet module (or
        equivalent; see EdgeBlock for details).
      node_model_fn: A callable that will be passed to NodeBlock to perform
        per-node computations. The callable must return a Sonnet module (or
        equivalent; see NodeBlock for details).
      global_model_fn: A callable that will be passed to GlobalBlock to perform
        per-global computations. The callable must return a Sonnet module (or
        equivalent; see GlobalBlock for details).
      reducer: Reducer to be used by NodeBlock and GlobalBlock to aggregate
        nodes and edges. Defaults to tf.unsorted_segment_sum. This will be
        overridden by the reducers specified in `node_block_opt` and
        `global_block_opt`, if any.
      edge_block_opt: Additional options to be passed to the EdgeBlock. Can
        contain keys `use_edges`, `use_receiver_nodes`, `use_sender_nodes`,
        `use_globals`. By default, these are all True.
      node_block_opt: Additional options to be passed to the NodeBlock. Can
        contain the keys `use_received_edges`, `use_sent_edges`, `use_nodes`,
        `use_globals` (all set to True by default), and
        `received_edges_reducer`, `sent_edges_reducer` (default to `reducer`).
      global_block_opt: Additional options to be passed to the GlobalBlock. Can
        contain the keys `use_edges`, `use_nodes`, `use_globals` (all set to
        True by default), and `edges_reducer`, `nodes_reducer` (defaults to
        `reducer`).
      name: The module name.
    """
        super(GraphNetwork, self).__init__(name=name)
        edge_block_opt = _make_default_edge_block_opt(edge_block_opt)
        node_block_opt = _make_default_node_block_opt(node_block_opt, reducer)
        global_block_opt = _make_default_global_block_opt(
            global_block_opt, reducer)

        #with self._enter_variable_scope():
        if 2 > 1:
            self._edge_block = blocks.EdgeBlock(edge_model_fn=edge_model_fn,
                                                **edge_block_opt)
            self._node_block = blocks.NodeBlock(node_model_fn=node_model_fn,
                                                **node_block_opt)
            self._global_block = blocks.GlobalBlock(
                global_model_fn=global_model_fn, **global_block_opt)
예제 #25
0
파일: base.py 프로젝트: rkunnawa/root_gnn
 def __init__(self,
              edge_model_fn,
              node_model_fn,
              reducer=tf.math.unsorted_segment_sum,
              name="interaction_network"):
     super(InteractionNetwork, self).__init__(name=name)
     self._edge_block = blocks.EdgeBlock(edge_model_fn=edge_model_fn,
                                         use_globals=False)
     self._node_block = blocks.NodeBlock(node_model_fn=node_model_fn,
                                         use_received_edges=True,
                                         use_sent_edges=True,
                                         use_globals=False,
                                         received_edges_reducer=reducer)
예제 #26
0
 def test_missing_aggregation_raises_exception(
     self, use_received_edges, use_sent_edges,
     received_edges_reducer, sent_edges_reducer):
   """Checks that missing a required aggregation argument raises an error."""
   with self.assertRaisesRegexp(ValueError, "should not be None"):
     blocks.NodeBlock(
         node_model_fn=self._node_model_fn,
         use_received_edges=use_received_edges,
         use_sent_edges=use_sent_edges,
         use_nodes=False,
         use_globals=False,
         received_edges_reducer=received_edges_reducer,
         sent_edges_reducer=sent_edges_reducer)
예제 #27
0
 def __init__(self,
              edge_model_fn,
              node_model_fn,
              global_model_fn,
              name=None):
     super(EncoderNetwork, self).__init__(name=name)
     self.node_block = blocks.NodeBlock(node_model_fn,
                                        use_received_edges=False,
                                        use_sent_edges=False,
                                        use_nodes=True,
                                        use_globals=False)
     self.relation_network = RelationNetwork(edge_model_fn=edge_model_fn,
                                             global_model_fn=global_model_fn)
예제 #28
0
 def test_incompatible_higher_rank_inputs_no_raise(self, use_received_edges,
                                                   use_sent_edges,
                                                   use_nodes, use_globals,
                                                   field):
     """No exception should occur if a differently shapped field is not used."""
     input_graph = self._get_shaped_input_graph()
     input_graph = input_graph.replace(
         **{field: tf.transpose(getattr(input_graph, field), [0, 2, 1, 3])})
     network = blocks.NodeBlock(functools.partial(snt.Conv2D,
                                                  output_channels=10,
                                                  kernel_shape=[3, 3]),
                                use_received_edges=use_received_edges,
                                use_sent_edges=use_sent_edges,
                                use_nodes=use_nodes,
                                use_globals=use_globals)
     self._assert_build_and_run(network, input_graph)
예제 #29
0
 def test_incompatible_higher_rank_inputs_raises(self, use_received_edges,
                                                 use_sent_edges, use_nodes,
                                                 use_globals, field):
     """A exception should be raised if the inputs have incompatible shapes."""
     input_graph = self._get_shaped_input_graph()
     input_graph = input_graph.replace(
         **{field: tf.transpose(getattr(input_graph, field), [0, 2, 1, 3])})
     network = blocks.NodeBlock(functools.partial(snt.Conv2D,
                                                  output_channels=10,
                                                  kernel_shape=[3, 3]),
                                use_received_edges=use_received_edges,
                                use_sent_edges=use_sent_edges,
                                use_nodes=use_nodes,
                                use_globals=use_globals)
     with self.assertRaisesRegexp(ValueError,
                                  "in both shapes must be equal"):
         network(input_graph)
예제 #30
0
  def test_same_as_subblocks(self, reducer, none_fields):
    """Compares the output to explicit subblocks output.

    Args:
      reducer: The reducer used in the NodeBlock.
      none_fields: (list of strings) The corresponding fields are removed from
        the input graph.
    """
    input_graph = self._get_input_graph()
    input_graph = input_graph.map(lambda _: None, none_fields)

    deep_sets = self._get_model(reducer)

    output_graph = deep_sets(input_graph)
    output_nodes = output_graph.nodes
    output_globals = output_graph.globals

    node_block = blocks.NodeBlock(
        node_model_fn=lambda: deep_sets._node_block._node_model,
        use_received_edges=False,
        use_sent_edges=False,
        use_nodes=True,
        use_globals=True)
    global_block = blocks.GlobalBlock(
        global_model_fn=lambda: deep_sets._global_block._global_model,
        use_edges=False,
        use_nodes=True,
        use_globals=False,
        nodes_reducer=reducer)

    node_block_out = node_block(input_graph)
    expected_nodes = node_block_out.nodes
    expected_globals = global_block(node_block_out).globals

    self.assertAllEqual(input_graph.edges, output_graph.edges)
    self.assertAllEqual(input_graph.receivers, output_graph.receivers)
    self.assertAllEqual(input_graph.senders, output_graph.senders)

    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      (output_nodes_, output_globals_, expected_nodes_,
       expected_globals_) = sess.run(
           [output_nodes, output_globals, expected_nodes, expected_globals])

    self._assert_all_none_or_all_close(expected_nodes_, output_nodes_)
    self._assert_all_none_or_all_close(expected_globals_, output_globals_)