def _build_policy_final(self): """ Builds the policy model (final step) """ from diplomacy_research.utils.tensorflow import tf from diplomacy_research.models.layers.attention import StaticAttentionWrapper from diplomacy_research.models.layers.beam_decoder import DiverseBeamSearchDecoder from diplomacy_research.models.layers.decoder import MaskedBasicDecoder from diplomacy_research.models.layers.dropout import SeededDropoutWrapper from diplomacy_research.models.layers.dynamic_decode import dynamic_decode from diplomacy_research.models.policy.token_based.helper import CustomHelper, CustomBeamHelper from diplomacy_research.utils.tensorflow import cross_entropy, sequence_loss, to_int32, to_float, get_tile_beam # Quick function to retrieve hparams and placeholders and function shorthands hps = lambda hparam_name: self.hparams[hparam_name] pholder = lambda placeholder_name: self.placeholders[placeholder_name] # Training loop with tf.variable_scope('policy', reuse=tf.AUTO_REUSE): with tf.device(self.cluster_config.worker_device if self.cluster_config else None): # Features player_seeds = self.features['player_seed'] # tf.int32 - (b,) temperature = self.features['temperature'] # tf,flt32 - (b,) dropout_rates = self.features['dropout_rate'] # tf.flt32 - (b,) # Placeholders stop_gradient_all = pholder('stop_gradient_all') # Outputs (from initial steps) batch_size = self.outputs['batch_size'] board_alignments = self.outputs['board_alignments'] decoder_inputs = self.outputs['decoder_inputs'] decoder_mask = self.outputs['decoder_mask'] decoder_type = self.outputs['decoder_type'] raw_decoder_lengths = self.outputs['raw_decoder_lengths'] decoder_lengths = self.outputs['decoder_lengths'] board_state_conv = self.outputs['board_state_conv'] word_embedding = self.outputs['word_embedding'] # --- Decoding --- with tf.variable_scope('decoder_scope', reuse=tf.AUTO_REUSE): lstm_cell = tf.contrib.rnn.LSTMBlockCell(hps('lstm_size')) # decoder output to token decoder_output_layer = tf.layers.Dense(units=VOCABULARY_SIZE, activation=None, kernel_initializer=tf.random_normal_initializer, use_bias=True) # ======== Regular Decoding ======== # Applying dropout to input + attention and to output layer decoder_cell = SeededDropoutWrapper(cell=lstm_cell, seeds=player_seeds, input_keep_probs=1. - dropout_rates, output_keep_probs=1. - dropout_rates, variational_recurrent=hps('use_v_dropout'), input_size=hps('word_emb_size') + hps('attn_size'), dtype=tf.float32) # Apply attention over orderable location at each position decoder_cell = StaticAttentionWrapper(cell=decoder_cell, memory=board_state_conv, alignments=board_alignments, sequence_length=raw_decoder_lengths, output_attention=False) # Setting initial state decoder_init_state = decoder_cell.zero_state(batch_size, tf.float32) # ---- Helper ---- helper = CustomHelper(decoder_type=decoder_type, inputs=decoder_inputs[:, :-1], embedding=word_embedding, sequence_length=decoder_lengths, mask=decoder_mask, time_major=False, softmax_temperature=temperature) # ---- Decoder ---- sequence_mask = tf.sequence_mask(raw_decoder_lengths, maxlen=tf.reduce_max(decoder_lengths), dtype=tf.float32) maximum_iterations = TOKENS_PER_ORDER * NB_SUPPLY_CENTERS model_decoder = MaskedBasicDecoder(cell=decoder_cell, helper=helper, initial_state=decoder_init_state, output_layer=decoder_output_layer, extract_state=True) training_results, _, _ = dynamic_decode(decoder=model_decoder, output_time_major=False, maximum_iterations=maximum_iterations, swap_memory=hps('swap_memory')) global_vars_after_decoder = set(tf.global_variables()) # ======== Beam Search Decoding ======== tile_beam = get_tile_beam(hps('beam_width')) # Applying dropout to input + attention and to output layer decoder_cell = SeededDropoutWrapper(cell=lstm_cell, seeds=tile_beam(player_seeds), input_keep_probs=tile_beam(1. - dropout_rates), output_keep_probs=tile_beam(1. - dropout_rates), variational_recurrent=hps('use_v_dropout'), input_size=hps('word_emb_size') + hps('attn_size'), dtype=tf.float32) # Apply attention over orderable location at each position decoder_cell = StaticAttentionWrapper(cell=decoder_cell, memory=tile_beam(board_state_conv), alignments=tile_beam(board_alignments), sequence_length=tile_beam(raw_decoder_lengths), output_attention=False) # Setting initial state decoder_init_state = decoder_cell.zero_state(batch_size * hps('beam_width'), tf.float32) # ---- Beam Helper and Decoder ---- beam_helper = CustomBeamHelper(cell=decoder_cell, embedding=word_embedding, mask=decoder_mask, sequence_length=decoder_lengths, output_layer=decoder_output_layer, initial_state=decoder_init_state, beam_width=hps('beam_width')) beam_decoder = DiverseBeamSearchDecoder(beam_helper=beam_helper, sequence_length=decoder_lengths, nb_groups=hps('beam_groups')) beam_results, beam_state, _ = dynamic_decode(decoder=beam_decoder, output_time_major=False, maximum_iterations=maximum_iterations, swap_memory=hps('swap_memory')) # Making sure we haven't created new global variables assert not set(tf.global_variables()) - global_vars_after_decoder, 'New global vars were created' # Processing results logits = training_results.rnn_output # (b, dec_len, VOCAB_SIZE) logits_length = tf.shape(logits)[1] # dec_len decoder_target = decoder_inputs[:, 1:1 + logits_length] # Selected tokens are the token that was actually fed at the next position sample_mask = to_float(tf.math.equal(training_results.sample_id, -1)) selected_tokens = to_int32( sequence_mask * (sample_mask * to_float(decoder_target) + (1. - sample_mask) * to_float(training_results.sample_id))) # Argmax tokens are the most likely token outputted at each position argmax_tokens = to_int32(to_float(tf.argmax(logits, axis=-1)) * sequence_mask) log_probs = -1. * cross_entropy(logits=logits, labels=selected_tokens) * sequence_mask # Computing policy loss with tf.variable_scope('policy_loss'): policy_loss = sequence_loss(logits=logits, targets=decoder_target, weights=sequence_mask, average_across_batch=True, average_across_timesteps=True) policy_loss = tf.cond(stop_gradient_all, lambda: tf.stop_gradient(policy_loss), # pylint: disable=cell-var-from-loop lambda: policy_loss) # pylint: disable=cell-var-from-loop # Building output tags outputs = {'tag/policy/token_based/v005_markovian_film_board_align': True, 'targets': decoder_inputs[:, 1:], 'selected_tokens': selected_tokens, 'argmax_tokens': argmax_tokens, 'logits': logits, 'log_probs': log_probs, 'beam_tokens': tf.transpose(beam_results.predicted_ids, perm=[0, 2, 1]), # [batch, beam, steps] 'beam_log_probs': beam_state.log_probs, 'rnn_states': training_results.rnn_state, 'policy_loss': policy_loss, 'draw_prob': self.outputs.get('draw_prob', tf.zeros_like(self.features['draw_target'])), 'learning_rate': self.learning_rate} # Adding features, placeholders and outputs to graph self.add_meta_information(outputs)
def _build_policy_final(self): """ Builds the policy model (final step) """ from diplomacy_research.utils.tensorflow import tf from diplomacy_research.models.layers.attention import AttentionWrapper, BahdanauAttention from diplomacy_research.models.layers.beam_decoder import DiverseBeamSearchDecoder from diplomacy_research.models.layers.decoder import CandidateBasicDecoder from diplomacy_research.models.layers.dropout import SeededDropoutWrapper from diplomacy_research.models.layers.dynamic_decode import dynamic_decode from diplomacy_research.models.policy.order_based.helper import CustomHelper, CustomBeamHelper from diplomacy_research.utils.tensorflow import cross_entropy, sequence_loss, to_int32, to_float, get_tile_beam # Quick function to retrieve hparams and placeholders and function shorthands hps = lambda hparam_name: self.hparams[hparam_name] pholder = lambda placeholder_name: self.placeholders[placeholder_name] # Training loop with tf.variable_scope('policy', reuse=tf.AUTO_REUSE): with tf.device(self.cluster_config.worker_device if self.cluster_config else None): # Features player_seeds = self.features['player_seed'] # tf.int32 - (b,) temperature = self.features['temperature'] # tf,flt32 - (b,) dropout_rates = self.features['dropout_rate'] # tf.flt32 - (b,) # Placeholders stop_gradient_all = pholder('stop_gradient_all') # Outputs (from initial steps) batch_size = self.outputs['batch_size'] decoder_inputs = self.outputs['decoder_inputs'] decoder_type = self.outputs['decoder_type'] raw_decoder_lengths = self.outputs['raw_decoder_lengths'] decoder_lengths = self.outputs['decoder_lengths'] board_state_conv = self.outputs['board_state_conv'] order_embedding = self.outputs['order_embedding'] candidate_embedding = self.outputs['candidate_embedding'] candidates = self.outputs['candidates'] max_candidate_length = self.outputs['max_candidate_length'] # --- Decoding --- with tf.variable_scope('decoder_scope', reuse=tf.AUTO_REUSE): lstm_cell = tf.contrib.rnn.LSTMBlockCell(hps('lstm_size')) # ======== Regular Decoding ======== # Applying dropout to input + attention and to output layer decoder_cell = SeededDropoutWrapper(cell=lstm_cell, seeds=player_seeds, input_keep_probs=1. - dropout_rates, output_keep_probs=1. - dropout_rates, variational_recurrent=hps('use_v_dropout'), input_size=hps('order_emb_size') + hps('attn_size'), dtype=tf.float32) # apply attention over location # curr_state [batch, NB_NODES, attn_size] attention_scope = tf.VariableScope(name='policy/decoder_scope/Attention', reuse=tf.AUTO_REUSE) attention_mechanism = BahdanauAttention(num_units=hps('attn_size'), memory=board_state_conv, normalize=True, name_or_scope=attention_scope) decoder_cell = AttentionWrapper(cell=decoder_cell, attention_mechanism=attention_mechanism, output_attention=False, name_or_scope=attention_scope) # Setting initial state decoder_init_state = decoder_cell.zero_state(batch_size, tf.float32) decoder_init_state = decoder_init_state.clone(attention=tf.reduce_mean(board_state_conv, axis=1)) # ---- Helper ---- helper = CustomHelper(decoder_type=decoder_type, inputs=decoder_inputs[:, :-1], order_embedding=order_embedding, candidate_embedding=candidate_embedding, sequence_length=decoder_lengths, candidates=candidates, time_major=False, softmax_temperature=temperature) # ---- Decoder ---- sequence_mask = tf.sequence_mask(raw_decoder_lengths, maxlen=tf.reduce_max(decoder_lengths), dtype=tf.float32) maximum_iterations = NB_SUPPLY_CENTERS model_decoder = CandidateBasicDecoder(cell=decoder_cell, helper=helper, initial_state=decoder_init_state, max_candidate_length=max_candidate_length, extract_state=True) training_results, _, _ = dynamic_decode(decoder=model_decoder, output_time_major=False, maximum_iterations=maximum_iterations, swap_memory=hps('swap_memory')) global_vars_after_decoder = set(tf.global_variables()) # ======== Beam Search Decoding ======== tile_beam = get_tile_beam(hps('beam_width')) # Applying dropout to input + attention and to output layer decoder_cell = SeededDropoutWrapper(cell=lstm_cell, seeds=tile_beam(player_seeds), input_keep_probs=tile_beam(1. - dropout_rates), output_keep_probs=tile_beam(1. - dropout_rates), variational_recurrent=hps('use_v_dropout'), input_size=hps('order_emb_size') + hps('attn_size'), dtype=tf.float32) # apply attention over location # curr_state [batch, NB_NODES, attn_size] attention_mechanism = BahdanauAttention(num_units=hps('attn_size'), memory=tile_beam(board_state_conv), normalize=True, name_or_scope=attention_scope) decoder_cell = AttentionWrapper(cell=decoder_cell, attention_mechanism=attention_mechanism, output_attention=False, name_or_scope=attention_scope) # Setting initial state decoder_init_state = decoder_cell.zero_state(batch_size * hps('beam_width'), tf.float32) decoder_init_state = decoder_init_state.clone(attention=tf.reduce_mean(tile_beam(board_state_conv), axis=1)) # ---- Beam Helper and Decoder ---- beam_helper = CustomBeamHelper(cell=decoder_cell, order_embedding=order_embedding, candidate_embedding=candidate_embedding, candidates=candidates, sequence_length=decoder_lengths, initial_state=decoder_init_state, beam_width=hps('beam_width')) beam_decoder = DiverseBeamSearchDecoder(beam_helper=beam_helper, sequence_length=decoder_lengths, nb_groups=hps('beam_groups')) beam_results, beam_state, _ = dynamic_decode(decoder=beam_decoder, output_time_major=False, maximum_iterations=maximum_iterations, swap_memory=hps('swap_memory')) # Making sure we haven't created new global variables assert not set(tf.global_variables()) - global_vars_after_decoder, 'New global vars were created' # Processing results candidate_logits = training_results.rnn_output # (b, dec_len, max_cand_len) logits_length = tf.shape(candidate_logits)[1] # dec_len decoder_target = decoder_inputs[:, 1:1 + logits_length] # Selected tokens are the token that was actually fed at the next position sample_mask = to_float(tf.math.equal(training_results.sample_id, -1)) selected_tokens = to_int32( sequence_mask * (sample_mask * to_float(decoder_target) + (1. - sample_mask) * to_float(training_results.sample_id))) # Computing ArgMax tokens argmax_id = to_int32(tf.argmax(candidate_logits, axis=-1)) max_nb_candidate = tf.shape(candidate_logits)[2] candidate_ids = \ tf.reduce_sum(tf.one_hot(argmax_id, max_nb_candidate, dtype=tf.int32) * candidates, axis=-1) argmax_tokens = to_int32(to_float(candidate_ids) * sequence_mask) # Extracting the position of the target candidate tokens_labels = tf.argmax(to_int32(tf.math.equal(selected_tokens[:, :, None], candidates)), -1) target_labels = tf.argmax(to_int32(tf.math.equal(decoder_target[:, :, None], candidates)), -1) # Log Probs log_probs = -1. * cross_entropy(logits=candidate_logits, labels=tokens_labels) * sequence_mask # Computing policy loss with tf.variable_scope('policy_loss'): policy_loss = sequence_loss(logits=candidate_logits, targets=target_labels, weights=sequence_mask, average_across_batch=True, average_across_timesteps=True) policy_loss = tf.cond(stop_gradient_all, lambda: tf.stop_gradient(policy_loss), # pylint: disable=cell-var-from-loop lambda: policy_loss) # pylint: disable=cell-var-from-loop # Building output tags outputs = {'tag/policy/order_based/v001_markovian_no_film': True, 'targets': decoder_inputs[:, 1:], 'selected_tokens': selected_tokens, 'argmax_tokens': argmax_tokens, 'logits': candidate_logits, 'log_probs': log_probs, 'beam_tokens': tf.transpose(beam_results.predicted_ids, perm=[0, 2, 1]), # [batch, beam, steps] 'beam_log_probs': beam_state.log_probs, 'rnn_states': training_results.rnn_state, 'policy_loss': policy_loss, 'draw_prob': self.outputs.get('draw_prob', tf.zeros_like(self.features['draw_target'])), 'learning_rate': self.learning_rate} # Adding features, placeholders and outputs to graph self.add_meta_information(outputs)
def _build_policy_initial(self): """ Builds the policy model (initial step) """ from diplomacy_research.utils.tensorflow import tf from diplomacy_research.models.layers.initializers import uniform from diplomacy_research.utils.tensorflow import pad_axis, to_int32, to_float, to_bool if not self.placeholders: self.placeholders = self.get_placeholders() # Quick function to retrieve hparams and placeholders and function shorthands hps = lambda hparam_name: self.hparams[hparam_name] pholder = lambda placeholder_name: self.placeholders[placeholder_name] # Training loop with tf.variable_scope('policy', reuse=tf.AUTO_REUSE): with tf.device(self.cluster_config.worker_device if self. cluster_config else None): # Features board_state = to_float( self.features['board_state'] ) # tf.flt32 - (b, NB_NODES, NB_FEATURES) board_alignments = to_float( self.features['board_alignments']) # (b, NB_NODES * len) prev_orders_state = to_float( self.features['prev_orders_state'] ) # (b, NB_PRV_OD, NB_ND, NB_OD_FT) decoder_inputs = self.features[ 'decoder_inputs'] # tf.int32 - (b, <= 1 + NB_SCS) decoder_lengths = self.features[ 'decoder_lengths'] # tf.int32 - (b,) candidates = self.features[ 'candidates'] # tf.int32 - (b, nb_locs * MAX_CANDIDATES) current_power = self.features[ 'current_power'] # tf.int32 - (b,) current_season = self.features[ 'current_season'] # tf.int32 - (b,) dropout_rates = self.features[ 'dropout_rate'] # tf.flt32 - (b,) # Batch size batch_size = tf.shape(board_state)[0] # Reshaping board alignments board_alignments = tf.reshape(board_alignments, [batch_size, -1, NB_NODES]) board_alignments /= tf.math.maximum( 1., tf.reduce_sum(board_alignments, axis=-1, keepdims=True)) # Overriding dropout_rates if pholder('dropout_rate') > 0 dropout_rates = tf.cond( tf.greater(pholder('dropout_rate'), 0.), true_fn=lambda: tf.zeros_like(dropout_rates) + pholder( 'dropout_rate'), false_fn=lambda: dropout_rates) # Padding decoder_inputs and candidates board_alignments = pad_axis( board_alignments, axis=1, min_size=tf.reduce_max(decoder_lengths)) decoder_inputs = pad_axis(decoder_inputs, axis=-1, min_size=2) candidates = pad_axis(candidates, axis=-1, min_size=MAX_CANDIDATES) # Making sure all RNN lengths are at least 1 # No need to trim, because the fields are variable length raw_decoder_lengths = decoder_lengths decoder_lengths = tf.math.maximum(1, decoder_lengths) # Placeholders decoder_type = tf.reduce_max(pholder('decoder_type')) is_training = pholder('is_training') # Reshaping candidates candidates = tf.reshape(candidates, [batch_size, -1, MAX_CANDIDATES]) candidates = candidates[:, :tf.reduce_max( decoder_lengths), :] # tf.int32 - (b, nb_locs, MAX_CAN) # Computing FiLM Gammas and Betas with tf.variable_scope('film_scope'): power_embedding = uniform( name='power_embedding', shape=[NB_POWERS, hps('power_emb_size')], scale=1.) current_power_mask = tf.one_hot(current_power, NB_POWERS, dtype=tf.float32) current_power_embedding = tf.reduce_sum( power_embedding[None] * current_power_mask[:, :, None], axis=1) # (b, power_emb) film_embedding_input = current_power_embedding # Also conditioning on current_season season_embedding = uniform( name='season_embedding', shape=[NB_SEASONS, hps('season_emb_size')], scale=1.) current_season_mask = tf.one_hot(current_season, NB_SEASONS, dtype=tf.float32) current_season_embedding = tf.reduce_sum( season_embedding[None] # (b,season_emb) * current_season_mask[:, :, None], axis=1) film_embedding_input = tf.concat( [film_embedding_input, current_season_embedding], axis=1) film_output_dims = [hps('gcn_size')] * ( hps('nb_graph_conv') - 1) + [hps('attn_size') // 2] # For board_state board_film_weights = tf.layers.Dense( units=2 * sum(film_output_dims), # (b, 1, 750) use_bias=True, activation=None)(film_embedding_input)[:, None, :] board_film_gammas, board_film_betas = tf.split( board_film_weights, 2, axis=2) # (b, 1, 750) board_film_gammas = tf.split(board_film_gammas, film_output_dims, axis=2) board_film_betas = tf.split(board_film_betas, film_output_dims, axis=2) # For prev_orders prev_ord_film_weights = tf.layers.Dense( units=2 * sum(film_output_dims), # (b, 1, 750) use_bias=True, activation=None)(film_embedding_input)[:, None, :] prev_ord_film_weights = tf.tile( prev_ord_film_weights, [NB_PREV_ORDERS, 1, 1]) # (n_pr, 1, 750) prev_ord_film_gammas, prev_ord_film_betas = tf.split( prev_ord_film_weights, 2, axis=2) prev_ord_film_gammas = tf.split(prev_ord_film_gammas, film_output_dims, axis=2) prev_ord_film_betas = tf.split(prev_ord_film_betas, film_output_dims, axis=2) # Storing as temporary output self.add_output('_board_state_conv_film_gammas', board_film_gammas) self.add_output('_board_state_conv_film_betas', board_film_betas) self.add_output('_prev_orders_conv_film_gammas', prev_ord_film_gammas) self.add_output('_prev_orders_conv_film_betas', prev_ord_film_betas) # Creating graph convolution with tf.variable_scope('graph_conv_scope'): assert hps('nb_graph_conv') >= 2 assert hps('attn_size') % 2 == 0 # Encoding board state board_state_0yr_conv = self.encode_board( board_state, name='board_state_conv') # Encoding prev_orders prev_orders_state = tf.reshape(prev_orders_state, [ batch_size * NB_PREV_ORDERS, NB_NODES, NB_ORDERS_FEATURES ]) prev_ord_conv = self.encode_board(prev_orders_state, name='prev_orders_conv') # Splitting back into (b, nb_prev, NB_NODES, attn_size // 2) # Reducing the prev ord conv using avg prev_ord_conv = tf.reshape(prev_ord_conv, [ batch_size, NB_PREV_ORDERS, NB_NODES, hps('attn_size') // 2 ]) prev_ord_conv = tf.reduce_mean(prev_ord_conv, axis=1) # Concatenating the current board conv with the prev ord conv # The final board_state_conv should be of dimension (b, NB_NODE, attn_size) board_state_conv = self.get_board_state_conv( board_state_0yr_conv, is_training, prev_ord_conv) # Creating order embedding vector (to embed order_ix) # Embeddings needs to be cached locally on the worker, otherwise TF can't compute their gradients with tf.variable_scope('order_embedding_scope'): # embedding: (order_vocab_size, 64) caching_device = self.cluster_config.caching_device if self.cluster_config else None partitioner = tf.fixed_size_partitioner( NB_PARTITIONS) if hps('use_partitioner') else None order_embedding = uniform( name='order_embedding', shape=[ORDER_VOCABULARY_SIZE, hps('order_emb_size')], scale=1., partitioner=partitioner, caching_device=caching_device) # Creating candidate embedding with tf.variable_scope('candidate_embedding_scope'): # embedding: (order_vocab_size, 64) caching_device = self.cluster_config.caching_device if self.cluster_config else None partitioner = tf.fixed_size_partitioner( NB_PARTITIONS) if hps('use_partitioner') else None candidate_embedding = uniform( name='candidate_embedding', shape=[ORDER_VOCABULARY_SIZE, hps('lstm_size') + 1], scale=1., partitioner=partitioner, caching_device=caching_device) # Trimming to the maximum number of candidates candidate_lengths = tf.reduce_sum( to_int32(tf.math.greater(candidates, PAD_ID)), -1) # int32 - (b,) max_candidate_length = tf.math.maximum( 1, tf.reduce_max(candidate_lengths)) candidates = candidates[:, :, :max_candidate_length] # Building output tags outputs = { 'batch_size': batch_size, 'board_alignments': board_alignments, 'decoder_inputs': decoder_inputs, 'decoder_type': decoder_type, 'raw_decoder_lengths': raw_decoder_lengths, 'decoder_lengths': decoder_lengths, 'board_state_conv': board_state_conv, 'board_state_0yr_conv': board_state_0yr_conv, 'prev_ord_conv': prev_ord_conv, 'order_embedding': order_embedding, 'candidate_embedding': candidate_embedding, 'candidates': candidates, 'max_candidate_length': max_candidate_length, 'in_retreat_phase': tf.math.logical_and( # 1) board not empty, 2) disl. units present tf.reduce_sum(board_state[:], axis=[1, 2]) > 0, tf.math.logical_not( to_bool(tf.reduce_min(board_state[:, :, 23], -1)))) } # Adding to graph self.add_meta_information(outputs)
def _build_policy_initial(self): """ Builds the policy model (initial step) """ from diplomacy_research.utils.tensorflow import tf from diplomacy_research.models.layers.initializers import uniform from diplomacy_research.utils.tensorflow import pad_axis, to_int32, to_float, to_bool if not self.placeholders: self.placeholders = self.get_placeholders() # Quick function to retrieve hparams and placeholders and function shorthands hps = lambda hparam_name: self.hparams[hparam_name] pholder = lambda placeholder_name: self.placeholders[placeholder_name] # Training loop with tf.variable_scope('policy', reuse=tf.AUTO_REUSE): with tf.device(self.cluster_config.worker_device if self.cluster_config else None): # Features board_state = to_float(self.features['board_state']) # tf.flt32 - (b, NB_NODES, NB_FEATURES) decoder_inputs = self.features['decoder_inputs'] # tf.int32 - (b, <= 1 + NB_SCS) decoder_lengths = self.features['decoder_lengths'] # tf.int32 - (b,) candidates = self.features['candidates'] # tf.int32 - (b, nb_locs * MAX_CANDIDATES) dropout_rates = self.features['dropout_rate'] # tf.flt32 - (b,) # Batch size batch_size = tf.shape(board_state)[0] # Overriding dropout_rates if pholder('dropout_rate') > 0 dropout_rates = tf.cond(tf.greater(pholder('dropout_rate'), 0.), true_fn=lambda: tf.zeros_like(dropout_rates) + pholder('dropout_rate'), false_fn=lambda: dropout_rates) # Padding decoder_inputs and candidates decoder_inputs = pad_axis(decoder_inputs, axis=-1, min_size=2) candidates = pad_axis(candidates, axis=-1, min_size=MAX_CANDIDATES) # Making sure all RNN lengths are at least 1 # No need to trim, because the fields are variable length raw_decoder_lengths = decoder_lengths decoder_lengths = tf.math.maximum(1, decoder_lengths) # Placeholders decoder_type = tf.reduce_max(pholder('decoder_type')) is_training = pholder('is_training') # Reshaping candidates candidates = tf.reshape(candidates, [batch_size, -1, MAX_CANDIDATES]) candidates = candidates[:, :tf.reduce_max(decoder_lengths), :] # tf.int32 - (b, nb_locs, MAX_CAN) # Creating graph convolution with tf.variable_scope('graph_conv_scope'): assert hps('nb_graph_conv') >= 2 # Encoding board state board_state_0yr_conv = self.encode_board(board_state, name='board_state_conv') board_state_conv = self.get_board_state_conv(board_state_0yr_conv, is_training) # Creating order embedding vector (to embed order_ix) # Embeddings needs to be cached locally on the worker, otherwise TF can't compute their gradients with tf.variable_scope('order_embedding_scope'): # embedding: (order_vocab_size, 64) caching_device = self.cluster_config.caching_device if self.cluster_config else None partitioner = tf.fixed_size_partitioner(NB_PARTITIONS) if hps('use_partitioner') else None order_embedding = uniform(name='order_embedding', shape=[ORDER_VOCABULARY_SIZE, hps('order_emb_size')], scale=1., partitioner=partitioner, caching_device=caching_device) # Creating candidate embedding with tf.variable_scope('candidate_embedding_scope'): # embedding: (order_vocab_size, 64) caching_device = self.cluster_config.caching_device if self.cluster_config else None partitioner = tf.fixed_size_partitioner(NB_PARTITIONS) if hps('use_partitioner') else None candidate_embedding = uniform(name='candidate_embedding', shape=[ORDER_VOCABULARY_SIZE, hps('lstm_size') + 1], scale=1., partitioner=partitioner, caching_device=caching_device) # Trimming to the maximum number of candidates candidate_lengths = tf.reduce_sum(to_int32(tf.math.greater(candidates, PAD_ID)), -1) # int32 - (b,) max_candidate_length = tf.math.maximum(1, tf.reduce_max(candidate_lengths)) candidates = candidates[:, :, :max_candidate_length] # Building output tags outputs = {'batch_size': batch_size, 'decoder_inputs': decoder_inputs, 'decoder_type': decoder_type, 'raw_decoder_lengths': raw_decoder_lengths, 'decoder_lengths': decoder_lengths, 'board_state_conv': board_state_conv, 'board_state_0yr_conv': board_state_0yr_conv, 'order_embedding': order_embedding, 'candidate_embedding': candidate_embedding, 'candidates': candidates, 'max_candidate_length': max_candidate_length, 'in_retreat_phase': tf.math.logical_and( # 1) board not empty, 2) disl. units present tf.reduce_sum(board_state[:], axis=[1, 2]) > 0, tf.math.logical_not(to_bool(tf.reduce_min(board_state[:, :, 23], -1))))} # Adding to graph self.add_meta_information(outputs)
def _build_policy_final(self): """ Builds the policy model (final step) """ from diplomacy_research.utils.tensorflow import tf from diplomacy_research.models.layers.attention import StaticAttentionWrapper from diplomacy_research.models.layers.beam_decoder import DiverseBeamSearchDecoder from diplomacy_research.models.layers.decoder import CandidateBasicDecoder from diplomacy_research.models.layers.dropout import SeededDropoutWrapper from diplomacy_research.models.layers.dynamic_decode import dynamic_decode from diplomacy_research.models.layers.initializers import uniform from diplomacy_research.models.layers.transformer import TransformerCell from diplomacy_research.models.layers.wrappers import IdentityCell from diplomacy_research.models.policy.order_based.helper import CustomHelper, CustomBeamHelper from diplomacy_research.utils.tensorflow import cross_entropy, sequence_loss, to_int32, to_float, get_tile_beam # Quick function to retrieve hparams and placeholders and function shorthands hps = lambda hparam_name: self.hparams[hparam_name] pholder = lambda placeholder_name: self.placeholders[placeholder_name] # Training loop with tf.variable_scope('policy', reuse=tf.AUTO_REUSE): with tf.device(self.cluster_config.worker_device if self. cluster_config else None): # Features player_seeds = self.features['player_seed'] # tf.int32 - (b,) temperature = self.features['temperature'] # tf,flt32 - (b,) dropout_rates = self.features[ 'dropout_rate'] # tf.flt32 - (b,) # Placeholders stop_gradient_all = pholder('stop_gradient_all') # Outputs (from initial steps) batch_size = self.outputs['batch_size'] board_alignments = self.outputs['board_alignments'] decoder_inputs = self.outputs['decoder_inputs'] decoder_type = self.outputs['decoder_type'] raw_decoder_lengths = self.outputs['raw_decoder_lengths'] decoder_lengths = self.outputs['decoder_lengths'] board_state_conv = self.outputs['board_state_conv'] order_embedding = self.outputs['order_embedding'] candidate_embedding = self.outputs['candidate_embedding'] candidates = self.outputs['candidates'] max_candidate_length = self.outputs['max_candidate_length'] # Creating a smaller position embedding if it's not present in the outputs # Embeddings needs to be cached locally on the worker, otherwise TF can't compute their gradients with tf.variable_scope('position_embedding_scope'): caching_device = self.cluster_config.caching_device if self.cluster_config else None position_embedding = uniform( name='position_embedding', shape=[NB_SUPPLY_CENTERS, hps('trsf_emb_size')], scale=1., caching_device=caching_device) # Past Attentions past_attentions, message_lengths = None, None # --- Decoding --- with tf.variable_scope('decoder_scope', reuse=tf.AUTO_REUSE): feeder_cell = IdentityCell( output_size=hps('trsf_emb_size') + hps('attn_size')) # ======== Regular Decoding ======== # Applying Dropout to input, attention and output feeder_cell = SeededDropoutWrapper( cell=feeder_cell, seeds=player_seeds, input_keep_probs=1. - dropout_rates, variational_recurrent=hps('use_v_dropout'), input_size=hps('trsf_emb_size') + hps('attn_size'), dtype=tf.float32) # Apply attention over orderable location at each position feeder_cell = StaticAttentionWrapper( cell=feeder_cell, memory=board_state_conv, alignments=board_alignments, sequence_length=raw_decoder_lengths, output_attention=False) # Setting initial state feeder_cell_init_state = feeder_cell.zero_state( batch_size, tf.float32) # ---- Helper ---- helper = CustomHelper( decoder_type=decoder_type, inputs=decoder_inputs[:, :-1], order_embedding=order_embedding, candidate_embedding=candidate_embedding, sequence_length=decoder_lengths, candidates=candidates, time_major=False, softmax_temperature=temperature) # ---- Transformer Cell ---- trsf_scope = tf.VariableScope( name='policy/training_scope/transformer', reuse=False) transformer_cell = TransformerCell( nb_layers=hps('trsf_nb_layers'), nb_heads=hps('trsf_nb_heads'), word_embedding=order_embedding, position_embedding=position_embedding, batch_size=batch_size, feeder_cell=feeder_cell, feeder_init_state=feeder_cell_init_state, past_attentions=past_attentions, past_seq_lengths=message_lengths, scope=trsf_scope, name='transformer') transformer_cell_init_state = transformer_cell.zero_state( batch_size, tf.float32) # ---- Invariants ---- invariants_map = { 'past_attentions': tf.TensorShape([ None, # batch size hps('trsf_nb_layers'), # nb_layers 2, # key, value hps('trsf_nb_heads'), # nb heads None, # Seq len hps('trsf_emb_size') // hps('trsf_nb_heads') ]) } # Head size # ---- Decoder ---- sequence_mask = tf.sequence_mask( raw_decoder_lengths, maxlen=tf.reduce_max(decoder_lengths), dtype=tf.float32) maximum_iterations = NB_SUPPLY_CENTERS model_decoder = CandidateBasicDecoder( cell=transformer_cell, helper=helper, initial_state=transformer_cell_init_state, max_candidate_length=max_candidate_length, extract_state=True) training_results, _, _ = dynamic_decode( decoder=model_decoder, output_time_major=False, maximum_iterations=maximum_iterations, invariants_map=invariants_map, swap_memory=hps('swap_memory')) global_vars_after_decoder = set(tf.global_variables()) # ======== Beam Search Decoding ======== tile_beam = get_tile_beam(hps('beam_width')) beam_feeder_cell = IdentityCell( output_size=hps('trsf_emb_size') + hps('attn_size')) # Applying Dropout to input, attention and output beam_feeder_cell = SeededDropoutWrapper( cell=beam_feeder_cell, seeds=tile_beam(player_seeds), input_keep_probs=tile_beam(1. - dropout_rates), variational_recurrent=hps('use_v_dropout'), input_size=hps('trsf_emb_size') + hps('attn_size'), dtype=tf.float32) # Apply attention over orderable location at each position beam_feeder_cell = StaticAttentionWrapper( cell=beam_feeder_cell, memory=tile_beam(board_state_conv), alignments=tile_beam(board_alignments), sequence_length=tile_beam(raw_decoder_lengths), output_attention=False) # Setting initial state beam_feeder_init_state = beam_feeder_cell.zero_state( batch_size * hps('beam_width'), tf.float32) # ---- Transformer Cell ---- trsf_scope = tf.VariableScope( name='policy/training_scope/transformer', reuse=True) beam_trsf_cell = TransformerCell( nb_layers=hps('trsf_nb_layers'), nb_heads=hps('trsf_nb_heads'), word_embedding=order_embedding, position_embedding=position_embedding, batch_size=batch_size * hps('beam_width'), feeder_cell=beam_feeder_cell, feeder_init_state=beam_feeder_init_state, past_attentions=tile_beam(past_attentions), past_seq_lengths=tile_beam(message_lengths), scope=trsf_scope, name='transformer') beam_trsf_cell_init_state = beam_trsf_cell.zero_state( batch_size * hps('beam_width'), tf.float32) # ---- Beam Helper and Decoder ---- beam_helper = CustomBeamHelper( cell=beam_trsf_cell, order_embedding=order_embedding, candidate_embedding=candidate_embedding, candidates=candidates, sequence_length=decoder_lengths, initial_state=beam_trsf_cell_init_state, beam_width=hps('beam_width')) beam_decoder = DiverseBeamSearchDecoder( beam_helper=beam_helper, sequence_length=decoder_lengths, nb_groups=hps('beam_groups')) beam_results, beam_state, _ = dynamic_decode( decoder=beam_decoder, output_time_major=False, maximum_iterations=maximum_iterations, invariants_map=invariants_map, swap_memory=hps('swap_memory')) # Making sure we haven't created new global variables assert not set( tf.global_variables() ) - global_vars_after_decoder, 'New global vars were created' # Processing results candidate_logits = training_results.rnn_output # (b, dec_len, max_cand_len) logits_length = tf.shape(candidate_logits)[1] # dec_len decoder_target = decoder_inputs[:, 1:1 + logits_length] # Selected tokens are the token that was actually fed at the next position sample_mask = to_float( tf.math.equal(training_results.sample_id, -1)) selected_tokens = to_int32( sequence_mask * (sample_mask * to_float(decoder_target) + (1. - sample_mask) * to_float(training_results.sample_id))) # Computing ArgMax tokens argmax_id = to_int32(tf.argmax(candidate_logits, axis=-1)) max_nb_candidate = tf.shape(candidate_logits)[2] candidate_ids = \ tf.reduce_sum(tf.one_hot(argmax_id, max_nb_candidate, dtype=tf.int32) * candidates, axis=-1) argmax_tokens = to_int32( to_float(candidate_ids) * sequence_mask) # Extracting the position of the target candidate tokens_labels = tf.argmax( to_int32( tf.math.equal(selected_tokens[:, :, None], candidates)), -1) target_labels = tf.argmax( to_int32( tf.math.equal(decoder_target[:, :, None], candidates)), -1) # Log Probs log_probs = -1. * cross_entropy( logits=candidate_logits, labels=tokens_labels) * sequence_mask # Computing policy loss with tf.variable_scope('policy_loss'): policy_loss = sequence_loss(logits=candidate_logits, targets=target_labels, weights=sequence_mask, average_across_batch=True, average_across_timesteps=True) policy_loss = tf.cond( stop_gradient_all, lambda: tf.stop_gradient(policy_loss), # pylint: disable=cell-var-from-loop lambda: policy_loss) # pylint: disable=cell-var-from-loop # Building output tags outputs = { 'tag/policy/order_based/v015_film_transformer_gpt': True, 'targets': decoder_inputs[:, 1:], 'selected_tokens': selected_tokens, 'argmax_tokens': argmax_tokens, 'logits': candidate_logits, 'log_probs': log_probs, 'beam_tokens': tf.transpose(beam_results.predicted_ids, perm=[0, 2, 1]), # [batch, beam, steps] 'beam_log_probs': beam_state.log_probs, 'rnn_states': training_results.rnn_state, 'policy_loss': policy_loss, 'draw_prob': self.outputs.get('draw_prob', tf.zeros_like(self.features['draw_target'])), 'learning_rate': self.learning_rate } # Adding features, placeholders and outputs to graph self.add_meta_information(outputs)