Beispiel #1
0
    def get_logits(self, graph_features: gn.graphs.GraphsTuple, node_mask):
        """
      graph_embeddings: Message propagated graph embeddings.
                        Use self.compute_graph_embeddings to compute and cache
                        these to use with different network heads for value, policy etc.
    """
        # broadcast globals and attach them to node features
        graph_features = graph_features.replace(globals=tf.concat([
            graph_features.globals,
            gn.blocks.NodesToGlobalsAggregator(tf.unsorted_segment_mean)
            (graph_features.replace(
                nodes=self.policy_summarize(graph_features.nodes)))
        ],
                                                                  axis=-1))

        graph_features = graph_features.replace(nodes=tf.concat([
            graph_features.nodes,
            gn.blocks.broadcast_globals_to_nodes(graph_features)
        ],
                                                                axis=-1))
        # get logits over nodes
        logits = self.policy_torso(graph_features.nodes)
        # remove the final singleton dimension
        logits = tf.squeeze(logits, axis=-1)
        log_vals = {}
        # record norm *before* adding -INF to invalid spots
        log_vals['opt/logits_norm'] = tf.linalg.norm(logits)

        indices = gn.utils_tf.sparse_to_dense_indices(graph_features.n_node)
        logits = tf.scatter_nd(indices, logits, tf.shape(node_mask))
        logits = tf.where(tf.equal(node_mask, 1), logits,
                          tf.fill(tf.shape(node_mask), -INF))
        return logits, log_vals
Beispiel #2
0
    def get_node_features(
            self, graph: gn.graphs.GraphsTuple) -> gn.graphs.GraphsTuple:
        aggregator = self.get_edge_to_node_aggregator()
        edge_to_v_agg = graph.replace(nodes=aggregator(graph))
        globs_node = graph.replace(
            nodes=gn.blocks.broadcast_globals_to_nodes(graph, name='g_to_n'))

        return gn.utils_tf.concat([graph, edge_to_v_agg, globs_node],
                                  axis=1,
                                  name='concat_n_feats')
Beispiel #3
0
    def get_global_features(
            self, graph: gn.graphs.GraphsTuple) -> gn.graphs.GraphsTuple:
        edge_to_global_agg = self.get_edge_to_global_aggregator()
        edge_to_glob = graph.replace(globals=edge_to_global_agg(graph))

        node_to_global_agg = self.get_node_to_global_aggregator()
        node_to_glob = graph.replace(globals=node_to_global_agg(graph))

        return gn.utils_tf.concat([graph, edge_to_glob, node_to_glob],
                                  axis=1,
                                  name='concat_g_feats')
Beispiel #4
0
  def _convolve(self, graph_features: gn.graphs.GraphsTuple):
    for i in range(self.n_prop_layers):
      with tf.variable_scope('prop_layer_%d' % i):
        new_graph_features = self._graphnet_models[i](graph_features)
        # residual connections
        graph_features = graph_features.replace(
            nodes=new_graph_features.nodes + graph_features.nodes,
            edges=new_graph_features.edges + graph_features.edges,
            globals=new_graph_features.globals + graph_features.globals)

        # layer norm
        graph_features = graph_features.replace(
            nodes=self._node_layer_norms[i](graph_features.nodes))
    return graph_features
Beispiel #5
0
    def get_edge_features(
            graph: gn.graphs.GraphsTuple) -> gn.graphs.GraphsTuple:
        senders = graph.replace(
            edges=gn.blocks.broadcast_sender_nodes_to_edges(graph,
                                                            name='sn_to_e'))
        receivers = graph.replace(
            edges=gn.blocks.broadcast_receiver_nodes_to_edges(graph,
                                                              name='rn_to_e'))
        nodes = graph.replace(edges=0.5 * (senders.edges + receivers.edges))
        globs = graph.replace(
            edges=gn.blocks.broadcast_globals_to_edges(graph, name='g_to_e'))

        return gn.utils_tf.concat([graph, nodes, globs],
                                  axis=1,
                                  name='concat_e_feats')
Beispiel #6
0
    def _attn_convolve(self, graph_features: gn.graphs.GraphsTuple):

        num_heads = self.config.num_heads
        key_size = self.config.key_size
        value_size = self.config.node_embed_dim

        for i in range(self._n_prop_layers):
            with tf.variable_scope('attention'):
                nodes = graph_features.nodes
                qkv_size = 2 * key_size + value_size
                # total_size = qkv_size * num_heads  # denote as F

                # [total_num_nodes, d] => [total_num_nodes, F]
                qkv_flat = self._attention_dense_layers[i](nodes)

                qkv = tf.reshape(qkv_flat, [-1, num_heads, qkv_size])
                # q => [total_num_nodes, num_heads, key_size]
                # k => [total_num_nodes, num_heads, key_size]
                # v => [total_num_nodes, num_heads, value_size]
                q, k, v = tf.split(qkv, [key_size, key_size, value_size], -1)

            with tf.variable_scope('prop_layer_%d' % i):
                new_graph_features = self._graphnet_models[i](v, k, q,
                                                              graph_features)
                # residual connections
                graph_features = graph_features.replace(
                    nodes=new_graph_features.nodes + graph_features.nodes,
                    edges=new_graph_features.edges + graph_features.edges,
                    globals=new_graph_features.globals +
                    graph_features.globals)
Beispiel #7
0
    def _encode(self, graph_features: gn.graphs.GraphsTuple, var_type_mask,
                constraint_type_mask, obj_type_mask):
        nodes = graph_features.nodes
        node_indices = gn.utils_tf.sparse_to_dense_indices(
            graph_features.n_node)
        l = [var_type_mask, constraint_type_mask, obj_type_mask]
        for i, mask in enumerate(l):
            mask = tf.reshape(mask, [-1, tf.shape(mask)[-1]])
            l[i] = tf.gather_nd(params=mask, indices=node_indices)
        var_type_mask, constraint_type_mask, obj_type_mask = l

        # TODO(arc): remove feature padding from nodes.
        nodes = tf.where(
            tf.equal(var_type_mask, 1), self._var_encode_net(nodes),
            tf.where(tf.equal(constraint_type_mask, 1),
                     self._constraint_encode_net(nodes),
                     self._obj_encode_net(nodes)))
        col = tf.fill([infer_shape(nodes)[0]], 0)
        node_types = tf.where(
            tf.equal(var_type_mask, 1), col,
            tf.where(tf.equal(constraint_type_mask, 1), col + 1, col + 2))
        node_types = tf.one_hot(node_types, 3)
        nodes = tf.concat([nodes, node_types], axis=-1)
        graph_features = graph_features.replace(nodes=nodes)
        graph_features = self._encode_net(graph_features)
        return graph_features
    def _encode(self,
                input_graph: gn.graphs.GraphsTuple) -> gn.graphs.GraphsTuple:

        if input_graph.globals is not None:
            broadcasted_globals = gn.blocks.broadcast_globals_to_nodes(
                input_graph)
            input_graph = input_graph.replace(nodes=tf.concat(
                [input_graph.nodes, broadcasted_globals], axis=-1),
                                              globals=None)

        latent_graph_0 = self._encoder_network(input_graph)
        return latent_graph_0
Beispiel #9
0
def create_zero_graph(blue_print: gn.graphs.GraphsTuple,
                      feature_dims: GraphFeatureDimensions):
    graph = blue_print.replace(nodes=None, edges=None, globals=None)
    graph = gn.utils_tf.set_zero_edge_features(graph,
                                               edge_size=feature_dims.edges,
                                               dtype=tf.float64)
    graph = gn.utils_tf.set_zero_node_features(graph,
                                               node_size=feature_dims.nodes,
                                               dtype=tf.float64)
    graph = gn.utils_tf.set_zero_global_features(
        graph, global_size=feature_dims.globals, dtype=tf.float64)
    return graph
Beispiel #10
0
    def get_value(self, graph_features: gn.graphs.GraphsTuple):
        """
      graph_embeddings: Message propagated graph embeddings.
                        Use self.compute_graph_embeddings to compute and cache
                        these to use with different network heads for value, policy etc.
    """
        with tf.variable_scope('value_network'):
            agg = gn.blocks.NodesToGlobalsAggregator(tf.unsorted_segment_mean)(
                graph_features.replace(
                    nodes=self.value_summarize(graph_features.nodes)))

            value = tf.concat([agg, graph_features.globals], axis=-1)
            return tf.squeeze(self.value_torso_2(value), axis=-1)
Beispiel #11
0
  def _encode(
      self, input_graph: gn.graphs.GraphsTuple) -> gn.graphs.GraphsTuple:
    """Encodes the input graph features into a latent graph."""
    # Copy the globals to all of the nodes, if applicable.
    if input_graph.globals is not None:
      broadcasted_globals = gn.blocks.broadcast_globals_to_nodes(input_graph)
      input_graph = input_graph.replace(
          nodes=tf.concat([input_graph.nodes, broadcasted_globals], axis=-1),
          globals=None)

    # Encode the node and edge features.
    latent_graph_0 = self._encoder_network(input_graph)
    return latent_graph_0
Beispiel #12
0
    def _convolve(self, graph_features: gn.graphs.GraphsTuple):
        for i in range(self.n_prop_layers):
            with tf.variable_scope('prop_layer_%d' % i):
                # one round of message passing
                new_graph_features = self._graphnet_models[i](graph_features)
                # residual connections
                graph_features = graph_features.replace(
                    nodes=new_graph_features.nodes + graph_features.nodes,
                    edges=new_graph_features.edges + graph_features.edges,
                    # residual connection not needed for globals,
                    # since the current global_model_fn is identity
                    globals=new_graph_features.globals)

        return graph_features
Beispiel #13
0
    def get_auxiliary_loss(self, graph_features: gn.graphs.GraphsTuple, obs):
        """
      Returns a prediction for each node.
      This is useful for supervised node labelling/prediction tasks.
    """
        node_mask = obs['node_mask']
        # broadcast globals and attach them to node features
        graph_features = graph_features.replace(nodes=tf.concat([
            graph_features.nodes,
            gn.blocks.broadcast_globals_to_nodes(graph_features)
        ],
                                                                axis=-1))
        # get logits over nodes
        logits = self.supervised_prediction_torso(graph_features.nodes)
        # remove the final singleton dimension
        logits = tf.squeeze(logits, axis=-1)
        indices = gn.utils_tf.sparse_to_dense_indices(graph_features.n_node)
        preds = tf.scatter_nd(indices, logits, tf.shape(node_mask))

        var_type_mask = obs['var_type_mask']
        auxiliary_loss = tf.reduce_mean(
            tf.boolean_mask((preds - obs['optimal_solution'])**2,
                            tf.cast(var_type_mask, tf.bool)))
        return auxiliary_loss