示例#1
0
    def _build(self,
               input_graph,
               hidden_size=50,
               attn_scale=1.0,
               attn_dropout_keep_prob=1.0,
               regularizer=None,
               is_training=False):

        node_values = input_graph.nodes
        edge_values = input_graph.edges

        value_dims = node_values.shape[-1].value
        assert value_dims == edge_values.shape[-1].value

        # Compute edge values, sender feature + edge feature.
        # - edge_values = [total_num_edges, value_dims]
        edge_value_block = blocks.EdgeBlock(edge_model_fn=lambda: snt.Linear(
            output_size=value_dims, regularizers={'w': regularizer}),
                                            use_edges=True,
                                            use_receiver_nodes=True,
                                            use_sender_nodes=True,
                                            use_globals=False,
                                            name='update_edge_values')
        edge_values = edge_value_block(input_graph).edges
        tf.summary.histogram('mpnn/edge_values', edge_values)

        logits_block = blocks.EdgeBlock(
            edge_model_fn=lambda: snt.Linear(output_size=1,
                                             regularizers={'w': regularizer}),
            # edge_model_fn=lambda: snt.nets.MLP(output_sizes=[hidden_size, 1],
            #                                    activation=tf.nn.tanh,
            #                                    regularizers={'w': regularizer}),
            use_edges=True,
            use_receiver_nodes=True,
            use_sender_nodes=True,
            use_globals=False,
            name='update_attention_logits')
        attention_weights_logits = attn_scale * logits_block(input_graph).edges
        tf.summary.histogram('mpnn/logits', attention_weights_logits)

        normalized_attention_weight = modules._received_edges_normalizer(
            input_graph.replace(edges=attention_weights_logits),
            normalizer=self._normalizer)
        normalized_attention_weight = slim.dropout(normalized_attention_weight,
                                                   attn_dropout_keep_prob,
                                                   is_training=is_training)

        # Attending to sender values according to the weights.
        # - attended_edges = [total_num_edges, value_dims]
        attended_edges = edge_values * normalized_attention_weight

        # Summing all of the attended values from each node.
        # aggregated_attended_values = [total_num_nodes, embedding_size]
        received_edges_aggregator = blocks.ReceivedEdgesToNodesAggregator(
            reducer=tf.math.unsorted_segment_sum)
        aggregated_attended_values = received_edges_aggregator(
            input_graph.replace(edges=attended_edges))

        return input_graph.replace(nodes=aggregated_attended_values,
                                   edges=edge_values)
示例#2
0
    def __init__(self, name="DecaySimulator"):
        super(DecaySimulator, self).__init__(name=name)

        self._node_linear = make_mlp_model()
        self._node_rnn = snt.GRU(hidden_size=LATENT_SIZE, name='node_rnn')
        self._node_proper = snt.nets.MLP([4], activate_final=False)

        self._edge_block = blocks.EdgeBlock(edge_model_fn=make_mlp_model,
                                            use_edges=False,
                                            use_receiver_nodes=True,
                                            use_sender_nodes=True,
                                            use_globals=False,
                                            name='edge_encoder_block')
        self._node_encoder_block = blocks.NodeBlock(
            node_model_fn=make_mlp_model,
            use_received_edges=False,
            use_sent_edges=False,
            use_nodes=True,
            use_globals=False,
            name='node_encoder_block')

        self._global_encoder_block = blocks.GlobalBlock(
            global_model_fn=make_mlp_model,
            use_edges=True,
            use_nodes=True,
            use_globals=False,
            nodes_reducer=tf.math.unsorted_segment_sum,
            edges_reducer=tf.math.unsorted_segment_sum,
            name='global_encoder_block')

        self._core = MLPGraphNetwork()

        # self._core = InteractionNetwork(
        #     edge_model_fn=make_mlp_model,
        #     node_model_fn=make_mlp_model,
        #     reducer=tf.math.unsorted_segment_sum
        # )

        # # Transforms the outputs into appropriate shapes.
        node_output_size = 64
        node_fn = lambda: snt.Sequential([
            snt.nets.MLP(
                [node_output_size],
                activation=tf.nn.relu,  # default is relu
                name='node_output')
        ])

        global_output_size = 1
        global_fn = lambda: snt.Sequential([
            snt.nets.MLP(
                [global_output_size],
                activation=tf.nn.relu,  # default is relu
                name='global_output'),
            tf.sigmoid
        ])

        self._output_transform = modules.GraphIndependent(
            edge_model_fn=None,
            node_model_fn=node_fn,
            global_model_fn=global_fn)
    def __init__(self, name="DeepGraphInfoMax"):
        super(DeepGraphInfoMax, self).__init__(name=name)

        self._edge_block = blocks.EdgeBlock(
            edge_model_fn=lambda: snt.nets.MLP([LATENT_SIZE] * 2,
                                               activation=tf.nn.relu,
                                               activate_final=True,
                                               use_dropout=True),
            use_edges=False,
            use_receiver_nodes=True,
            use_sender_nodes=True,
            use_globals=False,
            name='edge_encoder_block')
        self._node_encoder_block = blocks.NodeBlock(
            node_model_fn=make_mlp_model,
            use_received_edges=False,
            use_sent_edges=False,
            use_nodes=True,
            use_globals=False,
            name='node_encoder_block')

        self._core = modules.InteractionNetwork(
            edge_model_fn=make_mlp_model,
            node_model_fn=make_mlp_model,
            reducer=tf.unsorted_segment_sum)
示例#4
0
    def __init__(self, name, edge_model_fn, node_model_fn, global_model_fn):
        """
        description: initializes the model

        :param edge_model_fn: Function passed to the edge block, in this paper, it is an MLP
        :param node_model_fn: Function passed to the node block, in this paper for the task of bAbI, it is an LSTM over timesteps
        :param edge_model_fn: Function passed to the global block, in this paper, it is an MLP
        """

        super(RRN, self).__init__(name=name)

        self._edge_block = blocks.EdgeBlock(edge_model_fn=edge_model_fn,
                                            use_edges=False,
                                            use_receiver_nodes=True,
                                            use_sender_nodes=True,
                                            use_globals=False)

        self._global_block = blocks.GlobalBlock(
            global_model_fn=global_model_fn,
            use_edges=False,
            use_nodes=True,
            use_globals=False,
            nodes_reducer=tf.unsorted_segment_sum)

        self._node_model_fn = node_model_fn
示例#5
0
文件: GNN.py 项目: YannickCharles/RQ1
    def __init__(self, edge_model_fn, node_model_fn, name="fwd_gnn"):
        # aggregator_fn = tf.math.unsorted_segment_sum,
        """Initializes the transition model"""

        super(GraphNeuralNetwork_transition, self).__init__(name=name)

        with self._enter_variable_scope():
            self._edge_block = blocks.EdgeBlock(edge_model_fn=edge_model_fn,
                                                use_edges=False,
                                                use_receiver_nodes=True,
                                                use_sender_nodes=True,
                                                use_globals=False,
                                                name='edge_block')

            # self._node_block = blocks.NodeBlock(
            #     node_model_fn=node_model_fn,
            #     use_received_edges=True,
            #     use_sent_edges=True,
            #     use_nodes=True,
            #     use_globals=True,
            #     received_edges_reducer=tf.math.unsorted_segment_sum,
            #     sent_edges_reducer=tf.math.unsorted_segment_sum,
            #     name="node_block")

            self._sent_edges_aggregator = blocks.SentEdgesToNodesAggregator(
                reducer=tf.math.unsorted_segment_sum)
            self._node_model = node_model_fn()
示例#6
0
    def __init__(self, name="FourTopPredictor"):
        super(FourTopPredictor, self).__init__(name=name)

        self._edge_block = blocks.EdgeBlock(edge_model_fn=make_mlp_model,
                                            use_edges=False,
                                            use_receiver_nodes=True,
                                            use_sender_nodes=True,
                                            use_globals=False,
                                            name='edge_encoder_block')
        self._node_encoder_block = blocks.NodeBlock(
            node_model_fn=make_mlp_model,
            use_received_edges=False,
            use_sent_edges=False,
            use_nodes=True,
            use_globals=False,
            name='node_encoder_block')

        self._global_block = blocks.GlobalBlock(
            global_model_fn=make_mlp_model,
            use_edges=True,
            use_nodes=True,
            use_globals=False,
        )

        self._core = MLPGraphNetwork()

        # Transforms the outputs into appropriate shapes.
        global_output_size = n_target_node_features * n_max_tops
        self._global_nn = snt.nets.MLP(
            [128, 128, global_output_size],
            activation=tf.nn.leaky_relu,  # default is relu, tanh
            dropout_rate=0.30,
            name='global_output')
示例#7
0
    def __init__(self,
                 edge_model_fn,
                 node_model_fn,
                 reducer=tf.unsorted_segment_sum,
                 name="interaction_network"):
        """Initializes the InteractionNetwork module.

    Args:
      edge_model_fn: A callable that will be passed to `EdgeBlock` to perform
        per-edge computations. The callable must return a Sonnet module (or
        equivalent; see `blocks.EdgeBlock` for details), and the shape of the
        output of this module must match the one of the input nodes, but for the
        first and last axis.
      node_model_fn: A callable that will be passed to `NodeBlock` to perform
        per-node computations. The callable must return a Sonnet module (or
        equivalent; see `blocks.NodeBlock` for details).
      reducer: Reducer to be used by NodeBlock to aggregate edges. Defaults to
        tf.unsorted_segment_sum.
      name: The module name.
    """
        super(InteractionNetwork, self).__init__(name=name)

        with self._enter_variable_scope():
            self._edge_block = blocks.EdgeBlock(edge_model_fn=edge_model_fn,
                                                use_globals=False)
            self._node_block = blocks.NodeBlock(node_model_fn=node_model_fn,
                                                use_received_edges=True,
                                                use_sent_edges=True,
                                                use_globals=False,
                                                received_edges_reducer=reducer)
示例#8
0
    def __init__(self,
                 edge_model_fn,
                 node_model_fn,
                 reducer=tf.unsorted_segment_sum,
                 name="tcomm_net"):
        super(TCommNet, self).__init__(name=name)

        with self._enter_variable_scope():
            self._edge_block = blocks.LEdgeBlock(edge_model_fn=edge_model_fn,
                                                 use_edges=False,
                                                 use_receiver_nodes=False,
                                                 use_sender_nodes=True,
                                                 use_globals=False,
                                                 use_reverse_edges=True)
            self._edge_block2 = blocks.EdgeBlock(edge_model_fn=edge_model_fn,
                                                 use_edges=False,
                                                 use_receiver_nodes=False,
                                                 use_sender_nodes=True,
                                                 use_globals=False)
            self._node_block = blocks.toLNodeBlock(
                node_model_fn=node_model_fn,
                use_received_edges=True,
                use_sent_edges=False,
                use_nodes=True,
                use_globals=False,
                received_edges_reducer=reducer)
示例#9
0
    def __init__(self,
                 edge_model_fn,
                 global_model_fn,
                 reducer=tf.unsorted_segment_sum,
                 name="relation_network"):
        """Initializes the RelationNetwork module.

    Args:
      edge_model_fn: A callable that will be passed to EdgeBlock to perform
        per-edge computations. The callable must return a Sonnet module (or
        equivalent; see EdgeBlock for details).
      global_model_fn: A callable that will be passed to GlobalBlock to perform
        per-global computations. The callable must return a Sonnet module (or
        equivalent; see GlobalBlock for details).
      reducer: Reducer to be used by GlobalBlock to aggregate edges. Defaults to
        tf.unsorted_segment_sum.
      name: The module name.
    """
        super(RelationNetwork, self).__init__(name=name)

        with self._enter_variable_scope():
            self._edge_block = blocks.EdgeBlock(edge_model_fn=edge_model_fn,
                                                use_edges=False,
                                                use_receiver_nodes=True,
                                                use_sender_nodes=True,
                                                use_globals=False)

            self._global_block = blocks.GlobalBlock(
                global_model_fn=global_model_fn,
                use_edges=True,
                use_nodes=False,
                use_globals=False,
                edges_reducer=reducer)
示例#10
0
    def __init__(self,
                 edge_model_fn,
                 node_model_fn,
                 reducer=tf.unsorted_segment_sum,
                 name="comm_net"):
        """Initializes the CommNet module.

        Args:
          edge_model_fn: A callable to be passed to EdgeBlock. The callable must
            return a Sonnet module (or equivalent; see EdgeBlock for details).
          node_encoder_model_fn: A callable to be passed to the NodeBlock
            responsible for the first encoding of the nodes. The callable must
            return a Sonnet module (or equivalent; see NodeBlock for details). The
            shape of this module's output should match the shape of the module built
            by `edge_model_fn`, but for the first and last dimension.
          node_model_fn: A callable to be passed to NodeBlock. The callable must
            return a Sonnet module (or equivalent; see NodeBlock for details).
          reducer: Reduction to be used when aggregating the edges in the nodes.
            This should be a callable whose signature matches
            tf.unsorted_segment_sum.
          name: The module name.
        """
        super(CommNet, self).__init__(name=name)

        with self._enter_variable_scope():
            # Computes $\Psi_{com}(x_j)$ in Eq. (2) of 1706.06122
            self._edge_block = blocks.EdgeBlock(edge_model_fn=edge_model_fn,
                                                use_edges=False,
                                                use_receiver_nodes=False,
                                                use_sender_nodes=True,
                                                use_globals=False)
示例#11
0
 def test_compatible_higher_rank_no_raise(self):
   """No exception should occur with higher ranks tensors."""
   input_graph = self._get_shaped_input_graph()
   input_graph = input_graph.map(lambda v: tf.transpose(v, [0, 2, 1, 3]))
   network = blocks.EdgeBlock(
       functools.partial(snt.Conv2D, output_channels=10, kernel_shape=[3, 3]))
   self._assert_build_and_run(network, input_graph)
示例#12
0
  def test_same_as_subblocks(self, reducer):
    """Compares the output to explicit subblocks output.

    Args:
      reducer: The reducer used in the `NodeBlock` and `GlobalBlock`.
    """
    input_graph = self._get_input_graph()

    edge_model_fn = functools.partial(snt.Linear, output_size=5)
    node_model_fn = functools.partial(snt.Linear, output_size=10)
    global_model_fn = functools.partial(snt.Linear, output_size=15)

    graph_network = modules.GraphNetwork(
        edge_model_fn=edge_model_fn,
        node_model_fn=node_model_fn,
        global_model_fn=global_model_fn,
        reducer=reducer)

    output_graph = graph_network(input_graph)

    edge_block = blocks.EdgeBlock(
        edge_model_fn=lambda: graph_network._edge_block._edge_model,
        use_sender_nodes=True,
        use_edges=True,
        use_receiver_nodes=True,
        use_globals=True)
    node_block = blocks.NodeBlock(
        node_model_fn=lambda: graph_network._node_block._node_model,
        use_nodes=True,
        use_sent_edges=False,
        use_received_edges=True,
        use_globals=True,
        received_edges_reducer=reducer)
    global_block = blocks.GlobalBlock(
        global_model_fn=lambda: graph_network._global_block._global_model,
        use_nodes=True,
        use_edges=True,
        use_globals=True,
        edges_reducer=reducer,
        nodes_reducer=reducer)

    expected_output_edge_block = edge_block(input_graph)
    expected_output_node_block = node_block(expected_output_edge_block)
    expected_output_global_block = global_block(expected_output_node_block)
    expected_edges = expected_output_edge_block.edges
    expected_nodes = expected_output_node_block.nodes
    expected_globals = expected_output_global_block.globals

    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      (output_graph_out,
       expected_edges_out, expected_nodes_out, expected_globals_out) = sess.run(
           (output_graph, expected_edges, expected_nodes, expected_globals))

    self._assert_all_none_or_all_close(expected_edges_out,
                                       output_graph_out.edges)
    self._assert_all_none_or_all_close(expected_nodes_out,
                                       output_graph_out.nodes)
    self._assert_all_none_or_all_close(expected_globals_out,
                                       output_graph_out.globals)
示例#13
0
    def __init__(self, name="GlobalClassifierNoEdgeInfo"):
        super(GlobalClassifierNoEdgeInfo, self).__init__(name=name)

        self._edge_block = blocks.EdgeBlock(edge_model_fn=make_mlp_model,
                                            use_edges=False,
                                            use_receiver_nodes=True,
                                            use_sender_nodes=True,
                                            use_globals=False,
                                            name='edge_encoder_block')

        self._node_encoder_block = blocks.NodeBlock(
            node_model_fn=make_mlp_model,
            use_received_edges=False,
            use_sent_edges=False,
            use_nodes=True,
            use_globals=False,
            name='node_encoder_block')

        self._global_block = blocks.GlobalBlock(
            global_model_fn=make_mlp_model,
            use_edges=True,
            use_nodes=True,
            use_globals=False,
        )

        self._core = MLPGraphNetwork()
        # Transforms the outputs into appropriate shapes.
        global_output_size = 1
        global_fn = lambda: snt.Sequential([
            snt.nets.MLP([LATENT_SIZE, global_output_size],
                         name='global_output'), tf.sigmoid
        ])

        self._output_transform = modules.GraphIndependent(
            None, None, global_fn)
示例#14
0
  def test_output_values(
      self, use_edges, use_receiver_nodes, use_sender_nodes, use_globals):
    """Compares the output of an EdgeBlock to an explicit computation."""
    input_graph = self._get_input_graph()
    edge_block = blocks.EdgeBlock(
        edge_model_fn=self._edge_model_fn,
        use_edges=use_edges,
        use_receiver_nodes=use_receiver_nodes,
        use_sender_nodes=use_sender_nodes,
        use_globals=use_globals)
    output_graph = edge_block(input_graph)

    model_inputs = []
    if use_edges:
      model_inputs.append(input_graph.edges)
    if use_receiver_nodes:
      model_inputs.append(blocks.broadcast_receiver_nodes_to_edges(input_graph))
    if use_sender_nodes:
      model_inputs.append(blocks.broadcast_sender_nodes_to_edges(input_graph))
    if use_globals:
      model_inputs.append(blocks.broadcast_globals_to_edges(input_graph))

    model_inputs = tf.concat(model_inputs, axis=-1)
    self.assertEqual(input_graph.nodes, output_graph.nodes)
    self.assertEqual(input_graph.globals, output_graph.globals)

    with self.test_session() as sess:
      output_graph_out, model_inputs_out = sess.run(
          (output_graph, model_inputs))

    expected_output_edges = model_inputs_out * self._scale
    self.assertNDArrayNear(
        expected_output_edges, output_graph_out.edges, err=1e-4)
示例#15
0
    def test_created_variables(self, use_edges, use_receiver_nodes,
                               use_sender_nodes, use_globals,
                               expected_first_dim_w):
        """Verifies the variable names and shapes created by an EdgeBlock."""
        output_size = 10
        expected_var_shapes_dict = {
            "edge_block/mlp/linear_0/b:0": [output_size],
            "edge_block/mlp/linear_0/w:0": [expected_first_dim_w, output_size]
        }

        input_graph = self._get_input_graph()

        edge_block = blocks.EdgeBlock(edge_model_fn=functools.partial(
            snt.nets.MLP, output_sizes=[output_size]),
                                      use_edges=use_edges,
                                      use_receiver_nodes=use_receiver_nodes,
                                      use_sender_nodes=use_sender_nodes,
                                      use_globals=use_globals)
        edge_block(input_graph)

        variables = edge_block.get_variables()
        var_shapes_dict = {
            var.name: var.get_shape().as_list()
            for var in variables
        }
        self.assertDictEqual(expected_var_shapes_dict, var_shapes_dict)
示例#16
0
  def test_unused_field_can_be_none(
      self, use_edges, use_nodes, use_globals, none_field):
    """Checks that computation can handle non-necessary fields left None."""
    input_graph = self._get_input_graph([none_field])
    edge_block = blocks.EdgeBlock(
        edge_model_fn=self._edge_model_fn,
        use_edges=use_edges,
        use_receiver_nodes=use_nodes,
        use_sender_nodes=use_nodes,
        use_globals=use_globals)
    output_graph = edge_block(input_graph)

    model_inputs = []
    if use_edges:
      model_inputs.append(input_graph.edges)
    if use_nodes:
      model_inputs.append(blocks.broadcast_receiver_nodes_to_edges(input_graph))
      model_inputs.append(blocks.broadcast_sender_nodes_to_edges(input_graph))
    if use_globals:
      model_inputs.append(blocks.broadcast_globals_to_edges(input_graph))

    model_inputs = tf.concat(model_inputs, axis=-1)
    self.assertEqual(input_graph.nodes, output_graph.nodes)
    self.assertEqual(input_graph.globals, output_graph.globals)

    with self.test_session() as sess:
      actual_edges, model_inputs_out = sess.run(
          (output_graph.edges, model_inputs))

    expected_output_edges = model_inputs_out * self._scale
    self.assertNDArrayNear(expected_output_edges, actual_edges, err=1e-4)
示例#17
0
    def __init__(self, name="SegmentClassifier"):
        super(SegmentClassifier, self).__init__(name=name)

        self._edge_block = blocks.EdgeBlock(edge_model_fn=make_mlp_model,
                                            use_edges=False,
                                            use_receiver_nodes=True,
                                            use_sender_nodes=True,
                                            use_globals=False,
                                            name='edge_encoder_block')
        self._node_encoder_block = blocks.NodeBlock(
            node_model_fn=make_mlp_model,
            use_received_edges=False,
            use_sent_edges=False,
            use_nodes=True,
            use_globals=False,
            name='node_encoder_block')

        self._core = InteractionNetwork(edge_model_fn=make_mlp_model,
                                        node_model_fn=make_mlp_model,
                                        reducer=tf.math.unsorted_segment_sum)

        # Transforms the outputs into appropriate shapes.
        edge_output_size = 1
        edge_fn = lambda: snt.Sequential([
            snt.nets.MLP(
                [edge_output_size],
                activation=tf.nn.relu,  # default is relu
                name='edge_output'),
            tf.sigmoid
        ])

        self._output_transform = modules.GraphIndependent(edge_fn, None, None)
示例#18
0
    def __init__(self,
                 edge_model_fn,
                 node_encoder_model_fn,
                 node_model_fn,
                 reducer=tf.unsorted_segment_sum,
                 name="oricomm_net"):
        super(OriCommNet, self).__init__(name=name)

        with self._enter_variable_scope():
            # Computes $\Psi_{com}(x_j)$ in Eq. (2) of 1706.06122
            self._edge_block = blocks.EdgeBlock(edge_model_fn=edge_model_fn,
                                                use_edges=False,
                                                use_receiver_nodes=False,
                                                use_sender_nodes=True,
                                                use_globals=False)
            # Computes $\Phi(x_i)$ in Eq. (2) of 1706.06122
        self._node_encoder_block = blocks.NodeBlock(
            node_model_fn=node_encoder_model_fn,
            use_received_edges=False,
            use_sent_edges=False,
            use_nodes=True,
            use_globals=False,
            received_edges_reducer=reducer,
            name="node_encoder_block")
        # Computes $\Theta(..)$ in Eq.(2) of 1706.06122
        self._node_block = blocks.NodeBlock(node_model_fn=node_model_fn,
                                            use_received_edges=True,
                                            use_sent_edges=False,
                                            use_nodes=True,
                                            use_globals=False,
                                            received_edges_reducer=reducer)
示例#19
0
  def test_edge_block_options(self,
                              use_edges,
                              use_receiver_nodes,
                              use_sender_nodes,
                              use_globals):
    """Test for configuring the EdgeBlock options."""
    reducer = tf.math.unsorted_segment_sum
    input_graph = self._get_input_graph()
    edge_model_fn = functools.partial(snt.Linear, output_size=10)
    edge_block_opt = {"use_edges": use_edges,
                      "use_receiver_nodes": use_receiver_nodes,
                      "use_sender_nodes": use_sender_nodes,
                      "use_globals": use_globals}
    # Identity node model
    node_model_fn = lambda: tf.identity
    node_block_opt = {"use_received_edges": False,
                      "use_sent_edges": False,
                      "use_nodes": True,
                      "use_globals": False}
    # Identity global model
    global_model_fn = lambda: tf.identity
    global_block_opt = {"use_globals": True,
                        "use_nodes": False,
                        "use_edges": False}

    graph_network = modules.GraphNetwork(
        edge_model_fn=edge_model_fn,
        edge_block_opt=edge_block_opt,
        node_model_fn=node_model_fn,
        node_block_opt=node_block_opt,
        global_model_fn=global_model_fn,
        global_block_opt=global_block_opt,
        reducer=reducer)

    output_graph = graph_network(input_graph)

    edge_block = blocks.EdgeBlock(
        edge_model_fn=lambda: graph_network._edge_block._edge_model,
        use_edges=use_edges,
        use_receiver_nodes=use_receiver_nodes,
        use_sender_nodes=use_sender_nodes,
        use_globals=use_globals)

    expected_output_edge_block = edge_block(input_graph)
    expected_output_node_block = expected_output_edge_block
    expected_output_global_block = expected_output_node_block
    expected_edges = expected_output_edge_block.edges
    expected_nodes = expected_output_node_block.nodes
    expected_globals = expected_output_global_block.globals

    with self.test_session() as sess:
      sess.run(tf.compat.v1.global_variables_initializer())
      (output_graph_out,
       expected_edges_out, expected_nodes_out, expected_globals_out) = sess.run(
           (output_graph, expected_edges, expected_nodes, expected_globals))

    self.assertAllEqual(expected_edges_out, output_graph_out.edges)
    self.assertAllEqual(expected_nodes_out, output_graph_out.nodes)
    self.assertAllEqual(expected_globals_out, output_graph_out.globals)
示例#20
0
 def test_no_input_raises_exception(self):
     """Checks that receiving no input raises an exception."""
     with self.assertRaisesRegexp(ValueError, "At least one of "):
         blocks.EdgeBlock(edge_model_fn=self._edge_model_fn,
                          use_edges=False,
                          use_receiver_nodes=False,
                          use_sender_nodes=False,
                          use_globals=False)
示例#21
0
    def __init__(self,
                 attention_node_projection_model,
                 attention_edge_projection_model,
                 query_key_product_model,
                 node_model_fn,
                 edge_model_fn,
                 global_model_fn,
                 num_heads,
                 key_size,
                 value_size,
                 edge_block_opt=None,
                 global_block_opt=None,
                 name="GAT"):
        """
      Args:
        attention_node_projection_model: Model used for projection to get
          query, key and values
          Final layer dim should be key_size * num_heads
        attention_edge_projection_model: Model used for projection to get
          query, key and values
          Final layer dim should be (key_size + value_size) * num_heads
        query_key_product_model: Model used to find "dot product" between
          queries and keys.
          Final layer dim should be 1.
        node_model_fn: Model applied to node embeddings finally.
        edge_model_fn: Model applied to node embeddings finally.
        num_heads: Number of attention heads
        key_size: Key dimension
        value_size: value dimension
        edge_block_opt: Additional options to be passed to the EdgeBlock. Can
        contain keys `use_edges`, `use_receiver_nodes`, `use_sender_nodes`,
        `use_globals`. By default, these are all True.
        global_block_opt: Additional options to be passed to the GlobalBlock. Can
          contain the keys `use_edges`, `use_nodes`, `use_globals` (all set to
          True by default), and `edges_reducer`, `nodes_reducer` (defaults to
          `reducer`).
        name: The module name.
    """
        super().__init__(name=name)
        self._attention_node_projection_model = attention_node_projection_model
        self._attention_edge_projection_model = attention_edge_projection_model
        self._query_key_product_model = query_key_product_model
        self.num_heads = num_heads
        self.key_size = key_size
        self.value_size = value_size

        edge_block_opt = _make_default_edge_block_opt(edge_block_opt)
        global_block_opt = _make_default_global_block_opt(
            global_block_opt, tf.unsorted_segment_sum)
        # does not make sense without using sender nodes.
        assert edge_block_opt['use_sender_nodes']
        with self._enter_variable_scope():
            self._node_model = node_model_fn()
            self._edge_block = blocks.EdgeBlock(edge_model_fn=edge_model_fn,
                                                **edge_block_opt)
            self._global_block = blocks.GlobalBlock(
                global_model_fn=global_model_fn, **global_block_opt)
示例#22
0
  def test_optional_arguments(self, scale, offset):
    """Assesses the correctness of the EdgeBlock using arguments."""
    input_graph = self._get_input_graph()
    edge_block = blocks.EdgeBlock(edge_model_fn=self._edge_model_args_fn)
    output_graph_out = edge_block(
        input_graph, edge_model_kwargs=dict(scale=scale, offset=offset))

    fixed_scale = scale
    fixed_offset = offset
    model_fn = lambda: lambda features: features * fixed_scale + fixed_offset
    hardcoded_edge_block = blocks.EdgeBlock(edge_model_fn=model_fn)
    expected_graph_out = hardcoded_edge_block(input_graph)

    self.assertIs(expected_graph_out.nodes, output_graph_out.nodes)
    self.assertIs(expected_graph_out.globals, output_graph_out.globals)
    self.assertNDArrayNear(
        expected_graph_out.edges.numpy(),
        output_graph_out.edges.numpy(),
        err=1e-4)
示例#23
0
  def test_same_as_subblocks(self, reducer, none_field=None):
    """Compares the output to explicit subblocks output.

    Args:
      reducer: The reducer used in the `NodeBlock`s.
      none_field: (string, default=None) If not None, the corresponding field
        is removed from the input graph.
    """
    input_graph = self._get_input_graph(none_field)

    comm_net = self._get_model(reducer)
    output_graph = comm_net(input_graph)
    output_nodes = output_graph.nodes

    edge_subblock = blocks.EdgeBlock(
        edge_model_fn=lambda: comm_net._edge_block._edge_model,
        use_edges=False,
        use_receiver_nodes=False,
        use_sender_nodes=True,
        use_globals=False)
    node_encoder_subblock = blocks.NodeBlock(
        node_model_fn=lambda: comm_net._node_encoder_block._node_model,
        use_received_edges=False,
        use_sent_edges=False,
        use_nodes=True,
        use_globals=False,
        received_edges_reducer=reducer)
    node_subblock = blocks.NodeBlock(
        node_model_fn=lambda: comm_net._node_block._node_model,
        use_received_edges=True,
        use_sent_edges=False,
        use_nodes=True,
        use_globals=False,
        received_edges_reducer=reducer)

    edge_block_out = edge_subblock(input_graph)
    encoded_nodes = node_encoder_subblock(input_graph).nodes
    node_input_graph = input_graph.replace(
        edges=edge_block_out.edges, nodes=encoded_nodes)
    node_block_out = node_subblock(node_input_graph)
    expected_nodes = node_block_out.nodes

    self.assertAllEqual(input_graph.globals, output_graph.globals)
    self.assertAllEqual(input_graph.edges, output_graph.edges)
    self.assertAllEqual(input_graph.receivers, output_graph.receivers,)
    self.assertAllEqual(input_graph.senders, output_graph.senders)

    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      actual_nodes_output, expected_nodes_output = sess.run(
          [output_nodes, expected_nodes])

    self._assert_all_none_or_all_close(expected_nodes_output,
                                       actual_nodes_output)
示例#24
0
    def __init__(self,
                 edge_fn,
                 with_edge_inputs=False,
                 with_node_inputs=True,
                 encoder_size: list = None,
                 core_size: list = None,
                 name="EdgeLearnerBase",
                 **kwargs):
        super(EdgeLearnerBase, self).__init__(name=name)

        if encoder_size is not None:
            encoder_mlp_fn = partial(make_mlp_model,
                                     mlp_size=encoder_size,
                                     **kwargs)
        else:
            encoder_mlp_fn = partial(make_mlp_model, **kwargs)

        edge_block_args = dict(use_edges=False,
                               use_receiver_nodes=True,
                               use_sender_nodes=True,
                               use_globals=False)
        node_block_args = dict(use_received_edges=False,
                               use_sent_edges=False,
                               use_nodes=True,
                               use_globals=False)
        if with_edge_inputs:
            edge_block_args['use_edges'] = True
            node_block_args['use_received_edges'] = True
            node_block_args['use_sent_edges'] = True
        if not with_node_inputs:
            edge_block_args['use_receiver_nodes'] = False
            edge_block_args['use_sender_nodes'] = False
            node_block_args['use_nodes'] = False

        self._edge_block = blocks.EdgeBlock(edge_model_fn=encoder_mlp_fn,
                                            **edge_block_args,
                                            name='edge_encoder_block')

        self._node_encoder_block = blocks.NodeBlock(
            node_model_fn=encoder_mlp_fn,
            **node_block_args,
            name='node_encoder_block')

        if core_size is not None:
            core_mlp_fn = partial(make_mlp_model, mlp_size=core_size, **kwargs)
        else:
            core_mlp_fn = partial(make_mlp_model, **kwargs)

        self._core = InteractionNetwork(edge_model_fn=core_mlp_fn,
                                        node_model_fn=core_mlp_fn,
                                        reducer=tf.math.unsorted_segment_sum)

        self._output_transform = modules.GraphIndependent(edge_fn, None, None)
示例#25
0
    def __init__(self,
                 edge_model_fn,
                 node_model_fn,
                 global_model_fn,
                 reducer=tf.math.unsorted_segment_sum,
                 edge_block_opt=None,
                 node_block_opt=None,
                 global_block_opt=None,
                 name="graph_network"):
        """Initializes the GraphNetwork module.

    Args:
      edge_model_fn: A callable that will be passed to EdgeBlock to perform
        per-edge computations. The callable must return a Sonnet module (or
        equivalent; see EdgeBlock for details).
      node_model_fn: A callable that will be passed to NodeBlock to perform
        per-node computations. The callable must return a Sonnet module (or
        equivalent; see NodeBlock for details).
      global_model_fn: A callable that will be passed to GlobalBlock to perform
        per-global computations. The callable must return a Sonnet module (or
        equivalent; see GlobalBlock for details).
      reducer: Reducer to be used by NodeBlock and GlobalBlock to aggregate
        nodes and edges. Defaults to tf.unsorted_segment_sum. This will be
        overridden by the reducers specified in `node_block_opt` and
        `global_block_opt`, if any.
      edge_block_opt: Additional options to be passed to the EdgeBlock. Can
        contain keys `use_edges`, `use_receiver_nodes`, `use_sender_nodes`,
        `use_globals`. By default, these are all True.
      node_block_opt: Additional options to be passed to the NodeBlock. Can
        contain the keys `use_received_edges`, `use_sent_edges`, `use_nodes`,
        `use_globals` (all set to True by default), and
        `received_edges_reducer`, `sent_edges_reducer` (default to `reducer`).
      global_block_opt: Additional options to be passed to the GlobalBlock. Can
        contain the keys `use_edges`, `use_nodes`, `use_globals` (all set to
        True by default), and `edges_reducer`, `nodes_reducer` (defaults to
        `reducer`).
      name: The module name.
    """
        super(GraphNetwork, self).__init__(name=name)
        edge_block_opt = _make_default_edge_block_opt(edge_block_opt)
        node_block_opt = _make_default_node_block_opt(node_block_opt, reducer)
        global_block_opt = _make_default_global_block_opt(
            global_block_opt, reducer)

        #with self._enter_variable_scope():
        if 2 > 1:
            self._edge_block = blocks.EdgeBlock(edge_model_fn=edge_model_fn,
                                                **edge_block_opt)
            self._node_block = blocks.NodeBlock(node_model_fn=node_model_fn,
                                                **node_block_opt)
            self._global_block = blocks.GlobalBlock(
                global_model_fn=global_model_fn, **global_block_opt)
示例#26
0
 def test_missing_field_raises_exception(
     self, use_edges, use_receiver_nodes, use_sender_nodes, use_globals,
     none_fields):
   """Checks that missing a required field raises an exception."""
   input_graph = self._get_input_graph(none_fields)
   edge_block = blocks.EdgeBlock(
       edge_model_fn=self._edge_model_fn,
       use_edges=use_edges,
       use_receiver_nodes=use_receiver_nodes,
       use_sender_nodes=use_sender_nodes,
       use_globals=use_globals)
   with self.assertRaisesRegexp(ValueError, "field cannot be None"):
     edge_block(input_graph)
示例#27
0
文件: base.py 项目: rkunnawa/root_gnn
 def __init__(self,
              edge_model_fn,
              node_model_fn,
              reducer=tf.math.unsorted_segment_sum,
              name="interaction_network"):
     super(InteractionNetwork, self).__init__(name=name)
     self._edge_block = blocks.EdgeBlock(edge_model_fn=edge_model_fn,
                                         use_globals=False)
     self._node_block = blocks.NodeBlock(node_model_fn=node_model_fn,
                                         use_received_edges=True,
                                         use_sent_edges=True,
                                         use_globals=False,
                                         received_edges_reducer=reducer)
示例#28
0
 def test_incompatible_higher_rank_inputs_no_raise(self, use_edges,
                                                   use_receiver_nodes,
                                                   use_sender_nodes,
                                                   use_globals, field):
     """No exception should occur if a differently shapped field is not used."""
     input_graph = self._get_shaped_input_graph()
     input_graph = input_graph.replace(
         **{field: tf.transpose(getattr(input_graph, field), [0, 2, 1, 3])})
     network = blocks.EdgeBlock(functools.partial(snt.Conv2D,
                                                  output_channels=10,
                                                  kernel_shape=[3, 3]),
                                use_edges=use_edges,
                                use_receiver_nodes=use_receiver_nodes,
                                use_sender_nodes=use_sender_nodes,
                                use_globals=use_globals)
     self._assert_build_and_run(network, input_graph)
示例#29
0
 def test_incompatible_higher_rank_inputs_raises(self, use_edges,
                                                 use_receiver_nodes,
                                                 use_sender_nodes,
                                                 use_globals, field):
     """A exception should be raised if the inputs have incompatible shapes."""
     input_graph = self._get_shaped_input_graph()
     input_graph = input_graph.replace(
         **{field: tf.transpose(getattr(input_graph, field), [0, 2, 1, 3])})
     network = blocks.EdgeBlock(functools.partial(snt.Conv2D,
                                                  output_channels=10,
                                                  kernel_shape=[3, 3]),
                                use_edges=use_edges,
                                use_receiver_nodes=use_receiver_nodes,
                                use_sender_nodes=use_sender_nodes,
                                use_globals=use_globals)
     with self.assertRaisesRegexp(ValueError,
                                  "in both shapes must be equal"):
         network(input_graph)
示例#30
0
  def test_same_as_subblocks(self, reducer, none_field=None):
    """Compares the output to explicit subblocks output.

    Args:
      reducer: The reducer used in the `NodeBlock`s.
      none_field: (string, default=None) If not None, the corresponding field
        is removed from the input graph.
    """
    input_graph = self._get_input_graph(none_field)

    interaction_network = self._get_model(reducer)
    output_graph = interaction_network(input_graph)
    edges_out = output_graph.edges
    nodes_out = output_graph.nodes
    self.assertAllEqual(input_graph.globals, output_graph.globals)

    edge_block = blocks.EdgeBlock(
        edge_model_fn=lambda: interaction_network._edge_block._edge_model,
        use_sender_nodes=True,
        use_edges=True,
        use_receiver_nodes=True,
        use_globals=False)
    node_block = blocks.NodeBlock(
        node_model_fn=lambda: interaction_network._node_block._node_model,
        use_nodes=True,
        use_sent_edges=False,
        use_received_edges=True,
        use_globals=False,
        received_edges_reducer=reducer)

    expected_output_edge_block = edge_block(input_graph)
    expected_output_node_block = node_block(expected_output_edge_block)
    expected_edges = expected_output_edge_block.edges
    expected_nodes = expected_output_node_block.nodes

    with self.test_session() as sess:
      sess.run(tf.compat.v1.global_variables_initializer())
      (actual_edges_out, actual_nodes_out,
       expected_edges_out, expected_nodes_out) = sess.run(
           [edges_out, nodes_out, expected_edges, expected_nodes])

    self.assertAllEqual(expected_edges_out, actual_edges_out)
    self.assertAllEqual(expected_nodes_out, actual_nodes_out)