def _mask_attn_weights(self, attn_weights): """ Masks the attention weights :param attn_weights: The attention weights - [batch, nb_head, seq_len, seq_len + past_length] :return: A tensor of 0 and 1. of the same shape and dtype as attn_weights """ seq_len = array_ops.shape(attn_weights)[-2] total_len = array_ops.shape(attn_weights)[-1] # 1) Creating the attention mask matrix (with the lower triangle set to 1. on the right) # e.g. if seq_len == 3, and total_len == 10 # the attention mask would be: - [seq_len, total_len] # [[1., 1., 1., 1., 1., 1., 1., 1., 0., 0.], # [1., 1., 1., 1., 1., 1., 1., 1., 1., 0.], # [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]] num_lower = math_ops.cast(-1, dtypes.int32) num_upper = total_len - seq_len attn_mask = gen_array_ops.matrix_band_part( array_ops.ones([seq_len, total_len]), num_lower, num_upper) # No past_attentions/context - We just add two leading dimensions to attn_mask and can return it if self._past_seq_lengths is None: return attn_mask[None, None, :, :] # If we have a context with varying sequence length, we also need to mask the items after the end of sequence # e.g. # [[1., 1., 1., 0., 0., 0., 0., 1., 1., 1.], # => length of 3 (padded to 7) + seq_len of 3 # [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], # => length of 7 (padded to 7) + seq_len of 3 # [1., 1., 1., 1., 1., 0., 0., 1., 1., 1.]] # => length of 5 (padded to 7) + seq_len of 3 # # The resulting attention mask would be the product of the two. # [[1., 1., 1., 0., 0., 0., 0., 1., 0., 0.], # [1., 1., 1., 1., 1., 1., 1., 1., 1., 0.], # [1., 1., 1., 1., 1., 0., 0., 1., 1., 1.]] seq_mask = array_ops.sequence_mask( self._past_seq_lengths, dtype=dtypes.float32) # [b, max_len] seq_mask = pad_axis(seq_mask, axis=-1, min_size=total_len) # [b, total_len] # Returning the multiplication of the two masks return gen_math_ops.mul(attn_mask[None, None, :, :], seq_mask[:, None, None, :]) # [b, nb_heads, seq, total]
def _build_policy_initial(self): """ Builds the policy model (initial step) """ from diplomacy_research.utils.tensorflow import tf from diplomacy_research.models.layers.initializers import uniform from diplomacy_research.utils.tensorflow import build_sparse_batched_tensor, pad_axis, to_float, to_bool if not self.placeholders: self.placeholders = self.get_placeholders() # Quick function to retrieve hparams and placeholders and function shorthands hps = lambda hparam_name: self.hparams[hparam_name] pholder = lambda placeholder_name: self.placeholders[placeholder_name] # Training loop with tf.variable_scope('policy', reuse=tf.AUTO_REUSE): with tf.device(self.cluster_config.worker_device if self.cluster_config else None): # Features board_state = to_float(self.features['board_state']) # tf.flt32 - (b, NB_NODES, NB_FEATURES) board_alignments = to_float(self.features['board_alignments']) # (b, NB_NODES * len) decoder_inputs = self.features['decoder_inputs'] # tf.int32 - (b, <= 1 + TOK/ORD * NB_SCS) decoder_lengths = self.features['decoder_lengths'] # tf.int32 - (b,) current_power = self.features['current_power'] # tf.int32 - (b,) current_season = self.features['current_season'] # tf.int32 - (b,) dropout_rates = self.features['dropout_rate'] # tf.flt32 - (b,) # Batch size batch_size = tf.shape(board_state)[0] # Reshaping board alignments board_alignments = tf.reshape(board_alignments, [batch_size, -1, NB_NODES]) board_alignments /= tf.math.maximum(1., tf.reduce_sum(board_alignments, axis=-1, keepdims=True)) # Building decoder mask decoder_mask_indices = self.features['decoder_mask_indices'] # tf.int64 - (b, 3 * len) decoder_mask_shape = self.proto_fields['decoder_mask'].shape # Overriding dropout_rates if pholder('dropout_rate') > 0 dropout_rates = tf.cond(tf.greater(pholder('dropout_rate'), 0.), true_fn=lambda: tf.zeros_like(dropout_rates) + pholder('dropout_rate'), false_fn=lambda: dropout_rates) # Padding inputs board_alignments = pad_axis(board_alignments, axis=1, min_size=tf.reduce_max(decoder_lengths)) decoder_inputs = pad_axis(decoder_inputs, axis=-1, min_size=2) decoder_mask_indices = pad_axis(decoder_mask_indices, axis=-1, min_size=len(decoder_mask_shape)) # Reshaping to (b, len, 3) # decoder_mask is -- tf.bool (batch, TOK/ORD * NB_SC, VOCAB_SIZE, VOCAB_SIZE) decoder_mask_indices = tf.reshape(decoder_mask_indices, [batch_size, -1, len(decoder_mask_shape)]) decoder_mask = build_sparse_batched_tensor(decoder_mask_indices, value=True, dtype=tf.bool, dense_shape=decoder_mask_shape) # Making sure all RNN lengths are at least 1 # No need to trim, because the fields are variable length raw_decoder_lengths = decoder_lengths decoder_lengths = tf.math.maximum(1, decoder_lengths) # Placeholders decoder_type = tf.reduce_max(pholder('decoder_type')) is_training = pholder('is_training') # Computing FiLM Gammas and Betas with tf.variable_scope('film_scope'): power_embedding = uniform(name='power_embedding', shape=[NB_POWERS, hps('power_emb_size')], scale=1.) current_power_mask = tf.one_hot(current_power, NB_POWERS, dtype=tf.float32) current_power_embedding = tf.reduce_sum(power_embedding[None] * current_power_mask[:, :, None], axis=1) # (b, power_emb) film_embedding_input = current_power_embedding # Also conditioning on current_season season_embedding = uniform(name='season_embedding', shape=[NB_SEASONS, hps('season_emb_size')], scale=1.) current_season_mask = tf.one_hot(current_season, NB_SEASONS, dtype=tf.float32) current_season_embedding = tf.reduce_sum(season_embedding[None] # (b,season_emb) * current_season_mask[:, :, None], axis=1) film_embedding_input = tf.concat([film_embedding_input, current_season_embedding], axis=1) film_output_dims = [hps('gcn_size')] * (hps('nb_graph_conv') - 1) + [hps('attn_size')] film_weights = tf.layers.Dense(units=2 * sum(film_output_dims), # (b, 1, 750) use_bias=True, activation=None)(film_embedding_input)[:, None, :] film_gammas, film_betas = tf.split(film_weights, 2, axis=2) # (b, 1, 750) film_gammas = tf.split(film_gammas, film_output_dims, axis=2) film_betas = tf.split(film_betas, film_output_dims, axis=2) # Storing as temporary output self.add_output('_board_state_conv_film_gammas', film_gammas) self.add_output('_board_state_conv_film_betas', film_betas) # Creating graph convolution with tf.variable_scope('graph_conv_scope'): assert hps('nb_graph_conv') >= 2 # Encoding board state board_state_0yr_conv = self.encode_board(board_state, name='board_state_conv') board_state_conv = self.get_board_state_conv(board_state_0yr_conv, is_training) # Creating word embedding vector (to embed word_ix) # Embeddings needs to be cached locally on the worker, otherwise TF can't compute their gradients with tf.variable_scope('word_embedding_scope'): # embedding: (voc_size, 256) caching_device = self.cluster_config.caching_device if self.cluster_config else None word_embedding = uniform(name='word_embedding', shape=[VOCABULARY_SIZE, hps('word_emb_size')], scale=1., caching_device=caching_device) # Building output tags outputs = {'batch_size': batch_size, 'board_alignments': board_alignments, 'decoder_inputs': decoder_inputs, 'decoder_mask': decoder_mask, 'decoder_type': decoder_type, 'raw_decoder_lengths': raw_decoder_lengths, 'decoder_lengths': decoder_lengths, 'board_state_conv': board_state_conv, 'board_state_0yr_conv': board_state_0yr_conv, 'word_embedding': word_embedding, 'in_retreat_phase': tf.math.logical_and( # 1) board not empty, 2) disl. units present tf.reduce_sum(board_state[:], axis=[1, 2]) > 0, tf.math.logical_not(to_bool(tf.reduce_min(board_state[:, :, 23], -1))))} # Adding to graph self.add_meta_information(outputs)
def _build_policy_initial(self): """ Builds the policy model (initial step) """ from diplomacy_research.utils.tensorflow import tf from diplomacy_research.models.layers.initializers import uniform from diplomacy_research.utils.tensorflow import build_sparse_batched_tensor, pad_axis, to_float, to_bool if not self.placeholders: self.placeholders = self.get_placeholders() # Quick function to retrieve hparams and placeholders and function shorthands hps = lambda hparam_name: self.hparams[hparam_name] pholder = lambda placeholder_name: self.placeholders[placeholder_name] # Training loop with tf.variable_scope('policy', reuse=tf.AUTO_REUSE): with tf.device(self.cluster_config.worker_device if self.cluster_config else None): # Features board_state = to_float(self.features['board_state']) # tf.flt32 - (b, NB_NODES, NB_FEATURES) decoder_inputs = self.features['decoder_inputs'] # tf.int32 - (b, <= 1 + TOK/ORD * NB_SCS) decoder_lengths = self.features['decoder_lengths'] # tf.int32 - (b,) dropout_rates = self.features['dropout_rate'] # tf.flt32 - (b,) # Batch size batch_size = tf.shape(board_state)[0] # Building decoder mask decoder_mask_indices = self.features['decoder_mask_indices'] # tf.int64 - (b, 3 * len) decoder_mask_shape = self.proto_fields['decoder_mask'].shape # Overriding dropout_rates if pholder('dropout_rate') > 0 dropout_rates = tf.cond(tf.greater(pholder('dropout_rate'), 0.), true_fn=lambda: tf.zeros_like(dropout_rates) + pholder('dropout_rate'), false_fn=lambda: dropout_rates) # Padding inputs decoder_inputs = pad_axis(decoder_inputs, axis=-1, min_size=2) decoder_mask_indices = pad_axis(decoder_mask_indices, axis=-1, min_size=len(decoder_mask_shape)) # Reshaping to (b, len, 3) # decoder_mask is -- tf.bool (batch, TOK/ORD * NB_SC, VOCAB_SIZE, VOCAB_SIZE) decoder_mask_indices = tf.reshape(decoder_mask_indices, [batch_size, -1, len(decoder_mask_shape)]) decoder_mask = build_sparse_batched_tensor(decoder_mask_indices, value=True, dtype=tf.bool, dense_shape=decoder_mask_shape) # Making sure all RNN lengths are at least 1 # No need to trim, because the fields are variable length raw_decoder_lengths = decoder_lengths decoder_lengths = tf.math.maximum(1, decoder_lengths) # Placeholders decoder_type = tf.reduce_max(pholder('decoder_type')) # Creating word embedding vector (to embed word_ix) # Embeddings needs to be cached locally on the worker, otherwise TF can't compute their gradients with tf.variable_scope('word_embedding_scope'): # embedding: (voc_size, 256) caching_device = self.cluster_config.caching_device if self.cluster_config else None word_embedding = uniform(name='word_embedding', shape=[VOCABULARY_SIZE, hps('word_emb_size')], scale=1., caching_device=caching_device) # Building output tags outputs = {'batch_size': batch_size, 'decoder_inputs': decoder_inputs, 'decoder_mask': decoder_mask, 'decoder_type': decoder_type, 'raw_decoder_lengths': raw_decoder_lengths, 'decoder_lengths': decoder_lengths, 'board_state_conv': tf.zeros([batch_size, NB_NODES, 0], dtype=tf.float32), 'board_state_0yr_conv': tf.zeros([batch_size, NB_NODES, 0], dtype=tf.float32), 'word_embedding': word_embedding, 'in_retreat_phase': tf.math.logical_and( # 1) board not empty, 2) disl. units present tf.reduce_sum(board_state[:], axis=[1, 2]) > 0, tf.math.logical_not(to_bool(tf.reduce_min(board_state[:, :, 23], -1))))} # Adding to graph self.add_meta_information(outputs)
def _build_policy_initial(self): """ Builds the policy model (initial step) """ from diplomacy_research.utils.tensorflow import tf from diplomacy_research.models.layers.initializers import uniform from diplomacy_research.utils.tensorflow import pad_axis, to_int32, to_float, to_bool if not self.placeholders: self.placeholders = self.get_placeholders() # Quick function to retrieve hparams and placeholders and function shorthands hps = lambda hparam_name: self.hparams[hparam_name] pholder = lambda placeholder_name: self.placeholders[placeholder_name] # Training loop with tf.variable_scope('policy', reuse=tf.AUTO_REUSE): with tf.device(self.cluster_config.worker_device if self.cluster_config else None): # Features board_state = to_float(self.features['board_state']) # tf.flt32 - (b, NB_NODES, NB_FEATURES) decoder_inputs = self.features['decoder_inputs'] # tf.int32 - (b, <= 1 + NB_SCS) decoder_lengths = self.features['decoder_lengths'] # tf.int32 - (b,) candidates = self.features['candidates'] # tf.int32 - (b, nb_locs * MAX_CANDIDATES) dropout_rates = self.features['dropout_rate'] # tf.flt32 - (b,) # Batch size batch_size = tf.shape(board_state)[0] # Overriding dropout_rates if pholder('dropout_rate') > 0 dropout_rates = tf.cond(tf.greater(pholder('dropout_rate'), 0.), true_fn=lambda: tf.zeros_like(dropout_rates) + pholder('dropout_rate'), false_fn=lambda: dropout_rates) # Padding decoder_inputs and candidates decoder_inputs = pad_axis(decoder_inputs, axis=-1, min_size=2) candidates = pad_axis(candidates, axis=-1, min_size=MAX_CANDIDATES) # Making sure all RNN lengths are at least 1 # No need to trim, because the fields are variable length raw_decoder_lengths = decoder_lengths decoder_lengths = tf.math.maximum(1, decoder_lengths) # Placeholders decoder_type = tf.reduce_max(pholder('decoder_type')) is_training = pholder('is_training') # Reshaping candidates candidates = tf.reshape(candidates, [batch_size, -1, MAX_CANDIDATES]) candidates = candidates[:, :tf.reduce_max(decoder_lengths), :] # tf.int32 - (b, nb_locs, MAX_CAN) # Creating graph convolution with tf.variable_scope('graph_conv_scope'): assert hps('nb_graph_conv') >= 2 # Encoding board state board_state_0yr_conv = self.encode_board(board_state, name='board_state_conv') board_state_conv = self.get_board_state_conv(board_state_0yr_conv, is_training) # Creating order embedding vector (to embed order_ix) # Embeddings needs to be cached locally on the worker, otherwise TF can't compute their gradients with tf.variable_scope('order_embedding_scope'): # embedding: (order_vocab_size, 64) caching_device = self.cluster_config.caching_device if self.cluster_config else None partitioner = tf.fixed_size_partitioner(NB_PARTITIONS) if hps('use_partitioner') else None order_embedding = uniform(name='order_embedding', shape=[ORDER_VOCABULARY_SIZE, hps('order_emb_size')], scale=1., partitioner=partitioner, caching_device=caching_device) # Creating candidate embedding with tf.variable_scope('candidate_embedding_scope'): # embedding: (order_vocab_size, 64) caching_device = self.cluster_config.caching_device if self.cluster_config else None partitioner = tf.fixed_size_partitioner(NB_PARTITIONS) if hps('use_partitioner') else None candidate_embedding = uniform(name='candidate_embedding', shape=[ORDER_VOCABULARY_SIZE, hps('lstm_size') + 1], scale=1., partitioner=partitioner, caching_device=caching_device) # Trimming to the maximum number of candidates candidate_lengths = tf.reduce_sum(to_int32(tf.math.greater(candidates, PAD_ID)), -1) # int32 - (b,) max_candidate_length = tf.math.maximum(1, tf.reduce_max(candidate_lengths)) candidates = candidates[:, :, :max_candidate_length] # Building output tags outputs = {'batch_size': batch_size, 'decoder_inputs': decoder_inputs, 'decoder_type': decoder_type, 'raw_decoder_lengths': raw_decoder_lengths, 'decoder_lengths': decoder_lengths, 'board_state_conv': board_state_conv, 'board_state_0yr_conv': board_state_0yr_conv, 'order_embedding': order_embedding, 'candidate_embedding': candidate_embedding, 'candidates': candidates, 'max_candidate_length': max_candidate_length, 'in_retreat_phase': tf.math.logical_and( # 1) board not empty, 2) disl. units present tf.reduce_sum(board_state[:], axis=[1, 2]) > 0, tf.math.logical_not(to_bool(tf.reduce_min(board_state[:, :, 23], -1))))} # Adding to graph self.add_meta_information(outputs)
def _build_policy_initial(self): """ Builds the policy model (initial step) """ from diplomacy_research.utils.tensorflow import tf from diplomacy_research.models.layers.initializers import uniform from diplomacy_research.utils.tensorflow import pad_axis, to_int32, to_float, to_bool if not self.placeholders: self.placeholders = self.get_placeholders() # Quick function to retrieve hparams and placeholders and function shorthands hps = lambda hparam_name: self.hparams[hparam_name] pholder = lambda placeholder_name: self.placeholders[placeholder_name] # Training loop with tf.variable_scope('policy', reuse=tf.AUTO_REUSE): with tf.device(self.cluster_config.worker_device if self. cluster_config else None): # Features board_state = to_float( self.features['board_state'] ) # tf.flt32 - (b, NB_NODES, NB_FEATURES) board_alignments = to_float( self.features['board_alignments']) # (b, NB_NODES * len) prev_orders_state = to_float( self.features['prev_orders_state'] ) # (b, NB_PRV_OD, NB_ND, NB_OD_FT) decoder_inputs = self.features[ 'decoder_inputs'] # tf.int32 - (b, <= 1 + NB_SCS) decoder_lengths = self.features[ 'decoder_lengths'] # tf.int32 - (b,) candidates = self.features[ 'candidates'] # tf.int32 - (b, nb_locs * MAX_CANDIDATES) current_power = self.features[ 'current_power'] # tf.int32 - (b,) current_season = self.features[ 'current_season'] # tf.int32 - (b,) dropout_rates = self.features[ 'dropout_rate'] # tf.flt32 - (b,) # Batch size batch_size = tf.shape(board_state)[0] # Reshaping board alignments board_alignments = tf.reshape(board_alignments, [batch_size, -1, NB_NODES]) board_alignments /= tf.math.maximum( 1., tf.reduce_sum(board_alignments, axis=-1, keepdims=True)) # Overriding dropout_rates if pholder('dropout_rate') > 0 dropout_rates = tf.cond( tf.greater(pholder('dropout_rate'), 0.), true_fn=lambda: tf.zeros_like(dropout_rates) + pholder( 'dropout_rate'), false_fn=lambda: dropout_rates) # Padding decoder_inputs and candidates board_alignments = pad_axis( board_alignments, axis=1, min_size=tf.reduce_max(decoder_lengths)) decoder_inputs = pad_axis(decoder_inputs, axis=-1, min_size=2) candidates = pad_axis(candidates, axis=-1, min_size=MAX_CANDIDATES) # Making sure all RNN lengths are at least 1 # No need to trim, because the fields are variable length raw_decoder_lengths = decoder_lengths decoder_lengths = tf.math.maximum(1, decoder_lengths) # Placeholders decoder_type = tf.reduce_max(pholder('decoder_type')) is_training = pholder('is_training') # Reshaping candidates candidates = tf.reshape(candidates, [batch_size, -1, MAX_CANDIDATES]) candidates = candidates[:, :tf.reduce_max( decoder_lengths), :] # tf.int32 - (b, nb_locs, MAX_CAN) # Computing FiLM Gammas and Betas with tf.variable_scope('film_scope'): power_embedding = uniform( name='power_embedding', shape=[NB_POWERS, hps('power_emb_size')], scale=1.) current_power_mask = tf.one_hot(current_power, NB_POWERS, dtype=tf.float32) current_power_embedding = tf.reduce_sum( power_embedding[None] * current_power_mask[:, :, None], axis=1) # (b, power_emb) film_embedding_input = current_power_embedding # Also conditioning on current_season season_embedding = uniform( name='season_embedding', shape=[NB_SEASONS, hps('season_emb_size')], scale=1.) current_season_mask = tf.one_hot(current_season, NB_SEASONS, dtype=tf.float32) current_season_embedding = tf.reduce_sum( season_embedding[None] # (b,season_emb) * current_season_mask[:, :, None], axis=1) film_embedding_input = tf.concat( [film_embedding_input, current_season_embedding], axis=1) film_output_dims = [hps('gcn_size')] * ( hps('nb_graph_conv') - 1) + [hps('attn_size') // 2] # For board_state board_film_weights = tf.layers.Dense( units=2 * sum(film_output_dims), # (b, 1, 750) use_bias=True, activation=None)(film_embedding_input)[:, None, :] board_film_gammas, board_film_betas = tf.split( board_film_weights, 2, axis=2) # (b, 1, 750) board_film_gammas = tf.split(board_film_gammas, film_output_dims, axis=2) board_film_betas = tf.split(board_film_betas, film_output_dims, axis=2) # For prev_orders prev_ord_film_weights = tf.layers.Dense( units=2 * sum(film_output_dims), # (b, 1, 750) use_bias=True, activation=None)(film_embedding_input)[:, None, :] prev_ord_film_weights = tf.tile( prev_ord_film_weights, [NB_PREV_ORDERS, 1, 1]) # (n_pr, 1, 750) prev_ord_film_gammas, prev_ord_film_betas = tf.split( prev_ord_film_weights, 2, axis=2) prev_ord_film_gammas = tf.split(prev_ord_film_gammas, film_output_dims, axis=2) prev_ord_film_betas = tf.split(prev_ord_film_betas, film_output_dims, axis=2) # Storing as temporary output self.add_output('_board_state_conv_film_gammas', board_film_gammas) self.add_output('_board_state_conv_film_betas', board_film_betas) self.add_output('_prev_orders_conv_film_gammas', prev_ord_film_gammas) self.add_output('_prev_orders_conv_film_betas', prev_ord_film_betas) # Creating graph convolution with tf.variable_scope('graph_conv_scope'): assert hps('nb_graph_conv') >= 2 assert hps('attn_size') % 2 == 0 # Encoding board state board_state_0yr_conv = self.encode_board( board_state, name='board_state_conv') # Encoding prev_orders prev_orders_state = tf.reshape(prev_orders_state, [ batch_size * NB_PREV_ORDERS, NB_NODES, NB_ORDERS_FEATURES ]) prev_ord_conv = self.encode_board(prev_orders_state, name='prev_orders_conv') # Splitting back into (b, nb_prev, NB_NODES, attn_size // 2) # Reducing the prev ord conv using avg prev_ord_conv = tf.reshape(prev_ord_conv, [ batch_size, NB_PREV_ORDERS, NB_NODES, hps('attn_size') // 2 ]) prev_ord_conv = tf.reduce_mean(prev_ord_conv, axis=1) # Concatenating the current board conv with the prev ord conv # The final board_state_conv should be of dimension (b, NB_NODE, attn_size) board_state_conv = self.get_board_state_conv( board_state_0yr_conv, is_training, prev_ord_conv) # Creating order embedding vector (to embed order_ix) # Embeddings needs to be cached locally on the worker, otherwise TF can't compute their gradients with tf.variable_scope('order_embedding_scope'): # embedding: (order_vocab_size, 64) caching_device = self.cluster_config.caching_device if self.cluster_config else None partitioner = tf.fixed_size_partitioner( NB_PARTITIONS) if hps('use_partitioner') else None order_embedding = uniform( name='order_embedding', shape=[ORDER_VOCABULARY_SIZE, hps('order_emb_size')], scale=1., partitioner=partitioner, caching_device=caching_device) # Creating candidate embedding with tf.variable_scope('candidate_embedding_scope'): # embedding: (order_vocab_size, 64) caching_device = self.cluster_config.caching_device if self.cluster_config else None partitioner = tf.fixed_size_partitioner( NB_PARTITIONS) if hps('use_partitioner') else None candidate_embedding = uniform( name='candidate_embedding', shape=[ORDER_VOCABULARY_SIZE, hps('lstm_size') + 1], scale=1., partitioner=partitioner, caching_device=caching_device) # Trimming to the maximum number of candidates candidate_lengths = tf.reduce_sum( to_int32(tf.math.greater(candidates, PAD_ID)), -1) # int32 - (b,) max_candidate_length = tf.math.maximum( 1, tf.reduce_max(candidate_lengths)) candidates = candidates[:, :, :max_candidate_length] # Building output tags outputs = { 'batch_size': batch_size, 'board_alignments': board_alignments, 'decoder_inputs': decoder_inputs, 'decoder_type': decoder_type, 'raw_decoder_lengths': raw_decoder_lengths, 'decoder_lengths': decoder_lengths, 'board_state_conv': board_state_conv, 'board_state_0yr_conv': board_state_0yr_conv, 'prev_ord_conv': prev_ord_conv, 'order_embedding': order_embedding, 'candidate_embedding': candidate_embedding, 'candidates': candidates, 'max_candidate_length': max_candidate_length, 'in_retreat_phase': tf.math.logical_and( # 1) board not empty, 2) disl. units present tf.reduce_sum(board_state[:], axis=[1, 2]) > 0, tf.math.logical_not( to_bool(tf.reduce_min(board_state[:, :, 23], -1)))) } # Adding to graph self.add_meta_information(outputs)