def __call__(self, inputs): """ Actually performs the graph convolution :param inputs: The input feature matrix :return: The activated output matrix """ # Performs the convolution pre_act = tf.transpose( inputs, perm=[1, 0, 2]) # (b, N, in ) => (N, b, in ) pre_act = tf.matmul( pre_act, self.var_w) # (N, b, in) * (N, in, out) => (N, b, out) pre_act = tf.transpose( pre_act, perm=[1, 0, 2]) # (N, b, out) => (b, N, out) pre_act = tf.matmul( self.norm_adjacency, pre_act) # (b, N, N) * (b, N, out) => (b, N, out) # Adds the bias if self.bias: pre_act += self.var_b # (b, N, out) + (1,1,out) => (b, N, out) # Applying activation fn and residual connection post_act = self.activation_fn(pre_act) if self.residual: post_act += inputs return post_act # (b, N, out)
def _build_policy_final(self): """ Builds the policy model (final step) """ from diplomacy_research.utils.tensorflow import tf from diplomacy_research.models.layers.attention import StaticAttentionWrapper from diplomacy_research.models.layers.beam_decoder import DiverseBeamSearchDecoder from diplomacy_research.models.layers.decoder import MaskedBasicDecoder from diplomacy_research.models.layers.dropout import SeededDropoutWrapper from diplomacy_research.models.layers.dynamic_decode import dynamic_decode from diplomacy_research.models.policy.token_based.helper import CustomHelper, CustomBeamHelper from diplomacy_research.utils.tensorflow import cross_entropy, sequence_loss, to_int32, to_float, get_tile_beam # Quick function to retrieve hparams and placeholders and function shorthands hps = lambda hparam_name: self.hparams[hparam_name] pholder = lambda placeholder_name: self.placeholders[placeholder_name] # Training loop with tf.variable_scope('policy', reuse=tf.AUTO_REUSE): with tf.device(self.cluster_config.worker_device if self.cluster_config else None): # Features player_seeds = self.features['player_seed'] # tf.int32 - (b,) temperature = self.features['temperature'] # tf,flt32 - (b,) dropout_rates = self.features['dropout_rate'] # tf.flt32 - (b,) # Placeholders stop_gradient_all = pholder('stop_gradient_all') # Outputs (from initial steps) batch_size = self.outputs['batch_size'] board_alignments = self.outputs['board_alignments'] decoder_inputs = self.outputs['decoder_inputs'] decoder_mask = self.outputs['decoder_mask'] decoder_type = self.outputs['decoder_type'] raw_decoder_lengths = self.outputs['raw_decoder_lengths'] decoder_lengths = self.outputs['decoder_lengths'] board_state_conv = self.outputs['board_state_conv'] word_embedding = self.outputs['word_embedding'] # --- Decoding --- with tf.variable_scope('decoder_scope', reuse=tf.AUTO_REUSE): lstm_cell = tf.contrib.rnn.LSTMBlockCell(hps('lstm_size')) # decoder output to token decoder_output_layer = tf.layers.Dense(units=VOCABULARY_SIZE, activation=None, kernel_initializer=tf.random_normal_initializer, use_bias=True) # ======== Regular Decoding ======== # Applying dropout to input + attention and to output layer decoder_cell = SeededDropoutWrapper(cell=lstm_cell, seeds=player_seeds, input_keep_probs=1. - dropout_rates, output_keep_probs=1. - dropout_rates, variational_recurrent=hps('use_v_dropout'), input_size=hps('word_emb_size') + hps('attn_size'), dtype=tf.float32) # Apply attention over orderable location at each position decoder_cell = StaticAttentionWrapper(cell=decoder_cell, memory=board_state_conv, alignments=board_alignments, sequence_length=raw_decoder_lengths, output_attention=False) # Setting initial state decoder_init_state = decoder_cell.zero_state(batch_size, tf.float32) # ---- Helper ---- helper = CustomHelper(decoder_type=decoder_type, inputs=decoder_inputs[:, :-1], embedding=word_embedding, sequence_length=decoder_lengths, mask=decoder_mask, time_major=False, softmax_temperature=temperature) # ---- Decoder ---- sequence_mask = tf.sequence_mask(raw_decoder_lengths, maxlen=tf.reduce_max(decoder_lengths), dtype=tf.float32) maximum_iterations = TOKENS_PER_ORDER * NB_SUPPLY_CENTERS model_decoder = MaskedBasicDecoder(cell=decoder_cell, helper=helper, initial_state=decoder_init_state, output_layer=decoder_output_layer, extract_state=True) training_results, _, _ = dynamic_decode(decoder=model_decoder, output_time_major=False, maximum_iterations=maximum_iterations, swap_memory=hps('swap_memory')) global_vars_after_decoder = set(tf.global_variables()) # ======== Beam Search Decoding ======== tile_beam = get_tile_beam(hps('beam_width')) # Applying dropout to input + attention and to output layer decoder_cell = SeededDropoutWrapper(cell=lstm_cell, seeds=tile_beam(player_seeds), input_keep_probs=tile_beam(1. - dropout_rates), output_keep_probs=tile_beam(1. - dropout_rates), variational_recurrent=hps('use_v_dropout'), input_size=hps('word_emb_size') + hps('attn_size'), dtype=tf.float32) # Apply attention over orderable location at each position decoder_cell = StaticAttentionWrapper(cell=decoder_cell, memory=tile_beam(board_state_conv), alignments=tile_beam(board_alignments), sequence_length=tile_beam(raw_decoder_lengths), output_attention=False) # Setting initial state decoder_init_state = decoder_cell.zero_state(batch_size * hps('beam_width'), tf.float32) # ---- Beam Helper and Decoder ---- beam_helper = CustomBeamHelper(cell=decoder_cell, embedding=word_embedding, mask=decoder_mask, sequence_length=decoder_lengths, output_layer=decoder_output_layer, initial_state=decoder_init_state, beam_width=hps('beam_width')) beam_decoder = DiverseBeamSearchDecoder(beam_helper=beam_helper, sequence_length=decoder_lengths, nb_groups=hps('beam_groups')) beam_results, beam_state, _ = dynamic_decode(decoder=beam_decoder, output_time_major=False, maximum_iterations=maximum_iterations, swap_memory=hps('swap_memory')) # Making sure we haven't created new global variables assert not set(tf.global_variables()) - global_vars_after_decoder, 'New global vars were created' # Processing results logits = training_results.rnn_output # (b, dec_len, VOCAB_SIZE) logits_length = tf.shape(logits)[1] # dec_len decoder_target = decoder_inputs[:, 1:1 + logits_length] # Selected tokens are the token that was actually fed at the next position sample_mask = to_float(tf.math.equal(training_results.sample_id, -1)) selected_tokens = to_int32( sequence_mask * (sample_mask * to_float(decoder_target) + (1. - sample_mask) * to_float(training_results.sample_id))) # Argmax tokens are the most likely token outputted at each position argmax_tokens = to_int32(to_float(tf.argmax(logits, axis=-1)) * sequence_mask) log_probs = -1. * cross_entropy(logits=logits, labels=selected_tokens) * sequence_mask # Computing policy loss with tf.variable_scope('policy_loss'): policy_loss = sequence_loss(logits=logits, targets=decoder_target, weights=sequence_mask, average_across_batch=True, average_across_timesteps=True) policy_loss = tf.cond(stop_gradient_all, lambda: tf.stop_gradient(policy_loss), # pylint: disable=cell-var-from-loop lambda: policy_loss) # pylint: disable=cell-var-from-loop # Building output tags outputs = {'tag/policy/token_based/v005_markovian_film_board_align': True, 'targets': decoder_inputs[:, 1:], 'selected_tokens': selected_tokens, 'argmax_tokens': argmax_tokens, 'logits': logits, 'log_probs': log_probs, 'beam_tokens': tf.transpose(beam_results.predicted_ids, perm=[0, 2, 1]), # [batch, beam, steps] 'beam_log_probs': beam_state.log_probs, 'rnn_states': training_results.rnn_state, 'policy_loss': policy_loss, 'draw_prob': self.outputs.get('draw_prob', tf.zeros_like(self.features['draw_target'])), 'learning_rate': self.learning_rate} # Adding features, placeholders and outputs to graph self.add_meta_information(outputs)
def _build_policy_final(self): """ Builds the policy model (final step) """ from diplomacy_research.utils.tensorflow import tf from diplomacy_research.models.layers.attention import AttentionWrapper, BahdanauAttention from diplomacy_research.models.layers.beam_decoder import DiverseBeamSearchDecoder from diplomacy_research.models.layers.decoder import CandidateBasicDecoder from diplomacy_research.models.layers.dropout import SeededDropoutWrapper from diplomacy_research.models.layers.dynamic_decode import dynamic_decode from diplomacy_research.models.policy.order_based.helper import CustomHelper, CustomBeamHelper from diplomacy_research.utils.tensorflow import cross_entropy, sequence_loss, to_int32, to_float, get_tile_beam # Quick function to retrieve hparams and placeholders and function shorthands hps = lambda hparam_name: self.hparams[hparam_name] pholder = lambda placeholder_name: self.placeholders[placeholder_name] # Training loop with tf.variable_scope('policy', reuse=tf.AUTO_REUSE): with tf.device(self.cluster_config.worker_device if self.cluster_config else None): # Features player_seeds = self.features['player_seed'] # tf.int32 - (b,) temperature = self.features['temperature'] # tf,flt32 - (b,) dropout_rates = self.features['dropout_rate'] # tf.flt32 - (b,) # Placeholders stop_gradient_all = pholder('stop_gradient_all') # Outputs (from initial steps) batch_size = self.outputs['batch_size'] decoder_inputs = self.outputs['decoder_inputs'] decoder_type = self.outputs['decoder_type'] raw_decoder_lengths = self.outputs['raw_decoder_lengths'] decoder_lengths = self.outputs['decoder_lengths'] board_state_conv = self.outputs['board_state_conv'] order_embedding = self.outputs['order_embedding'] candidate_embedding = self.outputs['candidate_embedding'] candidates = self.outputs['candidates'] max_candidate_length = self.outputs['max_candidate_length'] # --- Decoding --- with tf.variable_scope('decoder_scope', reuse=tf.AUTO_REUSE): lstm_cell = tf.contrib.rnn.LSTMBlockCell(hps('lstm_size')) # ======== Regular Decoding ======== # Applying dropout to input + attention and to output layer decoder_cell = SeededDropoutWrapper(cell=lstm_cell, seeds=player_seeds, input_keep_probs=1. - dropout_rates, output_keep_probs=1. - dropout_rates, variational_recurrent=hps('use_v_dropout'), input_size=hps('order_emb_size') + hps('attn_size'), dtype=tf.float32) # apply attention over location # curr_state [batch, NB_NODES, attn_size] attention_scope = tf.VariableScope(name='policy/decoder_scope/Attention', reuse=tf.AUTO_REUSE) attention_mechanism = BahdanauAttention(num_units=hps('attn_size'), memory=board_state_conv, normalize=True, name_or_scope=attention_scope) decoder_cell = AttentionWrapper(cell=decoder_cell, attention_mechanism=attention_mechanism, output_attention=False, name_or_scope=attention_scope) # Setting initial state decoder_init_state = decoder_cell.zero_state(batch_size, tf.float32) decoder_init_state = decoder_init_state.clone(attention=tf.reduce_mean(board_state_conv, axis=1)) # ---- Helper ---- helper = CustomHelper(decoder_type=decoder_type, inputs=decoder_inputs[:, :-1], order_embedding=order_embedding, candidate_embedding=candidate_embedding, sequence_length=decoder_lengths, candidates=candidates, time_major=False, softmax_temperature=temperature) # ---- Decoder ---- sequence_mask = tf.sequence_mask(raw_decoder_lengths, maxlen=tf.reduce_max(decoder_lengths), dtype=tf.float32) maximum_iterations = NB_SUPPLY_CENTERS model_decoder = CandidateBasicDecoder(cell=decoder_cell, helper=helper, initial_state=decoder_init_state, max_candidate_length=max_candidate_length, extract_state=True) training_results, _, _ = dynamic_decode(decoder=model_decoder, output_time_major=False, maximum_iterations=maximum_iterations, swap_memory=hps('swap_memory')) global_vars_after_decoder = set(tf.global_variables()) # ======== Beam Search Decoding ======== tile_beam = get_tile_beam(hps('beam_width')) # Applying dropout to input + attention and to output layer decoder_cell = SeededDropoutWrapper(cell=lstm_cell, seeds=tile_beam(player_seeds), input_keep_probs=tile_beam(1. - dropout_rates), output_keep_probs=tile_beam(1. - dropout_rates), variational_recurrent=hps('use_v_dropout'), input_size=hps('order_emb_size') + hps('attn_size'), dtype=tf.float32) # apply attention over location # curr_state [batch, NB_NODES, attn_size] attention_mechanism = BahdanauAttention(num_units=hps('attn_size'), memory=tile_beam(board_state_conv), normalize=True, name_or_scope=attention_scope) decoder_cell = AttentionWrapper(cell=decoder_cell, attention_mechanism=attention_mechanism, output_attention=False, name_or_scope=attention_scope) # Setting initial state decoder_init_state = decoder_cell.zero_state(batch_size * hps('beam_width'), tf.float32) decoder_init_state = decoder_init_state.clone(attention=tf.reduce_mean(tile_beam(board_state_conv), axis=1)) # ---- Beam Helper and Decoder ---- beam_helper = CustomBeamHelper(cell=decoder_cell, order_embedding=order_embedding, candidate_embedding=candidate_embedding, candidates=candidates, sequence_length=decoder_lengths, initial_state=decoder_init_state, beam_width=hps('beam_width')) beam_decoder = DiverseBeamSearchDecoder(beam_helper=beam_helper, sequence_length=decoder_lengths, nb_groups=hps('beam_groups')) beam_results, beam_state, _ = dynamic_decode(decoder=beam_decoder, output_time_major=False, maximum_iterations=maximum_iterations, swap_memory=hps('swap_memory')) # Making sure we haven't created new global variables assert not set(tf.global_variables()) - global_vars_after_decoder, 'New global vars were created' # Processing results candidate_logits = training_results.rnn_output # (b, dec_len, max_cand_len) logits_length = tf.shape(candidate_logits)[1] # dec_len decoder_target = decoder_inputs[:, 1:1 + logits_length] # Selected tokens are the token that was actually fed at the next position sample_mask = to_float(tf.math.equal(training_results.sample_id, -1)) selected_tokens = to_int32( sequence_mask * (sample_mask * to_float(decoder_target) + (1. - sample_mask) * to_float(training_results.sample_id))) # Computing ArgMax tokens argmax_id = to_int32(tf.argmax(candidate_logits, axis=-1)) max_nb_candidate = tf.shape(candidate_logits)[2] candidate_ids = \ tf.reduce_sum(tf.one_hot(argmax_id, max_nb_candidate, dtype=tf.int32) * candidates, axis=-1) argmax_tokens = to_int32(to_float(candidate_ids) * sequence_mask) # Extracting the position of the target candidate tokens_labels = tf.argmax(to_int32(tf.math.equal(selected_tokens[:, :, None], candidates)), -1) target_labels = tf.argmax(to_int32(tf.math.equal(decoder_target[:, :, None], candidates)), -1) # Log Probs log_probs = -1. * cross_entropy(logits=candidate_logits, labels=tokens_labels) * sequence_mask # Computing policy loss with tf.variable_scope('policy_loss'): policy_loss = sequence_loss(logits=candidate_logits, targets=target_labels, weights=sequence_mask, average_across_batch=True, average_across_timesteps=True) policy_loss = tf.cond(stop_gradient_all, lambda: tf.stop_gradient(policy_loss), # pylint: disable=cell-var-from-loop lambda: policy_loss) # pylint: disable=cell-var-from-loop # Building output tags outputs = {'tag/policy/order_based/v001_markovian_no_film': True, 'targets': decoder_inputs[:, 1:], 'selected_tokens': selected_tokens, 'argmax_tokens': argmax_tokens, 'logits': candidate_logits, 'log_probs': log_probs, 'beam_tokens': tf.transpose(beam_results.predicted_ids, perm=[0, 2, 1]), # [batch, beam, steps] 'beam_log_probs': beam_state.log_probs, 'rnn_states': training_results.rnn_state, 'policy_loss': policy_loss, 'draw_prob': self.outputs.get('draw_prob', tf.zeros_like(self.features['draw_target'])), 'learning_rate': self.learning_rate} # Adding features, placeholders and outputs to graph self.add_meta_information(outputs)
def _build_policy_final(self): """ Builds the policy model (final step) """ from diplomacy_research.utils.tensorflow import tf from diplomacy_research.models.layers.attention import StaticAttentionWrapper from diplomacy_research.models.layers.beam_decoder import DiverseBeamSearchDecoder from diplomacy_research.models.layers.decoder import CandidateBasicDecoder from diplomacy_research.models.layers.dropout import SeededDropoutWrapper from diplomacy_research.models.layers.dynamic_decode import dynamic_decode from diplomacy_research.models.layers.initializers import uniform from diplomacy_research.models.layers.transformer import TransformerCell from diplomacy_research.models.layers.wrappers import IdentityCell from diplomacy_research.models.policy.order_based.helper import CustomHelper, CustomBeamHelper from diplomacy_research.utils.tensorflow import cross_entropy, sequence_loss, to_int32, to_float, get_tile_beam # Quick function to retrieve hparams and placeholders and function shorthands hps = lambda hparam_name: self.hparams[hparam_name] pholder = lambda placeholder_name: self.placeholders[placeholder_name] # Training loop with tf.variable_scope('policy', reuse=tf.AUTO_REUSE): with tf.device(self.cluster_config.worker_device if self. cluster_config else None): # Features player_seeds = self.features['player_seed'] # tf.int32 - (b,) temperature = self.features['temperature'] # tf,flt32 - (b,) dropout_rates = self.features[ 'dropout_rate'] # tf.flt32 - (b,) # Placeholders stop_gradient_all = pholder('stop_gradient_all') # Outputs (from initial steps) batch_size = self.outputs['batch_size'] board_alignments = self.outputs['board_alignments'] decoder_inputs = self.outputs['decoder_inputs'] decoder_type = self.outputs['decoder_type'] raw_decoder_lengths = self.outputs['raw_decoder_lengths'] decoder_lengths = self.outputs['decoder_lengths'] board_state_conv = self.outputs['board_state_conv'] order_embedding = self.outputs['order_embedding'] candidate_embedding = self.outputs['candidate_embedding'] candidates = self.outputs['candidates'] max_candidate_length = self.outputs['max_candidate_length'] # Creating a smaller position embedding if it's not present in the outputs # Embeddings needs to be cached locally on the worker, otherwise TF can't compute their gradients with tf.variable_scope('position_embedding_scope'): caching_device = self.cluster_config.caching_device if self.cluster_config else None position_embedding = uniform( name='position_embedding', shape=[NB_SUPPLY_CENTERS, hps('trsf_emb_size')], scale=1., caching_device=caching_device) # Past Attentions past_attentions, message_lengths = None, None # --- Decoding --- with tf.variable_scope('decoder_scope', reuse=tf.AUTO_REUSE): feeder_cell = IdentityCell( output_size=hps('trsf_emb_size') + hps('attn_size')) # ======== Regular Decoding ======== # Applying Dropout to input, attention and output feeder_cell = SeededDropoutWrapper( cell=feeder_cell, seeds=player_seeds, input_keep_probs=1. - dropout_rates, variational_recurrent=hps('use_v_dropout'), input_size=hps('trsf_emb_size') + hps('attn_size'), dtype=tf.float32) # Apply attention over orderable location at each position feeder_cell = StaticAttentionWrapper( cell=feeder_cell, memory=board_state_conv, alignments=board_alignments, sequence_length=raw_decoder_lengths, output_attention=False) # Setting initial state feeder_cell_init_state = feeder_cell.zero_state( batch_size, tf.float32) # ---- Helper ---- helper = CustomHelper( decoder_type=decoder_type, inputs=decoder_inputs[:, :-1], order_embedding=order_embedding, candidate_embedding=candidate_embedding, sequence_length=decoder_lengths, candidates=candidates, time_major=False, softmax_temperature=temperature) # ---- Transformer Cell ---- trsf_scope = tf.VariableScope( name='policy/training_scope/transformer', reuse=False) transformer_cell = TransformerCell( nb_layers=hps('trsf_nb_layers'), nb_heads=hps('trsf_nb_heads'), word_embedding=order_embedding, position_embedding=position_embedding, batch_size=batch_size, feeder_cell=feeder_cell, feeder_init_state=feeder_cell_init_state, past_attentions=past_attentions, past_seq_lengths=message_lengths, scope=trsf_scope, name='transformer') transformer_cell_init_state = transformer_cell.zero_state( batch_size, tf.float32) # ---- Invariants ---- invariants_map = { 'past_attentions': tf.TensorShape([ None, # batch size hps('trsf_nb_layers'), # nb_layers 2, # key, value hps('trsf_nb_heads'), # nb heads None, # Seq len hps('trsf_emb_size') // hps('trsf_nb_heads') ]) } # Head size # ---- Decoder ---- sequence_mask = tf.sequence_mask( raw_decoder_lengths, maxlen=tf.reduce_max(decoder_lengths), dtype=tf.float32) maximum_iterations = NB_SUPPLY_CENTERS model_decoder = CandidateBasicDecoder( cell=transformer_cell, helper=helper, initial_state=transformer_cell_init_state, max_candidate_length=max_candidate_length, extract_state=True) training_results, _, _ = dynamic_decode( decoder=model_decoder, output_time_major=False, maximum_iterations=maximum_iterations, invariants_map=invariants_map, swap_memory=hps('swap_memory')) global_vars_after_decoder = set(tf.global_variables()) # ======== Beam Search Decoding ======== tile_beam = get_tile_beam(hps('beam_width')) beam_feeder_cell = IdentityCell( output_size=hps('trsf_emb_size') + hps('attn_size')) # Applying Dropout to input, attention and output beam_feeder_cell = SeededDropoutWrapper( cell=beam_feeder_cell, seeds=tile_beam(player_seeds), input_keep_probs=tile_beam(1. - dropout_rates), variational_recurrent=hps('use_v_dropout'), input_size=hps('trsf_emb_size') + hps('attn_size'), dtype=tf.float32) # Apply attention over orderable location at each position beam_feeder_cell = StaticAttentionWrapper( cell=beam_feeder_cell, memory=tile_beam(board_state_conv), alignments=tile_beam(board_alignments), sequence_length=tile_beam(raw_decoder_lengths), output_attention=False) # Setting initial state beam_feeder_init_state = beam_feeder_cell.zero_state( batch_size * hps('beam_width'), tf.float32) # ---- Transformer Cell ---- trsf_scope = tf.VariableScope( name='policy/training_scope/transformer', reuse=True) beam_trsf_cell = TransformerCell( nb_layers=hps('trsf_nb_layers'), nb_heads=hps('trsf_nb_heads'), word_embedding=order_embedding, position_embedding=position_embedding, batch_size=batch_size * hps('beam_width'), feeder_cell=beam_feeder_cell, feeder_init_state=beam_feeder_init_state, past_attentions=tile_beam(past_attentions), past_seq_lengths=tile_beam(message_lengths), scope=trsf_scope, name='transformer') beam_trsf_cell_init_state = beam_trsf_cell.zero_state( batch_size * hps('beam_width'), tf.float32) # ---- Beam Helper and Decoder ---- beam_helper = CustomBeamHelper( cell=beam_trsf_cell, order_embedding=order_embedding, candidate_embedding=candidate_embedding, candidates=candidates, sequence_length=decoder_lengths, initial_state=beam_trsf_cell_init_state, beam_width=hps('beam_width')) beam_decoder = DiverseBeamSearchDecoder( beam_helper=beam_helper, sequence_length=decoder_lengths, nb_groups=hps('beam_groups')) beam_results, beam_state, _ = dynamic_decode( decoder=beam_decoder, output_time_major=False, maximum_iterations=maximum_iterations, invariants_map=invariants_map, swap_memory=hps('swap_memory')) # Making sure we haven't created new global variables assert not set( tf.global_variables() ) - global_vars_after_decoder, 'New global vars were created' # Processing results candidate_logits = training_results.rnn_output # (b, dec_len, max_cand_len) logits_length = tf.shape(candidate_logits)[1] # dec_len decoder_target = decoder_inputs[:, 1:1 + logits_length] # Selected tokens are the token that was actually fed at the next position sample_mask = to_float( tf.math.equal(training_results.sample_id, -1)) selected_tokens = to_int32( sequence_mask * (sample_mask * to_float(decoder_target) + (1. - sample_mask) * to_float(training_results.sample_id))) # Computing ArgMax tokens argmax_id = to_int32(tf.argmax(candidate_logits, axis=-1)) max_nb_candidate = tf.shape(candidate_logits)[2] candidate_ids = \ tf.reduce_sum(tf.one_hot(argmax_id, max_nb_candidate, dtype=tf.int32) * candidates, axis=-1) argmax_tokens = to_int32( to_float(candidate_ids) * sequence_mask) # Extracting the position of the target candidate tokens_labels = tf.argmax( to_int32( tf.math.equal(selected_tokens[:, :, None], candidates)), -1) target_labels = tf.argmax( to_int32( tf.math.equal(decoder_target[:, :, None], candidates)), -1) # Log Probs log_probs = -1. * cross_entropy( logits=candidate_logits, labels=tokens_labels) * sequence_mask # Computing policy loss with tf.variable_scope('policy_loss'): policy_loss = sequence_loss(logits=candidate_logits, targets=target_labels, weights=sequence_mask, average_across_batch=True, average_across_timesteps=True) policy_loss = tf.cond( stop_gradient_all, lambda: tf.stop_gradient(policy_loss), # pylint: disable=cell-var-from-loop lambda: policy_loss) # pylint: disable=cell-var-from-loop # Building output tags outputs = { 'tag/policy/order_based/v015_film_transformer_gpt': True, 'targets': decoder_inputs[:, 1:], 'selected_tokens': selected_tokens, 'argmax_tokens': argmax_tokens, 'logits': candidate_logits, 'log_probs': log_probs, 'beam_tokens': tf.transpose(beam_results.predicted_ids, perm=[0, 2, 1]), # [batch, beam, steps] 'beam_log_probs': beam_state.log_probs, 'rnn_states': training_results.rnn_state, 'policy_loss': policy_loss, 'draw_prob': self.outputs.get('draw_prob', tf.zeros_like(self.features['draw_target'])), 'learning_rate': self.learning_rate } # Adding features, placeholders and outputs to graph self.add_meta_information(outputs)