def get_logits(self, graph_features: gn.graphs.GraphsTuple, node_mask): """ graph_embeddings: Message propagated graph embeddings. Use self.compute_graph_embeddings to compute and cache these to use with different network heads for value, policy etc. """ # broadcast globals and attach them to node features graph_features = graph_features.replace(globals=tf.concat([ graph_features.globals, gn.blocks.NodesToGlobalsAggregator(tf.unsorted_segment_mean) (graph_features.replace( nodes=self.policy_summarize(graph_features.nodes))) ], axis=-1)) graph_features = graph_features.replace(nodes=tf.concat([ graph_features.nodes, gn.blocks.broadcast_globals_to_nodes(graph_features) ], axis=-1)) # get logits over nodes logits = self.policy_torso(graph_features.nodes) # remove the final singleton dimension logits = tf.squeeze(logits, axis=-1) log_vals = {} # record norm *before* adding -INF to invalid spots log_vals['opt/logits_norm'] = tf.linalg.norm(logits) indices = gn.utils_tf.sparse_to_dense_indices(graph_features.n_node) logits = tf.scatter_nd(indices, logits, tf.shape(node_mask)) logits = tf.where(tf.equal(node_mask, 1), logits, tf.fill(tf.shape(node_mask), -INF)) return logits, log_vals
def get_node_features( self, graph: gn.graphs.GraphsTuple) -> gn.graphs.GraphsTuple: aggregator = self.get_edge_to_node_aggregator() edge_to_v_agg = graph.replace(nodes=aggregator(graph)) globs_node = graph.replace( nodes=gn.blocks.broadcast_globals_to_nodes(graph, name='g_to_n')) return gn.utils_tf.concat([graph, edge_to_v_agg, globs_node], axis=1, name='concat_n_feats')
def get_global_features( self, graph: gn.graphs.GraphsTuple) -> gn.graphs.GraphsTuple: edge_to_global_agg = self.get_edge_to_global_aggregator() edge_to_glob = graph.replace(globals=edge_to_global_agg(graph)) node_to_global_agg = self.get_node_to_global_aggregator() node_to_glob = graph.replace(globals=node_to_global_agg(graph)) return gn.utils_tf.concat([graph, edge_to_glob, node_to_glob], axis=1, name='concat_g_feats')
def _convolve(self, graph_features: gn.graphs.GraphsTuple): for i in range(self.n_prop_layers): with tf.variable_scope('prop_layer_%d' % i): new_graph_features = self._graphnet_models[i](graph_features) # residual connections graph_features = graph_features.replace( nodes=new_graph_features.nodes + graph_features.nodes, edges=new_graph_features.edges + graph_features.edges, globals=new_graph_features.globals + graph_features.globals) # layer norm graph_features = graph_features.replace( nodes=self._node_layer_norms[i](graph_features.nodes)) return graph_features
def get_edge_features( graph: gn.graphs.GraphsTuple) -> gn.graphs.GraphsTuple: senders = graph.replace( edges=gn.blocks.broadcast_sender_nodes_to_edges(graph, name='sn_to_e')) receivers = graph.replace( edges=gn.blocks.broadcast_receiver_nodes_to_edges(graph, name='rn_to_e')) nodes = graph.replace(edges=0.5 * (senders.edges + receivers.edges)) globs = graph.replace( edges=gn.blocks.broadcast_globals_to_edges(graph, name='g_to_e')) return gn.utils_tf.concat([graph, nodes, globs], axis=1, name='concat_e_feats')
def _attn_convolve(self, graph_features: gn.graphs.GraphsTuple): num_heads = self.config.num_heads key_size = self.config.key_size value_size = self.config.node_embed_dim for i in range(self._n_prop_layers): with tf.variable_scope('attention'): nodes = graph_features.nodes qkv_size = 2 * key_size + value_size # total_size = qkv_size * num_heads # denote as F # [total_num_nodes, d] => [total_num_nodes, F] qkv_flat = self._attention_dense_layers[i](nodes) qkv = tf.reshape(qkv_flat, [-1, num_heads, qkv_size]) # q => [total_num_nodes, num_heads, key_size] # k => [total_num_nodes, num_heads, key_size] # v => [total_num_nodes, num_heads, value_size] q, k, v = tf.split(qkv, [key_size, key_size, value_size], -1) with tf.variable_scope('prop_layer_%d' % i): new_graph_features = self._graphnet_models[i](v, k, q, graph_features) # residual connections graph_features = graph_features.replace( nodes=new_graph_features.nodes + graph_features.nodes, edges=new_graph_features.edges + graph_features.edges, globals=new_graph_features.globals + graph_features.globals)
def _encode(self, graph_features: gn.graphs.GraphsTuple, var_type_mask, constraint_type_mask, obj_type_mask): nodes = graph_features.nodes node_indices = gn.utils_tf.sparse_to_dense_indices( graph_features.n_node) l = [var_type_mask, constraint_type_mask, obj_type_mask] for i, mask in enumerate(l): mask = tf.reshape(mask, [-1, tf.shape(mask)[-1]]) l[i] = tf.gather_nd(params=mask, indices=node_indices) var_type_mask, constraint_type_mask, obj_type_mask = l # TODO(arc): remove feature padding from nodes. nodes = tf.where( tf.equal(var_type_mask, 1), self._var_encode_net(nodes), tf.where(tf.equal(constraint_type_mask, 1), self._constraint_encode_net(nodes), self._obj_encode_net(nodes))) col = tf.fill([infer_shape(nodes)[0]], 0) node_types = tf.where( tf.equal(var_type_mask, 1), col, tf.where(tf.equal(constraint_type_mask, 1), col + 1, col + 2)) node_types = tf.one_hot(node_types, 3) nodes = tf.concat([nodes, node_types], axis=-1) graph_features = graph_features.replace(nodes=nodes) graph_features = self._encode_net(graph_features) return graph_features
def _encode(self, input_graph: gn.graphs.GraphsTuple) -> gn.graphs.GraphsTuple: if input_graph.globals is not None: broadcasted_globals = gn.blocks.broadcast_globals_to_nodes( input_graph) input_graph = input_graph.replace(nodes=tf.concat( [input_graph.nodes, broadcasted_globals], axis=-1), globals=None) latent_graph_0 = self._encoder_network(input_graph) return latent_graph_0
def create_zero_graph(blue_print: gn.graphs.GraphsTuple, feature_dims: GraphFeatureDimensions): graph = blue_print.replace(nodes=None, edges=None, globals=None) graph = gn.utils_tf.set_zero_edge_features(graph, edge_size=feature_dims.edges, dtype=tf.float64) graph = gn.utils_tf.set_zero_node_features(graph, node_size=feature_dims.nodes, dtype=tf.float64) graph = gn.utils_tf.set_zero_global_features( graph, global_size=feature_dims.globals, dtype=tf.float64) return graph
def get_value(self, graph_features: gn.graphs.GraphsTuple): """ graph_embeddings: Message propagated graph embeddings. Use self.compute_graph_embeddings to compute and cache these to use with different network heads for value, policy etc. """ with tf.variable_scope('value_network'): agg = gn.blocks.NodesToGlobalsAggregator(tf.unsorted_segment_mean)( graph_features.replace( nodes=self.value_summarize(graph_features.nodes))) value = tf.concat([agg, graph_features.globals], axis=-1) return tf.squeeze(self.value_torso_2(value), axis=-1)
def _encode( self, input_graph: gn.graphs.GraphsTuple) -> gn.graphs.GraphsTuple: """Encodes the input graph features into a latent graph.""" # Copy the globals to all of the nodes, if applicable. if input_graph.globals is not None: broadcasted_globals = gn.blocks.broadcast_globals_to_nodes(input_graph) input_graph = input_graph.replace( nodes=tf.concat([input_graph.nodes, broadcasted_globals], axis=-1), globals=None) # Encode the node and edge features. latent_graph_0 = self._encoder_network(input_graph) return latent_graph_0
def _convolve(self, graph_features: gn.graphs.GraphsTuple): for i in range(self.n_prop_layers): with tf.variable_scope('prop_layer_%d' % i): # one round of message passing new_graph_features = self._graphnet_models[i](graph_features) # residual connections graph_features = graph_features.replace( nodes=new_graph_features.nodes + graph_features.nodes, edges=new_graph_features.edges + graph_features.edges, # residual connection not needed for globals, # since the current global_model_fn is identity globals=new_graph_features.globals) return graph_features
def get_auxiliary_loss(self, graph_features: gn.graphs.GraphsTuple, obs): """ Returns a prediction for each node. This is useful for supervised node labelling/prediction tasks. """ node_mask = obs['node_mask'] # broadcast globals and attach them to node features graph_features = graph_features.replace(nodes=tf.concat([ graph_features.nodes, gn.blocks.broadcast_globals_to_nodes(graph_features) ], axis=-1)) # get logits over nodes logits = self.supervised_prediction_torso(graph_features.nodes) # remove the final singleton dimension logits = tf.squeeze(logits, axis=-1) indices = gn.utils_tf.sparse_to_dense_indices(graph_features.n_node) preds = tf.scatter_nd(indices, logits, tf.shape(node_mask)) var_type_mask = obs['var_type_mask'] auxiliary_loss = tf.reduce_mean( tf.boolean_mask((preds - obs['optimal_solution'])**2, tf.cast(var_type_mask, tf.bool))) return auxiliary_loss