def _build_value_initial(self): """ Builds the value model (initial step) """ from diplomacy_research.utils.tensorflow import tf from diplomacy_research.utils.tensorflow import to_float if not self.placeholders: self.placeholders = self.get_placeholders() else: self.placeholders.update(self.get_placeholders()) # Quick function to retrieve hparams and placeholders and function shorthands pholder = lambda placeholder_name: self.placeholders[placeholder_name] # Training loop with tf.variable_scope('value', reuse=tf.AUTO_REUSE): with tf.device(self.cluster_config.worker_device if self. cluster_config else None): # Features board_state = to_float( self.features['board_state'] ) # tf.float32 - (b, NB_NODES, NB_FEATURES) current_power = self.features[ 'current_power'] # tf.int32 - (b,) value_target = self.features[ 'value_target'] # tf.float32 - (b,) # Placeholders stop_gradient_all = pholder('stop_gradient_all') # Computing value for the current power state_value = self.get_board_value(board_state, current_power) # Computing value loss with tf.variable_scope('value_loss'): value_loss = tf.reduce_mean( tf.square(value_target - state_value)) value_loss = tf.cond( stop_gradient_all, lambda: tf.stop_gradient(value_loss), # pylint: disable=cell-var-from-loop lambda: value_loss) # pylint: disable=cell-var-from-loop # Building output tags outputs = { 'tag/value/v001_val_relu_7': True, 'state_value': state_value, 'value_loss': value_loss } # Adding features, placeholders and outputs to graph self.add_meta_information(outputs)
def _build_policy_final(self): """ Builds the policy model (final step) """ from diplomacy_research.utils.tensorflow import tf from diplomacy_research.models.layers.attention import StaticAttentionWrapper from diplomacy_research.models.layers.beam_decoder import DiverseBeamSearchDecoder from diplomacy_research.models.layers.decoder import MaskedBasicDecoder from diplomacy_research.models.layers.dropout import SeededDropoutWrapper from diplomacy_research.models.layers.dynamic_decode import dynamic_decode from diplomacy_research.models.policy.token_based.helper import CustomHelper, CustomBeamHelper from diplomacy_research.utils.tensorflow import cross_entropy, sequence_loss, to_int32, to_float, get_tile_beam # Quick function to retrieve hparams and placeholders and function shorthands hps = lambda hparam_name: self.hparams[hparam_name] pholder = lambda placeholder_name: self.placeholders[placeholder_name] # Training loop with tf.variable_scope('policy', reuse=tf.AUTO_REUSE): with tf.device(self.cluster_config.worker_device if self.cluster_config else None): # Features player_seeds = self.features['player_seed'] # tf.int32 - (b,) temperature = self.features['temperature'] # tf,flt32 - (b,) dropout_rates = self.features['dropout_rate'] # tf.flt32 - (b,) # Placeholders stop_gradient_all = pholder('stop_gradient_all') # Outputs (from initial steps) batch_size = self.outputs['batch_size'] board_alignments = self.outputs['board_alignments'] decoder_inputs = self.outputs['decoder_inputs'] decoder_mask = self.outputs['decoder_mask'] decoder_type = self.outputs['decoder_type'] raw_decoder_lengths = self.outputs['raw_decoder_lengths'] decoder_lengths = self.outputs['decoder_lengths'] board_state_conv = self.outputs['board_state_conv'] word_embedding = self.outputs['word_embedding'] # --- Decoding --- with tf.variable_scope('decoder_scope', reuse=tf.AUTO_REUSE): lstm_cell = tf.contrib.rnn.LSTMBlockCell(hps('lstm_size')) # decoder output to token decoder_output_layer = tf.layers.Dense(units=VOCABULARY_SIZE, activation=None, kernel_initializer=tf.random_normal_initializer, use_bias=True) # ======== Regular Decoding ======== # Applying dropout to input + attention and to output layer decoder_cell = SeededDropoutWrapper(cell=lstm_cell, seeds=player_seeds, input_keep_probs=1. - dropout_rates, output_keep_probs=1. - dropout_rates, variational_recurrent=hps('use_v_dropout'), input_size=hps('word_emb_size') + hps('attn_size'), dtype=tf.float32) # Apply attention over orderable location at each position decoder_cell = StaticAttentionWrapper(cell=decoder_cell, memory=board_state_conv, alignments=board_alignments, sequence_length=raw_decoder_lengths, output_attention=False) # Setting initial state decoder_init_state = decoder_cell.zero_state(batch_size, tf.float32) # ---- Helper ---- helper = CustomHelper(decoder_type=decoder_type, inputs=decoder_inputs[:, :-1], embedding=word_embedding, sequence_length=decoder_lengths, mask=decoder_mask, time_major=False, softmax_temperature=temperature) # ---- Decoder ---- sequence_mask = tf.sequence_mask(raw_decoder_lengths, maxlen=tf.reduce_max(decoder_lengths), dtype=tf.float32) maximum_iterations = TOKENS_PER_ORDER * NB_SUPPLY_CENTERS model_decoder = MaskedBasicDecoder(cell=decoder_cell, helper=helper, initial_state=decoder_init_state, output_layer=decoder_output_layer, extract_state=True) training_results, _, _ = dynamic_decode(decoder=model_decoder, output_time_major=False, maximum_iterations=maximum_iterations, swap_memory=hps('swap_memory')) global_vars_after_decoder = set(tf.global_variables()) # ======== Beam Search Decoding ======== tile_beam = get_tile_beam(hps('beam_width')) # Applying dropout to input + attention and to output layer decoder_cell = SeededDropoutWrapper(cell=lstm_cell, seeds=tile_beam(player_seeds), input_keep_probs=tile_beam(1. - dropout_rates), output_keep_probs=tile_beam(1. - dropout_rates), variational_recurrent=hps('use_v_dropout'), input_size=hps('word_emb_size') + hps('attn_size'), dtype=tf.float32) # Apply attention over orderable location at each position decoder_cell = StaticAttentionWrapper(cell=decoder_cell, memory=tile_beam(board_state_conv), alignments=tile_beam(board_alignments), sequence_length=tile_beam(raw_decoder_lengths), output_attention=False) # Setting initial state decoder_init_state = decoder_cell.zero_state(batch_size * hps('beam_width'), tf.float32) # ---- Beam Helper and Decoder ---- beam_helper = CustomBeamHelper(cell=decoder_cell, embedding=word_embedding, mask=decoder_mask, sequence_length=decoder_lengths, output_layer=decoder_output_layer, initial_state=decoder_init_state, beam_width=hps('beam_width')) beam_decoder = DiverseBeamSearchDecoder(beam_helper=beam_helper, sequence_length=decoder_lengths, nb_groups=hps('beam_groups')) beam_results, beam_state, _ = dynamic_decode(decoder=beam_decoder, output_time_major=False, maximum_iterations=maximum_iterations, swap_memory=hps('swap_memory')) # Making sure we haven't created new global variables assert not set(tf.global_variables()) - global_vars_after_decoder, 'New global vars were created' # Processing results logits = training_results.rnn_output # (b, dec_len, VOCAB_SIZE) logits_length = tf.shape(logits)[1] # dec_len decoder_target = decoder_inputs[:, 1:1 + logits_length] # Selected tokens are the token that was actually fed at the next position sample_mask = to_float(tf.math.equal(training_results.sample_id, -1)) selected_tokens = to_int32( sequence_mask * (sample_mask * to_float(decoder_target) + (1. - sample_mask) * to_float(training_results.sample_id))) # Argmax tokens are the most likely token outputted at each position argmax_tokens = to_int32(to_float(tf.argmax(logits, axis=-1)) * sequence_mask) log_probs = -1. * cross_entropy(logits=logits, labels=selected_tokens) * sequence_mask # Computing policy loss with tf.variable_scope('policy_loss'): policy_loss = sequence_loss(logits=logits, targets=decoder_target, weights=sequence_mask, average_across_batch=True, average_across_timesteps=True) policy_loss = tf.cond(stop_gradient_all, lambda: tf.stop_gradient(policy_loss), # pylint: disable=cell-var-from-loop lambda: policy_loss) # pylint: disable=cell-var-from-loop # Building output tags outputs = {'tag/policy/token_based/v005_markovian_film_board_align': True, 'targets': decoder_inputs[:, 1:], 'selected_tokens': selected_tokens, 'argmax_tokens': argmax_tokens, 'logits': logits, 'log_probs': log_probs, 'beam_tokens': tf.transpose(beam_results.predicted_ids, perm=[0, 2, 1]), # [batch, beam, steps] 'beam_log_probs': beam_state.log_probs, 'rnn_states': training_results.rnn_state, 'policy_loss': policy_loss, 'draw_prob': self.outputs.get('draw_prob', tf.zeros_like(self.features['draw_target'])), 'learning_rate': self.learning_rate} # Adding features, placeholders and outputs to graph self.add_meta_information(outputs)
def _build_draw_initial(self): """ Builds the draw model (initial step) """ from diplomacy_research.utils.tensorflow import tf from diplomacy_research.models.layers.graph_convolution import GraphConvolution, preprocess_adjacency from diplomacy_research.utils.tensorflow import to_float if not self.placeholders: self.placeholders = self.get_placeholders() else: self.placeholders.update(self.get_placeholders()) # Quick function to retrieve hparams and placeholders and function shorthands hps = lambda hparam_name: self.hparams[hparam_name] pholder = lambda placeholder_name: self.placeholders[placeholder_name] relu = tf.nn.relu sigmoid = tf.nn.sigmoid # Training loop with tf.variable_scope('draw', reuse=tf.AUTO_REUSE): with tf.device(self.cluster_config.worker_device if self. cluster_config else None): # Features board_state = to_float( self.features['board_state'] ) # tf.float32 - (b, NB_NODES, NB_FEATURES) current_power = self.features[ 'current_power'] # tf.int32 - (b,) draw_target = self.features['draw_target'] # tf.float32 - (b,) # Placeholders stop_gradient_all = pholder('stop_gradient_all') # Norm Adjacency batch_size = tf.shape(board_state)[0] norm_adjacency = preprocess_adjacency(get_adjacency_matrix()) norm_adjacency = tf.tile( tf.expand_dims(norm_adjacency, axis=0), [batch_size, 1, 1]) # Graph embeddings with tf.variable_scope('graph_conv_scope'): board_state_h0 = board_state # (b, 81, 35) board_state_h1 = GraphConvolution( input_dim=NB_FEATURES, output_dim=hps('draw_gcn_1_output_size'), norm_adjacency=norm_adjacency, activation_fn=relu, bias=True)(board_state_h0) # (b, 81, 25) # board_state_h2: (b, 2025) # board_state_h3: (b, 128) board_state_h2 = tf.reshape( board_state_h1, shape=[-1, NB_NODES * hps('draw_gcn_1_output_size')]) board_state_graph_conv = tf.layers.Dense( units=hps('draw_embedding_size'), activation=relu, use_bias=True)(board_state_h2) # Calculating draw for all powers with tf.variable_scope('draw_scope'): current_power_mask = tf.one_hot(current_power, NB_POWERS, dtype=tf.float32) draw_h0 = board_state_graph_conv # (b, 128) draw_h1 = tf.layers.Dense( units=hps('draw_h1_size'), # (b, 64) activation=relu, use_bias=True)(draw_h0) draw_h2 = tf.layers.Dense( units=hps('draw_h2_size'), # (b, 64) activation=relu, use_bias=True)(draw_h1) draw_probs = tf.layers.Dense( units=NB_POWERS, # (b, 7) activation=sigmoid, use_bias=True)(draw_h2) draw_prob = tf.reduce_sum(draw_probs * current_power_mask, axis=1) # (b,) # Computing draw loss with tf.variable_scope('draw_loss'): draw_loss = tf.reduce_mean( tf.square(draw_target - draw_prob)) draw_loss = tf.cond( stop_gradient_all, lambda: tf.stop_gradient(draw_loss), # pylint: disable=cell-var-from-loop lambda: draw_loss) # pylint: disable=cell-var-from-loop # Building output tags outputs = { 'tag/draw/v001_draw_relu': True, 'draw_prob': draw_prob, 'draw_loss': draw_loss } # Adding features, placeholders and outputs to graph self.add_meta_information(outputs)
def build(self): """ Builds the RL model using the correct optimizer """ from diplomacy_research.utils.tensorflow import tf, tfp, normalize, to_float from diplomacy_research.models.layers.avg_grad_optimizer import AvgGradOptimizer # Quick function to retrieve hparams and placeholders and function shorthands hps = lambda hparam_name: self.model.hparams[hparam_name] # Training loop with tf.variable_scope('policy', reuse=tf.AUTO_REUSE): with tf.device(self.cluster_config.worker_device if self. cluster_config else None): # Placeholders stop_gradient_all = self.model.placeholders[ 'stop_gradient_all'] # Features decoder_lengths = self.model.features[ 'decoder_lengths'] # tf.int32 - (b,) draw_action = self.model.features[ 'draw_action'] # tf.bool - (b,) reward_target = self.model.features[ 'reward_target'] # tf.float32 - (b,) value_target = self.model.features[ 'value_target'] # tf.float32 - (b,) old_log_probs = self.model.features[ 'old_log_probs'] # tf.float32 - (b, dec_len) # current_power = self.model.features['current_power'] # tf.int32 - (b,) # Making sure all RNN lengths are at least 1 # Trimming to the maximum decoder length in the batch raw_decoder_lengths = decoder_lengths decoder_lengths = tf.math.maximum(1, decoder_lengths) # Retrieving model outputs baseline = values = self.model.outputs['state_value'] # (b,) logits = self.model.outputs['logits'] # (b, dec, VOCAB) sequence_mask = tf.sequence_mask( raw_decoder_lengths, # (b, dec) maxlen=tf.reduce_max(decoder_lengths), dtype=tf.float32) # Computing Baseline Mean Square Error Loss with tf.variable_scope('baseline_scope'): baseline_mse_loss = tf.minimum( tf.square(value_target - values), hps('clip_value_threshold')) baseline_mse_loss = tf.reduce_sum(baseline_mse_loss) # () # Calculating surrogate loss with tf.variable_scope('policy_gradient_scope'): new_policy_log_probs = self.model.outputs[ 'log_probs'] * sequence_mask # (b, dec_len) old_policy_log_probs = old_log_probs * sequence_mask # (b, dec_len) new_sum_log_probs = tf.reduce_sum(new_policy_log_probs, axis=-1) # (b,) old_sum_log_probs = tf.reduce_sum(old_policy_log_probs, axis=-1) # (b,) ratio = tf.math.exp(new_sum_log_probs - old_sum_log_probs) # (b,) clipped_ratio = tf.clip_by_value(ratio, 1. - hps('epsilon'), 1. + hps('epsilon')) # (b,) advantages = tf.stop_gradient( normalize(reward_target - baseline)) # (b,) surrogate_loss_1 = ratio * advantages # (b,) surrogate_loss_2 = clipped_ratio * advantages # (b,) surrogate_loss = -tf.reduce_mean( tf.math.minimum(surrogate_loss_1, surrogate_loss_2)) # () # Calculating policy gradient for draw action with tf.variable_scope('draw_gradient_scope'): draw_action = to_float(draw_action) # (b,) draw_prob = self.model.outputs['draw_prob'] # (b,) log_prob_of_draw = draw_action * tf.log(draw_prob) + ( 1. - draw_action) * tf.log(1. - draw_prob) draw_gradient_loss = -1. * log_prob_of_draw * advantages # (b,) draw_gradient_loss = tf.reduce_mean( draw_gradient_loss) # () # Calculating entropy loss with tf.variable_scope('entropy_scope'): entropy = tfp.distributions.Categorical( logits=logits).entropy() entropy_loss = -tf.reduce_mean(entropy) # () # Scopes scope = ['policy', 'value', 'draw'] global_ignored_scope = None if not hps( 'ignored_scope') else hps('ignored_scope').split(',') # Creating PPO loss ppo_loss = surrogate_loss \ + hps('value_coeff') * baseline_mse_loss \ + hps('draw_coeff') * draw_gradient_loss \ + hps('entropy_coeff') * entropy_loss ppo_loss = tf.cond( stop_gradient_all, lambda: tf.stop_gradient(ppo_loss), # pylint: disable=cell-var-from-loop lambda: ppo_loss) # pylint: disable=cell-var-from-loop cost_and_scope = [(ppo_loss, scope, None)] # Creating optimizer op ppo_op = self.model.create_optimizer_op( cost_and_scope=cost_and_scope, ignored_scope=global_ignored_scope, max_gradient_norm=hps('max_gradient_norm')) # Making sure we are not using the AvgGradOptimizer, but directly the AdamOptimizer assert not isinstance( self.model.optimizer, AvgGradOptimizer), 'PPO does not use AvgGradOptimizer' # Storing outputs self._add_output('rl_policy_loss', surrogate_loss) self._add_output('rl_value_loss', baseline_mse_loss) self._add_output('rl_draw_loss', draw_gradient_loss) self._add_output('rl_entropy_loss', entropy_loss) self._add_output('rl_total_loss', ppo_loss) self._add_output('optimizer_op', ppo_op) # -------------------------------------- # Hooks # -------------------------------------- def hook_baseline_pre_condition(dataset): """ Pre-Condition: First queue to run """ if not hasattr(dataset, 'last_queue') or dataset.last_queue == '': return True return False def hook_baseline_post_queue(dataset): """ Post-Queue: Marks the baseline queue as processed """ dataset.last_queue = 'ppo_policy_baseline' # -------------------------------------- # Queues # -------------------------------------- self.queue_dataset.create_queue( 'ppo_policy_baseline', placeholders={ self.model.placeholders['decoder_type']: [TRAINING_DECODER] }, outputs=[ self.model.outputs[output_name] for output_name in ['optimizer_op'] + self.get_evaluation_tags() ], pre_condition=hook_baseline_pre_condition, post_queue=hook_baseline_post_queue) self.queue_dataset.create_queue( 'ppo_increase_version', placeholders={ self.model.placeholders['decoder_type']: [GREEDY_DECODER] }, outputs=[tf.assign_add(self.version_step, 1)], with_status=True)
def _build_policy_final(self): """ Builds the policy model (final step) """ from diplomacy_research.utils.tensorflow import tf from diplomacy_research.models.layers.attention import AttentionWrapper, BahdanauAttention from diplomacy_research.models.layers.beam_decoder import DiverseBeamSearchDecoder from diplomacy_research.models.layers.decoder import CandidateBasicDecoder from diplomacy_research.models.layers.dropout import SeededDropoutWrapper from diplomacy_research.models.layers.dynamic_decode import dynamic_decode from diplomacy_research.models.policy.order_based.helper import CustomHelper, CustomBeamHelper from diplomacy_research.utils.tensorflow import cross_entropy, sequence_loss, to_int32, to_float, get_tile_beam # Quick function to retrieve hparams and placeholders and function shorthands hps = lambda hparam_name: self.hparams[hparam_name] pholder = lambda placeholder_name: self.placeholders[placeholder_name] # Training loop with tf.variable_scope('policy', reuse=tf.AUTO_REUSE): with tf.device(self.cluster_config.worker_device if self.cluster_config else None): # Features player_seeds = self.features['player_seed'] # tf.int32 - (b,) temperature = self.features['temperature'] # tf,flt32 - (b,) dropout_rates = self.features['dropout_rate'] # tf.flt32 - (b,) # Placeholders stop_gradient_all = pholder('stop_gradient_all') # Outputs (from initial steps) batch_size = self.outputs['batch_size'] decoder_inputs = self.outputs['decoder_inputs'] decoder_type = self.outputs['decoder_type'] raw_decoder_lengths = self.outputs['raw_decoder_lengths'] decoder_lengths = self.outputs['decoder_lengths'] board_state_conv = self.outputs['board_state_conv'] order_embedding = self.outputs['order_embedding'] candidate_embedding = self.outputs['candidate_embedding'] candidates = self.outputs['candidates'] max_candidate_length = self.outputs['max_candidate_length'] # --- Decoding --- with tf.variable_scope('decoder_scope', reuse=tf.AUTO_REUSE): lstm_cell = tf.contrib.rnn.LSTMBlockCell(hps('lstm_size')) # ======== Regular Decoding ======== # Applying dropout to input + attention and to output layer decoder_cell = SeededDropoutWrapper(cell=lstm_cell, seeds=player_seeds, input_keep_probs=1. - dropout_rates, output_keep_probs=1. - dropout_rates, variational_recurrent=hps('use_v_dropout'), input_size=hps('order_emb_size') + hps('attn_size'), dtype=tf.float32) # apply attention over location # curr_state [batch, NB_NODES, attn_size] attention_scope = tf.VariableScope(name='policy/decoder_scope/Attention', reuse=tf.AUTO_REUSE) attention_mechanism = BahdanauAttention(num_units=hps('attn_size'), memory=board_state_conv, normalize=True, name_or_scope=attention_scope) decoder_cell = AttentionWrapper(cell=decoder_cell, attention_mechanism=attention_mechanism, output_attention=False, name_or_scope=attention_scope) # Setting initial state decoder_init_state = decoder_cell.zero_state(batch_size, tf.float32) decoder_init_state = decoder_init_state.clone(attention=tf.reduce_mean(board_state_conv, axis=1)) # ---- Helper ---- helper = CustomHelper(decoder_type=decoder_type, inputs=decoder_inputs[:, :-1], order_embedding=order_embedding, candidate_embedding=candidate_embedding, sequence_length=decoder_lengths, candidates=candidates, time_major=False, softmax_temperature=temperature) # ---- Decoder ---- sequence_mask = tf.sequence_mask(raw_decoder_lengths, maxlen=tf.reduce_max(decoder_lengths), dtype=tf.float32) maximum_iterations = NB_SUPPLY_CENTERS model_decoder = CandidateBasicDecoder(cell=decoder_cell, helper=helper, initial_state=decoder_init_state, max_candidate_length=max_candidate_length, extract_state=True) training_results, _, _ = dynamic_decode(decoder=model_decoder, output_time_major=False, maximum_iterations=maximum_iterations, swap_memory=hps('swap_memory')) global_vars_after_decoder = set(tf.global_variables()) # ======== Beam Search Decoding ======== tile_beam = get_tile_beam(hps('beam_width')) # Applying dropout to input + attention and to output layer decoder_cell = SeededDropoutWrapper(cell=lstm_cell, seeds=tile_beam(player_seeds), input_keep_probs=tile_beam(1. - dropout_rates), output_keep_probs=tile_beam(1. - dropout_rates), variational_recurrent=hps('use_v_dropout'), input_size=hps('order_emb_size') + hps('attn_size'), dtype=tf.float32) # apply attention over location # curr_state [batch, NB_NODES, attn_size] attention_mechanism = BahdanauAttention(num_units=hps('attn_size'), memory=tile_beam(board_state_conv), normalize=True, name_or_scope=attention_scope) decoder_cell = AttentionWrapper(cell=decoder_cell, attention_mechanism=attention_mechanism, output_attention=False, name_or_scope=attention_scope) # Setting initial state decoder_init_state = decoder_cell.zero_state(batch_size * hps('beam_width'), tf.float32) decoder_init_state = decoder_init_state.clone(attention=tf.reduce_mean(tile_beam(board_state_conv), axis=1)) # ---- Beam Helper and Decoder ---- beam_helper = CustomBeamHelper(cell=decoder_cell, order_embedding=order_embedding, candidate_embedding=candidate_embedding, candidates=candidates, sequence_length=decoder_lengths, initial_state=decoder_init_state, beam_width=hps('beam_width')) beam_decoder = DiverseBeamSearchDecoder(beam_helper=beam_helper, sequence_length=decoder_lengths, nb_groups=hps('beam_groups')) beam_results, beam_state, _ = dynamic_decode(decoder=beam_decoder, output_time_major=False, maximum_iterations=maximum_iterations, swap_memory=hps('swap_memory')) # Making sure we haven't created new global variables assert not set(tf.global_variables()) - global_vars_after_decoder, 'New global vars were created' # Processing results candidate_logits = training_results.rnn_output # (b, dec_len, max_cand_len) logits_length = tf.shape(candidate_logits)[1] # dec_len decoder_target = decoder_inputs[:, 1:1 + logits_length] # Selected tokens are the token that was actually fed at the next position sample_mask = to_float(tf.math.equal(training_results.sample_id, -1)) selected_tokens = to_int32( sequence_mask * (sample_mask * to_float(decoder_target) + (1. - sample_mask) * to_float(training_results.sample_id))) # Computing ArgMax tokens argmax_id = to_int32(tf.argmax(candidate_logits, axis=-1)) max_nb_candidate = tf.shape(candidate_logits)[2] candidate_ids = \ tf.reduce_sum(tf.one_hot(argmax_id, max_nb_candidate, dtype=tf.int32) * candidates, axis=-1) argmax_tokens = to_int32(to_float(candidate_ids) * sequence_mask) # Extracting the position of the target candidate tokens_labels = tf.argmax(to_int32(tf.math.equal(selected_tokens[:, :, None], candidates)), -1) target_labels = tf.argmax(to_int32(tf.math.equal(decoder_target[:, :, None], candidates)), -1) # Log Probs log_probs = -1. * cross_entropy(logits=candidate_logits, labels=tokens_labels) * sequence_mask # Computing policy loss with tf.variable_scope('policy_loss'): policy_loss = sequence_loss(logits=candidate_logits, targets=target_labels, weights=sequence_mask, average_across_batch=True, average_across_timesteps=True) policy_loss = tf.cond(stop_gradient_all, lambda: tf.stop_gradient(policy_loss), # pylint: disable=cell-var-from-loop lambda: policy_loss) # pylint: disable=cell-var-from-loop # Building output tags outputs = {'tag/policy/order_based/v001_markovian_no_film': True, 'targets': decoder_inputs[:, 1:], 'selected_tokens': selected_tokens, 'argmax_tokens': argmax_tokens, 'logits': candidate_logits, 'log_probs': log_probs, 'beam_tokens': tf.transpose(beam_results.predicted_ids, perm=[0, 2, 1]), # [batch, beam, steps] 'beam_log_probs': beam_state.log_probs, 'rnn_states': training_results.rnn_state, 'policy_loss': policy_loss, 'draw_prob': self.outputs.get('draw_prob', tf.zeros_like(self.features['draw_target'])), 'learning_rate': self.learning_rate} # Adding features, placeholders and outputs to graph self.add_meta_information(outputs)
def _build_value_final(self): """ Builds the value model (final step) """ from diplomacy_research.utils.tensorflow import tf if not self.placeholders: self.placeholders = self.get_placeholders() else: self.placeholders.update(self.get_placeholders()) # Quick function to retrieve hparams and placeholders and function shorthands hps = lambda hparam_name: self.hparams[hparam_name] pholder = lambda placeholder_name: self.placeholders[placeholder_name] relu = tf.nn.relu # Training loop with tf.variable_scope('value', reuse=tf.AUTO_REUSE): with tf.device(self.cluster_config.worker_device if self. cluster_config else None): # Outputs from the policy model assert 'rnn_states' in self.outputs # Inputs and Features rnn_states = self.outputs['rnn_states'] current_power = self.features[ 'current_power'] # tf.int32 - (b,) value_target = self.features[ 'value_target'] # tf.float32 - (b,) # Placeholders stop_gradient_all = pholder('stop_gradient_all') # Computing the value value_h0 = tf.stop_gradient(rnn_states) if hps( 'stop_gradient_value') else rnn_states value_h0_pos_0 = value_h0[:, 0, :] # (b, lstm_size) # Linear with relu # Then linear without relu value_h1_pos_0 = tf.layers.Dense( units=hps('value_h1_size'), # (b, 256) use_bias=True, activation=relu)(value_h0_pos_0) value_h2_pos_0 = tf.layers.Dense( units=NB_POWERS, # (b, 7) use_bias=True, activation=None)(value_h1_pos_0) # Computing for the current power current_power_mask = tf.one_hot(current_power, NB_POWERS, dtype=tf.float32) state_value = tf.reduce_sum(current_power_mask * value_h2_pos_0, axis=-1) # (b,) # Computing value loss with tf.variable_scope('value_loss'): value_loss = tf.reduce_mean( tf.square(value_target - state_value)) value_loss = tf.cond( stop_gradient_all, lambda: tf.stop_gradient(value_loss), # pylint: disable=cell-var-from-loop lambda: value_loss) # pylint: disable=cell-var-from-loop # Building output tags outputs = { 'tag/value/v003_rnn_step_0': True, 'state_value': state_value, 'value_loss': value_loss } # Adding features, placeholders and outputs to graph self.add_meta_information(outputs)
def _build_policy_final(self): """ Builds the policy model (final step) """ from diplomacy_research.utils.tensorflow import tf from diplomacy_research.models.layers.attention import StaticAttentionWrapper from diplomacy_research.models.layers.beam_decoder import DiverseBeamSearchDecoder from diplomacy_research.models.layers.decoder import CandidateBasicDecoder from diplomacy_research.models.layers.dropout import SeededDropoutWrapper from diplomacy_research.models.layers.dynamic_decode import dynamic_decode from diplomacy_research.models.layers.initializers import uniform from diplomacy_research.models.layers.transformer import TransformerCell from diplomacy_research.models.layers.wrappers import IdentityCell from diplomacy_research.models.policy.order_based.helper import CustomHelper, CustomBeamHelper from diplomacy_research.utils.tensorflow import cross_entropy, sequence_loss, to_int32, to_float, get_tile_beam # Quick function to retrieve hparams and placeholders and function shorthands hps = lambda hparam_name: self.hparams[hparam_name] pholder = lambda placeholder_name: self.placeholders[placeholder_name] # Training loop with tf.variable_scope('policy', reuse=tf.AUTO_REUSE): with tf.device(self.cluster_config.worker_device if self. cluster_config else None): # Features player_seeds = self.features['player_seed'] # tf.int32 - (b,) temperature = self.features['temperature'] # tf,flt32 - (b,) dropout_rates = self.features[ 'dropout_rate'] # tf.flt32 - (b,) # Placeholders stop_gradient_all = pholder('stop_gradient_all') # Outputs (from initial steps) batch_size = self.outputs['batch_size'] board_alignments = self.outputs['board_alignments'] decoder_inputs = self.outputs['decoder_inputs'] decoder_type = self.outputs['decoder_type'] raw_decoder_lengths = self.outputs['raw_decoder_lengths'] decoder_lengths = self.outputs['decoder_lengths'] board_state_conv = self.outputs['board_state_conv'] order_embedding = self.outputs['order_embedding'] candidate_embedding = self.outputs['candidate_embedding'] candidates = self.outputs['candidates'] max_candidate_length = self.outputs['max_candidate_length'] # Creating a smaller position embedding if it's not present in the outputs # Embeddings needs to be cached locally on the worker, otherwise TF can't compute their gradients with tf.variable_scope('position_embedding_scope'): caching_device = self.cluster_config.caching_device if self.cluster_config else None position_embedding = uniform( name='position_embedding', shape=[NB_SUPPLY_CENTERS, hps('trsf_emb_size')], scale=1., caching_device=caching_device) # Past Attentions past_attentions, message_lengths = None, None # --- Decoding --- with tf.variable_scope('decoder_scope', reuse=tf.AUTO_REUSE): feeder_cell = IdentityCell( output_size=hps('trsf_emb_size') + hps('attn_size')) # ======== Regular Decoding ======== # Applying Dropout to input, attention and output feeder_cell = SeededDropoutWrapper( cell=feeder_cell, seeds=player_seeds, input_keep_probs=1. - dropout_rates, variational_recurrent=hps('use_v_dropout'), input_size=hps('trsf_emb_size') + hps('attn_size'), dtype=tf.float32) # Apply attention over orderable location at each position feeder_cell = StaticAttentionWrapper( cell=feeder_cell, memory=board_state_conv, alignments=board_alignments, sequence_length=raw_decoder_lengths, output_attention=False) # Setting initial state feeder_cell_init_state = feeder_cell.zero_state( batch_size, tf.float32) # ---- Helper ---- helper = CustomHelper( decoder_type=decoder_type, inputs=decoder_inputs[:, :-1], order_embedding=order_embedding, candidate_embedding=candidate_embedding, sequence_length=decoder_lengths, candidates=candidates, time_major=False, softmax_temperature=temperature) # ---- Transformer Cell ---- trsf_scope = tf.VariableScope( name='policy/training_scope/transformer', reuse=False) transformer_cell = TransformerCell( nb_layers=hps('trsf_nb_layers'), nb_heads=hps('trsf_nb_heads'), word_embedding=order_embedding, position_embedding=position_embedding, batch_size=batch_size, feeder_cell=feeder_cell, feeder_init_state=feeder_cell_init_state, past_attentions=past_attentions, past_seq_lengths=message_lengths, scope=trsf_scope, name='transformer') transformer_cell_init_state = transformer_cell.zero_state( batch_size, tf.float32) # ---- Invariants ---- invariants_map = { 'past_attentions': tf.TensorShape([ None, # batch size hps('trsf_nb_layers'), # nb_layers 2, # key, value hps('trsf_nb_heads'), # nb heads None, # Seq len hps('trsf_emb_size') // hps('trsf_nb_heads') ]) } # Head size # ---- Decoder ---- sequence_mask = tf.sequence_mask( raw_decoder_lengths, maxlen=tf.reduce_max(decoder_lengths), dtype=tf.float32) maximum_iterations = NB_SUPPLY_CENTERS model_decoder = CandidateBasicDecoder( cell=transformer_cell, helper=helper, initial_state=transformer_cell_init_state, max_candidate_length=max_candidate_length, extract_state=True) training_results, _, _ = dynamic_decode( decoder=model_decoder, output_time_major=False, maximum_iterations=maximum_iterations, invariants_map=invariants_map, swap_memory=hps('swap_memory')) global_vars_after_decoder = set(tf.global_variables()) # ======== Beam Search Decoding ======== tile_beam = get_tile_beam(hps('beam_width')) beam_feeder_cell = IdentityCell( output_size=hps('trsf_emb_size') + hps('attn_size')) # Applying Dropout to input, attention and output beam_feeder_cell = SeededDropoutWrapper( cell=beam_feeder_cell, seeds=tile_beam(player_seeds), input_keep_probs=tile_beam(1. - dropout_rates), variational_recurrent=hps('use_v_dropout'), input_size=hps('trsf_emb_size') + hps('attn_size'), dtype=tf.float32) # Apply attention over orderable location at each position beam_feeder_cell = StaticAttentionWrapper( cell=beam_feeder_cell, memory=tile_beam(board_state_conv), alignments=tile_beam(board_alignments), sequence_length=tile_beam(raw_decoder_lengths), output_attention=False) # Setting initial state beam_feeder_init_state = beam_feeder_cell.zero_state( batch_size * hps('beam_width'), tf.float32) # ---- Transformer Cell ---- trsf_scope = tf.VariableScope( name='policy/training_scope/transformer', reuse=True) beam_trsf_cell = TransformerCell( nb_layers=hps('trsf_nb_layers'), nb_heads=hps('trsf_nb_heads'), word_embedding=order_embedding, position_embedding=position_embedding, batch_size=batch_size * hps('beam_width'), feeder_cell=beam_feeder_cell, feeder_init_state=beam_feeder_init_state, past_attentions=tile_beam(past_attentions), past_seq_lengths=tile_beam(message_lengths), scope=trsf_scope, name='transformer') beam_trsf_cell_init_state = beam_trsf_cell.zero_state( batch_size * hps('beam_width'), tf.float32) # ---- Beam Helper and Decoder ---- beam_helper = CustomBeamHelper( cell=beam_trsf_cell, order_embedding=order_embedding, candidate_embedding=candidate_embedding, candidates=candidates, sequence_length=decoder_lengths, initial_state=beam_trsf_cell_init_state, beam_width=hps('beam_width')) beam_decoder = DiverseBeamSearchDecoder( beam_helper=beam_helper, sequence_length=decoder_lengths, nb_groups=hps('beam_groups')) beam_results, beam_state, _ = dynamic_decode( decoder=beam_decoder, output_time_major=False, maximum_iterations=maximum_iterations, invariants_map=invariants_map, swap_memory=hps('swap_memory')) # Making sure we haven't created new global variables assert not set( tf.global_variables() ) - global_vars_after_decoder, 'New global vars were created' # Processing results candidate_logits = training_results.rnn_output # (b, dec_len, max_cand_len) logits_length = tf.shape(candidate_logits)[1] # dec_len decoder_target = decoder_inputs[:, 1:1 + logits_length] # Selected tokens are the token that was actually fed at the next position sample_mask = to_float( tf.math.equal(training_results.sample_id, -1)) selected_tokens = to_int32( sequence_mask * (sample_mask * to_float(decoder_target) + (1. - sample_mask) * to_float(training_results.sample_id))) # Computing ArgMax tokens argmax_id = to_int32(tf.argmax(candidate_logits, axis=-1)) max_nb_candidate = tf.shape(candidate_logits)[2] candidate_ids = \ tf.reduce_sum(tf.one_hot(argmax_id, max_nb_candidate, dtype=tf.int32) * candidates, axis=-1) argmax_tokens = to_int32( to_float(candidate_ids) * sequence_mask) # Extracting the position of the target candidate tokens_labels = tf.argmax( to_int32( tf.math.equal(selected_tokens[:, :, None], candidates)), -1) target_labels = tf.argmax( to_int32( tf.math.equal(decoder_target[:, :, None], candidates)), -1) # Log Probs log_probs = -1. * cross_entropy( logits=candidate_logits, labels=tokens_labels) * sequence_mask # Computing policy loss with tf.variable_scope('policy_loss'): policy_loss = sequence_loss(logits=candidate_logits, targets=target_labels, weights=sequence_mask, average_across_batch=True, average_across_timesteps=True) policy_loss = tf.cond( stop_gradient_all, lambda: tf.stop_gradient(policy_loss), # pylint: disable=cell-var-from-loop lambda: policy_loss) # pylint: disable=cell-var-from-loop # Building output tags outputs = { 'tag/policy/order_based/v015_film_transformer_gpt': True, 'targets': decoder_inputs[:, 1:], 'selected_tokens': selected_tokens, 'argmax_tokens': argmax_tokens, 'logits': candidate_logits, 'log_probs': log_probs, 'beam_tokens': tf.transpose(beam_results.predicted_ids, perm=[0, 2, 1]), # [batch, beam, steps] 'beam_log_probs': beam_state.log_probs, 'rnn_states': training_results.rnn_state, 'policy_loss': policy_loss, 'draw_prob': self.outputs.get('draw_prob', tf.zeros_like(self.features['draw_target'])), 'learning_rate': self.learning_rate } # Adding features, placeholders and outputs to graph self.add_meta_information(outputs)
def gradients(ys, xs, grad_ys=None, checkpoints='collection', aggregation_method=None, **kwargs): """ Authors: Tim Salimans and Yaroslav Bulatov Memory efficient gradient implementation inspired by "Training Deep Nets with Sublinear Memory Cost" by Chen et al. 2016 (https://arxiv.org/abs/1604.06174) :param ys: Tensor or list of tensors. :param xs: Tensor or list of tensors :param grad_ys: List of tensors holding the gradients received by the ys. Same length as ys. :param checkpoints: One of 1) a list consisting of tensors from the forward pass of the neural net that we should re-use when calculating the gradients in the backward pass all other tensors that do not appear in this list will be re-computed 2) The string 'speed': checkpoint all outputs of convolutions and matmuls. these ops are usually the most expensive, so checkpointing them maximizes the running speed (this is a good option if nonlinearities, concats, batchnorms, etc are taking up a lot of memory) 3) The string 'memory': try to minimize the memory usage (currently using a very simple strategy that identifies a number of bottleneck tensors in the graph to checkpoint) 4) The string 'collection': look for a tensorflow collection named 'checkpoints', which holds the tensors to checkpoint :param aggregation_method: :param kwargs: Optional kwargs to pass to tf.gradients :return: The gradients """ # pylint: disable=invalid-name # Computes forwards and backwards ops # Forward ops are all ops that are candidates for recomputation ys = [ys] if not isinstance(ys, list) else ys xs = [xs] if not isinstance(xs, list) else xs bwd_ops = graph_editor.get_backward_walk_ops([y.op for y in ys], inclusive=True) fwd_ops = graph_editor.get_forward_walk_ops([x.op for x in xs], inclusive=True, within_ops=bwd_ops) debug_print("bwd_ops: %s", bwd_ops) debug_print("fwd_ops: %s", fwd_ops) # Exclude ops with no inputs, or ops linked to xs or variables xs_ops = _to_ops(xs) fwd_ops = [op for op in fwd_ops if (op.inputs and not op in xs_ops and not '/assign' in op.name and not '/Assign' in op.name and not '/read' in op.name)] # Computes the list of tensors that can be recomputed from fw_ops ts_all = graph_editor.filter_ts(fwd_ops, True) ts_all = [t for t in ts_all if '/read' not in t.name] ts_all = set(ts_all) - set(xs) - set(ys) # Construct list of tensors to checkpoint during forward pass, if not given as input # At this point automatic selection happened and checkpoints is list of nodes if not isinstance(checkpoints, list): checkpoints = {'collection': _get_collection_checkpoints, 'speed': _get_speed_checkpoints, 'memory': _get_memory_checkpoints}[checkpoints](fwd_ops, ts_all, ys=ys, xs=xs, grad_ys=grad_ys, aggregation_method=aggregation_method, **kwargs) checkpoints = list(set(checkpoints).intersection(ts_all)) assert isinstance(checkpoints, list) debug_print("Checkpoint nodes used: %s", checkpoints) # Better error handling of special cases # xs are already handled as checkpoint nodes, so no need to include them xs_intersect_checkpoints = set(xs).intersection(set(checkpoints)) if xs_intersect_checkpoints: debug_print('Warning, some input nodes are also checkpoint nodes: %s', xs_intersect_checkpoints) ys_intersect_checkpoints = set(ys).intersection(set(checkpoints)) debug_print('ys: %s, checkpoints: %s, intersect: %s', ys, checkpoints, ys_intersect_checkpoints) # Saving an output node (ys) gives no benefit in memory while creating new edge cases, exclude them if ys_intersect_checkpoints: debug_print('Warning, some output nodes are also checkpoints nodes: %s', ys_intersect_checkpoints) # Remove initial and terminal nodes from checkpoints list if present # Only keeping checkpoints not in a control flow context checkpoints = list(set(checkpoints) - set(ys) - set(xs)) checkpoints = [ckpt for ckpt in checkpoints if ckpt._op._control_flow_context is None] # pylint: disable=protected-access # Check that we have some nodes to checkpoint if not checkpoints: raise RuntimeError('No checkpoints nodes found or given as input!') # Disconnect dependencies between checkpointed tensors checkpoints_disconnected = {} for ckpt in checkpoints: if ckpt.op and ckpt.op.name is not None: grad_node = tf.stop_gradient(ckpt, name=ckpt.op.name + '_stop_grad') else: grad_node = tf.stop_gradient(ckpt) checkpoints_disconnected[ckpt] = grad_node # Partial derivatives to the checkpointed tensors and xs ops_to_copy = fast_backward_ops(seed_ops=[y.op for y in ys], stop_at_ts=checkpoints, within_ops=fwd_ops) debug_print('Found %s ops to copy within fwd_ops %s, seed %s, stop_at %s', len(ops_to_copy), fwd_ops, [r.op for r in ys], checkpoints) debug_print('ops_to_copy = %s', ops_to_copy) debug_print('Processing list %s', ys) _, info = graph_editor.copy_with_input_replacements(graph_editor.sgv(ops_to_copy), {}) for origin_op, op in info._transformed_ops.items(): # pylint: disable=protected-access op._set_device(origin_op.node_def.device) # pylint: disable=protected-access copied_ops = info._transformed_ops.values() # pylint: disable=protected-access debug_print('Copied %s to %s', ops_to_copy, copied_ops) graph_editor.reroute_ts(checkpoints_disconnected.values(), checkpoints_disconnected.keys(), can_modify=copied_ops) debug_print('Rewired %s in place of %s restricted to %s', checkpoints_disconnected.values(), checkpoints_disconnected.keys(), copied_ops) # Get gradients with respect to current boundary + original x's copied_ys = [info._transformed_ops[y.op]._outputs[0] for y in ys] # pylint: disable=protected-access boundary = list(checkpoints_disconnected.values()) dv = TF_GRADIENTS(ys=copied_ys, xs=boundary + xs, grad_ys=grad_ys, aggregation_method=aggregation_method, **kwargs) debug_print('Got gradients %s', dv) debug_print('for %s', copied_ys) debug_print('with respect to %s', boundary + xs) # Adding control inputs to the graph inputs_to_do_before = [y.op for y in ys] if grad_ys is not None: inputs_to_do_before += grad_ys wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None] my_add_control_inputs(wait_to_do_ops, inputs_to_do_before) # Partial derivatives to the checkpointed nodes # Dictionary of "node: backprop" for nodes in the boundary # Partial derivatives to xs (usually the params of the neural net) d_checkpoints = {r: dr for r, dr in zip(checkpoints_disconnected.keys(), dv[:len(checkpoints_disconnected)])} d_xs = dv[len(checkpoints_disconnected):] # Incorporate derivatives flowing through the checkpointed nodes checkpoints_sorted_lists = tf_toposort(checkpoints, within_ops=fwd_ops) for ts in checkpoints_sorted_lists[::-1]: debug_print('Processing list %s', ts) checkpoints_other = [r for r in checkpoints if r not in ts] checkpoints_disconnected_other = [checkpoints_disconnected[r] for r in checkpoints_other] # Copy part of the graph below current checkpoint node, stopping at other checkpoints nodes ops_to_copy = fast_backward_ops(within_ops=fwd_ops, seed_ops=[r.op for r in ts], stop_at_ts=checkpoints_other) debug_print('Found %s ops to copy within %s, seed %s, stop_at %s', len(ops_to_copy), fwd_ops, [r.op for r in ts], checkpoints_other) debug_print('ops_to_copy = %s', ops_to_copy) # We are done! if not ops_to_copy: break _, info = graph_editor.copy_with_input_replacements(graph_editor.sgv(ops_to_copy), {}) for origin_op, op in info._transformed_ops.items(): # pylint: disable=protected-access op._set_device(origin_op.node_def.device) # pylint: disable=protected-access copied_ops = info._transformed_ops.values() # pylint: disable=protected-access debug_print('Copied %s to %s', ops_to_copy, copied_ops) graph_editor.reroute_ts(checkpoints_disconnected_other, checkpoints_other, can_modify=copied_ops) debug_print('Rewired %s in place of %s restricted to %s', checkpoints_disconnected_other, checkpoints_other, copied_ops) # Gradient flowing through the checkpointed node boundary = [info._transformed_ops[r.op]._outputs[0] for r in ts] # pylint: disable=protected-access substitute_backprops = [d_checkpoints[r] for r in ts] dv = TF_GRADIENTS(ys=boundary, xs=checkpoints_disconnected_other + xs, grad_ys=substitute_backprops, aggregation_method=aggregation_method, **kwargs) debug_print("Got gradients %s", dv) debug_print("for %s", boundary) debug_print("with respect to %s", checkpoints_disconnected_other + xs) debug_print("with boundary backprop substitutions %s", substitute_backprops) # Adding control inputs inputs_to_do_before = [d_checkpoints[r].op for r in ts] wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None] my_add_control_inputs(wait_to_do_ops, inputs_to_do_before) # Partial derivatives to the checkpointed nodes for r, dr in zip(checkpoints_other, dv[:len(checkpoints_other)]): if dr is not None: if d_checkpoints[r] is None: d_checkpoints[r] = dr else: d_checkpoints[r] += dr # Partial derivatives to xs (usually the params of the neural net) d_xs_new = dv[len(checkpoints_other):] for j in range(len(xs)): if d_xs_new[j] is not None: if d_xs[j] is None: d_xs[j] = _unsparsify(d_xs_new[j]) else: d_xs[j] += _unsparsify(d_xs_new[j]) # Returning the new gradients return d_xs
def build(self): """ Builds the RL model using the correct optimizer """ from diplomacy_research.utils.tensorflow import tf, tfp, normalize, to_float from diplomacy_research.models.layers.avg_grad_optimizer import AvgGradOptimizer # Quick function to retrieve hparams and placeholders and function shorthands hps = lambda hparam_name: self.model.hparams[hparam_name] # Training loop with tf.variable_scope('policy', reuse=tf.AUTO_REUSE): with tf.device(self.cluster_config.worker_device if self. cluster_config else None): # Placeholders stop_gradient_all = self.model.placeholders[ 'stop_gradient_all'] # Features decoder_lengths = self.model.features[ 'decoder_lengths'] # tf.int32 - (b,) draw_action = self.model.features[ 'draw_action'] # tf.bool - (b,) reward_target = self.model.features[ 'reward_target'] # tf.float32 - (b,) value_target = self.model.features[ 'value_target'] # tf.float32 - (b,) # current_power = self.model.features['current_power'] # tf.int32 - (b,) # Making sure all RNN lengths are at least 1 # Trimming to the maximum decoder length in the batch raw_decoder_lengths = decoder_lengths decoder_lengths = tf.math.maximum(1, decoder_lengths) # Retrieving model outputs # Using a fixed baseline (e.g. moving average) rather than a parameterized value function baseline = value_target # (b,) logits = self.model.outputs[ 'logits'] # (b, dec_len, VOCAB_SIZE) sequence_mask = tf.sequence_mask( raw_decoder_lengths, # (b, dec) maxlen=tf.reduce_max(decoder_lengths), dtype=tf.float32) # Calculating policy gradient loss with tf.variable_scope('policy_gradient_scope'): log_prob_of_tokens = self.model.outputs[ 'log_probs'] * sequence_mask # (b, dec_len) # Calculating loss and optimizer op advantages = tf.stop_gradient( normalize(reward_target - baseline)) # (b,) policy_gradient_loss = -tf.reduce_sum( log_prob_of_tokens, axis=-1) * advantages # (b,) policy_gradient_loss = tf.reduce_mean( policy_gradient_loss) # () # Calculating policy gradient for draw action with tf.variable_scope('draw_gradient_scope'): draw_action = to_float(draw_action) # (b,) draw_prob = self.model.outputs['draw_prob'] # (b,) log_prob_of_draw = draw_action * tf.log(draw_prob) + ( 1. - draw_action) * tf.log(1. - draw_prob) draw_gradient_loss = -1. * log_prob_of_draw * advantages # (b,) draw_gradient_loss = tf.reduce_mean( draw_gradient_loss) # () # Calculating entropy loss with tf.variable_scope('entropy_scope'): categorial_dist = tfp.distributions.Categorical( logits=logits) entropy = categorial_dist.entropy() entropy_loss = -tf.reduce_mean(entropy) # () # Scopes scope = ['policy', 'draw'] global_ignored_scope = [] if not hps('ignored_scope') else hps( 'ignored_scope').split(',') global_ignored_scope += ['value'] # Creating REINFORCE loss with baseline reinforce_loss = policy_gradient_loss \ + hps('draw_coeff') * draw_gradient_loss \ + hps('entropy_coeff') * entropy_loss reinforce_loss = tf.cond( stop_gradient_all, lambda: tf.stop_gradient(reinforce_loss), # pylint: disable=cell-var-from-loop lambda: reinforce_loss) # pylint: disable=cell-var-from-loop cost_and_scope = [(reinforce_loss, scope, None)] # Creating optimizer op reinforce_op = self.model.create_optimizer_op( cost_and_scope=cost_and_scope, ignored_scope=global_ignored_scope, max_gradient_norm=None) # AvgGradOptimizer will clip # Getting AvgGradOptimizer.update(version_step) assert isinstance( self.model.optimizer, AvgGradOptimizer), 'REINFORCE requires gradient averaging' update_op = self.model.optimizer.update(self.version_step) init_op = self.model.optimizer.init() # Storing outputs self._add_output('rl_policy_loss', policy_gradient_loss) self._add_output('rl_draw_loss', draw_gradient_loss) self._add_output('rl_entropy_loss', entropy_loss) self._add_output('rl_total_loss', reinforce_loss) self._add_output('optimizer_op', reinforce_op) self._add_output('update_op', update_op) self._add_output('init_op', init_op) # -------------------------------------- # Hooks # -------------------------------------- def hook_baseline_pre_condition(dataset): """ Pre-Condition: First queue to run """ if not hasattr(dataset, 'last_queue') or dataset.last_queue == '': return True return False def hook_baseline_post_queue(dataset): """ Post-Queue: Marks the baseline queue as processed """ dataset.last_queue = 'reinforce_policy' def hook_update_pre_condition(dataset): """ Pre-Condition: last_queue must be baseline """ if hasattr( dataset, 'last_queue') and dataset.last_queue == 'reinforce_policy': return True return False def hook_update_pre_queue(dataset): """ Pre-Queue: Restricts the queue to 1 dequeue maximum """ dataset.nb_items_to_pull_from_queue = min( dataset.nb_items_to_pull_from_queue, 1) def hook_update_post_queue(dataset): """ Post-Queue: Marks the update as processed """ dataset.last_queue = 'reinforce_update' # -------------------------------------- # Queues # -------------------------------------- self.queue_dataset.create_queue( 'reinforce_policy', placeholders={ self.model.placeholders['decoder_type']: [TRAINING_DECODER] }, outputs=[ self.model.outputs[output_name] for output_name in ['optimizer_op'] + self.get_evaluation_tags() ], with_status=True, pre_condition=hook_baseline_pre_condition, post_queue=hook_baseline_post_queue) self.queue_dataset.create_queue( 'reinforce_update', placeholders={ self.model.placeholders['decoder_type']: [GREEDY_DECODER] }, outputs=[self.model.outputs['update_op']], with_status=True, pre_condition=hook_update_pre_condition, pre_queue=hook_update_pre_queue, post_queue=hook_update_post_queue) self.queue_dataset.create_queue( 'optimizer_init', placeholders={ self.model.placeholders['decoder_type']: [GREEDY_DECODER] }, outputs=[self.model.outputs['init_op']], with_status=True)