def pack_graph_embeddings(self, graph_features, ge): # Pack graph embeddings to be transported from actors to learners. # For now we only handle BipartiteGraphs if set(graph_features.keys()) != set( gn.graphs.BipartiteGraphsTuple._fields): raise Exception(f'Not handled {graph_features.keys()}') graph_features = gn.graphs.BipartiteGraphsTuple(**graph_features) left_nodes = tf.scatter_nd( gn.utils_tf.sparse_to_dense_indices(ge.n_left_nodes), ge.left_nodes, infer_shape(graph_features.left_nodes)[:2] + [infer_shape(ge.left_nodes)[-1]]) right_nodes = tf.scatter_nd( gn.utils_tf.sparse_to_dense_indices(ge.n_right_nodes), ge.right_nodes, infer_shape(graph_features.right_nodes)[:2] + [infer_shape(ge.right_nodes)[-1]]) bs = infer_shape(graph_features.left_nodes)[0] dummy_tensor = lambda: tf.fill(tf.expand_dims(bs, 0), 0) return gn.graphs.BipartiteGraphsTuple( left_nodes=left_nodes, right_nodes=right_nodes, globals=ge.globals, n_left_nodes=ge.n_left_nodes, n_right_nodes=ge.n_right_nodes, # edges is dummy tensor -- not used edges=tf.expand_dims(tf.expand_dims(dummy_tensor(), -1), -1), senders=dummy_tensor(), receivers=dummy_tensor(), n_edge=dummy_tensor(), )
def get_value(self, graph_features: gn.graphs.GraphsTuple): """ graph_embeddings: Message propagated graph embeddings. Use self.compute_graph_embeddings to compute and cache these to use with different network heads for value, policy etc. """ with tf.variable_scope('value_network'): left_nodes = graph_features.left_nodes right_nodes = graph_features.right_nodes num_graphs = infer_shape(graph_features.n_left_nodes)[0] # Apply mlp to left and right nodes. left_nodes = self.value_torso_1_left(left_nodes) right_nodes = self.value_torso_1_right(right_nodes) # aggregate left and right nodes seperately. left_indices = gn.utils_tf.repeat(tf.range(num_graphs), graph_features.n_left_nodes, axis=0) left_agg = tf.unsorted_segment_mean(left_nodes[:infer_shape(left_indices)[0]], left_indices, num_graphs) right_indices = gn.utils_tf.repeat(tf.range(num_graphs), graph_features.n_right_nodes, axis=0) right_agg = tf.unsorted_segment_mean(right_nodes[:infer_shape(right_indices)[0]], right_indices, num_graphs) value = tf.concat([left_agg, right_agg, graph_features.globals], axis=-1) return tf.squeeze(self.value_torso_2(value), axis=-1)
def _add_stop_action(self, graph_features): with tf.variable_scope('switch_network'): # aggregate left and right nodes seperately. left_nodes = graph_features.left_nodes num_graphs = infer_shape(graph_features.n_left_nodes)[0] left_indices = gn.utils_tf.repeat(tf.range(num_graphs), graph_features.n_left_nodes, axis=0) left_agg = tf.unsorted_segment_mean(left_nodes[:infer_shape(left_indices)[0]], left_indices, num_graphs) switch_logits = tf.concat([left_agg, graph_features.globals], axis=-1) # Now calculate the logits for the switch binary action. switch_logits = tf.squeeze(self.switch_torso(switch_logits), axis=-1) return switch_logits
def get_node_embeddings(self, obs, graph_embeddings): # Returns embeddings of the nodes in the shape (B, N_max, d) graph_features = graph_embeddings broadcasted_globals = gn.utils_tf.repeat(graph_features.globals, graph_features.n_left_nodes, axis=0) left_nodes = graph_features.left_nodes[:infer_shape(broadcasted_globals)[0]] left_nodes = tf.concat([left_nodes, broadcasted_globals], axis=-1) indices = gn.utils_tf.sparse_to_dense_indices(graph_features.n_left_nodes) left_nodes = tf.scatter_nd(indices, left_nodes, infer_shape(obs['node_mask']) + [infer_shape(left_nodes)[-1]]) return left_nodes, graph_features.n_left_nodes
def get_value(self, graph_embeddings, obs): # Give this the optimal solution as well in the inputs? if self.config.use_mlp_value_func: xs, _ = self.get_node_embeddings(obs, graph_embeddings) # (N, L1, d) xs = tf.reshape(xs, [infer_shape(xs)[0], -1]) with tf.variable_scope('value_network'): self.value = snt.nets.MLP( [32, 32, 1], initializers=dict(w=glorot_uniform(self.seed), b=initializers.init_ops.Constant(0.0)), activate_final=False, activation=get_activation_from_str('relu')) return tf.squeeze(self.value(xs), axis=-1) else: return self._gcn_model.get_value(graph_embeddings)
def get_auxiliary_loss(self, graph_features: gn.graphs.GraphsTuple, obs): """ Returns a prediction for each node. This is useful for supervised node labelling/prediction tasks. """ # broadcast globals and attach them to node features broadcasted_globals = gn.utils_tf.repeat(graph_features.globals, graph_features.n_left_nodes, axis=0) left_nodes = tf.concat([graph_features.left_nodes, broadcasted_globals], axis=-1) # get logits over nodes preds = self.supervised_prediction_torso(left_nodes) # remove the final singleton dimension preds = tf.squeeze(preds, axis=-1) indices = gn.utils_tf.sparse_to_dense_indices(graph_features.n_left_nodes) opt_sol = tf.gather_nd(indices, obs['optimal_solution'], infer_shape(preds)) auxiliary_loss = tf.reduce_mean((preds - opt_sol)**2) return auxiliary_loss
def body_fn(i, decoder_inputs, logitss, actions, node_mask): dec = self._trans.decode(decoder_inputs, memory, src_masks, True) # (N, T2, d_model) # Q -> (N, d_model) Q = self._out_projection_mlp(dec[:, -1]) # Q -> (N, 1, d_model) Q = tf.expand_dims(Q, 1) # dot product outputs = tf.matmul(Q, tf.transpose(xs, [0, 2, 1])) # (N, 1, T1) outputs = tf.squeeze(outputs, 1) # (N, T1) # scale outputs /= (Q.get_shape().as_list()[-1]**0.5) # [N, T1] logits = tf.where(node_mask, outputs, tf.fill(tf.shape(node_mask), np.float32(-1e9))) # dont sample if sampled_actions is provided. assert sampled_actions is None act = sample_from_logits(logits, self.seed) # (N,) logitss = tf.concat( [logitss, tf.expand_dims(logits, axis=1)], axis=1) actions = tf.concat([actions, tf.expand_dims(act, axis=1)], axis=-1) # update node_masks to remove the current selected node for the next # decoding iteration. indices = tf.stack([tf.range(N), act], axis=-1) action_mask = tf.scatter_nd(indices, tf.ones((N, ), tf.bool), infer_shape(node_mask)) node_mask = tf.logical_and(node_mask, tf.logical_not(action_mask)) embs = tf.gather_nd(xs, indices) # (N, d) embs = tf.expand_dims(embs, 1) # (N, 1, d) decoder_inputs = tf.concat((decoder_inputs, embs), 1) return i + 1, decoder_inputs, logitss, actions, node_mask
def calc_logits(decoder_inputs, node_mask): # calculate logits with sampled actions. dec = self._trans.decode(decoder_inputs, memory, src_masks, True) # (N, T2, d_model) # Q -> (N, T2, d_model) Q = snt.BatchApply(self._out_projection_mlp)(dec) # dot product outputs = tf.matmul(Q, tf.transpose(xs, [0, 2, 1])) # (N, T2, T1) # scale outputs /= (Q.get_shape().as_list()[-1]**0.5) # (N, 1, T1) node_mask = tf.expand_dims(node_mask, 1) # (N, T2, T1) node_mask = tf.tile(node_mask, [1, infer_shape(outputs)[1], 1]) # [N, T2, T1] logits = tf.where( node_mask, outputs, tf.fill(tf.shape(node_mask), np.float32(-1e9))) return logits
def get_actions(self, graph_embeddings, obs, step_types=None, sampled_actions=None, log_features=False): # step_types: [T + 1, B] of stepTypes # If actions ar eprovided, then sampling is ommitted. """ Difference between src_masks, node_mask and action_mask src_masks: Indicates that certain node inputs to the encoder are from padding and should not be considered. node_masks: Sampling should happen only at these nodes during decoding phase. action_mask: Masks where the previous action has taken place. """ xs, n_node = self.get_node_embeddings(obs, graph_embeddings) # (N, L1, d) node_mask = tf.reshape(obs['node_mask'], infer_shape(xs)[:-1]) node_mask = tf.cast(node_mask, tf.bool) n_actions = tf.reshape(obs['n_actions'], (infer_shape(xs)[0], )) if sampled_actions is not None: # Remove some of the samples from the batch to match the size of the # actions. # This happens because the graph_embeddings are calculated for T + 1 # observations whereas actions are only avaiable for T steps. # (T + 1) * B -> T * B t_times_b = infer_shape(sampled_actions)[0] xs = xs[:t_times_b] n_node = n_node[:t_times_b] node_mask = node_mask[:t_times_b] n_actions = n_actions[:t_times_b] N = infer_shape(xs)[0] T1 = infer_shape(xs)[1] # T1 and L1 used interchangebly. # compute src_mask to remove padding nodes interfering. indices = gn.utils_tf.sparse_to_dense_indices(n_node) # src_masks -> (N, L1) src_masks = tf.scatter_nd(indices, tf.ones(infer_shape(indices)[:1], tf.bool), infer_shape(xs)[:2]) if sampled_actions is not None: log_features = dict() log_features.update( dict(xs=xs, src_masks=src_masks, step_types=step_types[:-1])) log_features.update(dict(actions=sampled_actions)) xs = snt.BatchApply(self._node_emb_mlp)(xs) # (N, L1, d_model) memory = xs d = self.config.d_model def cond_fn(i, *_): # until all samples have n_actions sampled. return i < tf.reduce_max(n_actions) def body_fn(i, decoder_inputs, logitss, actions, node_mask): dec = self._trans.decode(decoder_inputs, memory, src_masks, True) # (N, T2, d_model) # Q -> (N, d_model) Q = self._out_projection_mlp(dec[:, -1]) # Q -> (N, 1, d_model) Q = tf.expand_dims(Q, 1) # dot product outputs = tf.matmul(Q, tf.transpose(xs, [0, 2, 1])) # (N, 1, T1) outputs = tf.squeeze(outputs, 1) # (N, T1) # scale outputs /= (Q.get_shape().as_list()[-1]**0.5) # [N, T1] logits = tf.where(node_mask, outputs, tf.fill(tf.shape(node_mask), np.float32(-1e9))) # dont sample if sampled_actions is provided. assert sampled_actions is None act = sample_from_logits(logits, self.seed) # (N,) logitss = tf.concat( [logitss, tf.expand_dims(logits, axis=1)], axis=1) actions = tf.concat([actions, tf.expand_dims(act, axis=1)], axis=-1) # update node_masks to remove the current selected node for the next # decoding iteration. indices = tf.stack([tf.range(N), act], axis=-1) action_mask = tf.scatter_nd(indices, tf.ones((N, ), tf.bool), infer_shape(node_mask)) node_mask = tf.logical_and(node_mask, tf.logical_not(action_mask)) embs = tf.gather_nd(xs, indices) # (N, d) embs = tf.expand_dims(embs, 1) # (N, 1, d) decoder_inputs = tf.concat((decoder_inputs, embs), 1) return i + 1, decoder_inputs, logitss, actions, node_mask if sampled_actions is None: i = tf.constant(0) logitss = tf.constant([], shape=(N, 0, T1), dtype=tf.float32) actions = tf.constant([], shape=(N, 0), dtype=tf.int32) # Feed in zero embedding as the start sentinel. decoder_inputs = tf.zeros((N, 1, d), tf.float32) # (N, 1, d) i, _, logits, actions, _ = tf.while_loop( cond_fn, body_fn, (i, decoder_inputs, logitss, actions, node_mask), shape_invariants=( i.get_shape(), tf.TensorShape([N, None, d]), tf.TensorShape([N, None, T1]), tf.TensorShape([N, None]), node_mask.get_shape(), ), return_same_structure=True, back_prop=False, ) else: def calc_logits(decoder_inputs, node_mask): # calculate logits with sampled actions. dec = self._trans.decode(decoder_inputs, memory, src_masks, True) # (N, T2, d_model) # Q -> (N, T2, d_model) Q = snt.BatchApply(self._out_projection_mlp)(dec) # dot product outputs = tf.matmul(Q, tf.transpose(xs, [0, 2, 1])) # (N, T2, T1) # scale outputs /= (Q.get_shape().as_list()[-1]**0.5) # (N, 1, T1) node_mask = tf.expand_dims(node_mask, 1) # (N, T2, T1) node_mask = tf.tile(node_mask, [1, infer_shape(outputs)[1], 1]) # [N, T2, T1] logits = tf.where( node_mask, outputs, tf.fill(tf.shape(node_mask), np.float32(-1e9))) return logits decoder_inputs = tf.gather_nd(xs, tf.expand_dims(sampled_actions, -1), batch_dims=1) logits = calc_logits(decoder_inputs, node_mask) actions = sampled_actions # Finally logits -> (N, T2, T1) # Finally actions -> (N, T2) # now pad actions -> (N, max_k) and logits -> (N, max_k, T1) pad_d = obs['max_k'][0] - infer_shape(logits)[1] logits = tf.pad(logits, ([0, 0], [0, pad_d], [0, 0])) pad_d = obs['max_k'][0] - infer_shape(actions)[1] actions = tf.pad(actions, ([0, 0], [0, pad_d])) if log_features: return logits, actions, log_features return logits, actions