def test_get_board_value(model): """ Tests the get_board_value method """ from diplomacy_research.utils.tensorflow import tf nb_global_vars_before = len(tf.global_variables()) board_state = tf.placeholder(dtype=tf.float32, shape=[None, NB_NODES, NB_FEATURES], name='fake_board') current_power = tf.placeholder(dtype=tf.int32, shape=[None], name='fake_current_power') model.get_board_value(board_state, current_power, reuse=True) nb_global_vars_after = len(tf.global_variables()) assert nb_global_vars_before == nb_global_vars_after, 'New variables added when getting board value.'
def initialize(self, session): """ Initialize the adapter (init global vars and the dataset) :type session: tensorflow.python.client.session.Session """ if not self.feedable_dataset.can_support_iterator or not self.iterator: return from diplomacy_research.utils.tensorflow import tf assert session, 'You must pass a session to initialize the adapter' assert isinstance(self.feedable_dataset, QueueDataset), 'The dataset must be a QueueDataset' self.session = session # Initializes uninit global vars graph = self.graph or tf.get_default_graph() if not graph.finalized: with graph.as_default(): var_to_initialize = tf.global_variables() + tf.local_variables( ) is_initialized = self.session.run([ tf.is_variable_initialized(var) for var in var_to_initialize ]) not_initialized_vars = [ var for (var, is_init) in zip(var_to_initialize, is_initialized) if not is_init ] if not_initialized_vars: LOGGER.info('Initialized %d variables.', len(not_initialized_vars)) self.session.run( tf.variables_initializer(not_initialized_vars)) # Initializing the dataset to use the feedable model if not self.feedable_dataset.is_started and self.session: self.feedable_dataset.start(self.session) elif not self.feedable_dataset.is_initialized and self.session: self.feedable_dataset.initialize(self.session)
def _build_policy_final(self): """ Builds the policy model (final step) """ from diplomacy_research.utils.tensorflow import tf from diplomacy_research.models.layers.attention import StaticAttentionWrapper from diplomacy_research.models.layers.beam_decoder import DiverseBeamSearchDecoder from diplomacy_research.models.layers.decoder import MaskedBasicDecoder from diplomacy_research.models.layers.dropout import SeededDropoutWrapper from diplomacy_research.models.layers.dynamic_decode import dynamic_decode from diplomacy_research.models.policy.token_based.helper import CustomHelper, CustomBeamHelper from diplomacy_research.utils.tensorflow import cross_entropy, sequence_loss, to_int32, to_float, get_tile_beam # Quick function to retrieve hparams and placeholders and function shorthands hps = lambda hparam_name: self.hparams[hparam_name] pholder = lambda placeholder_name: self.placeholders[placeholder_name] # Training loop with tf.variable_scope('policy', reuse=tf.AUTO_REUSE): with tf.device(self.cluster_config.worker_device if self.cluster_config else None): # Features player_seeds = self.features['player_seed'] # tf.int32 - (b,) temperature = self.features['temperature'] # tf,flt32 - (b,) dropout_rates = self.features['dropout_rate'] # tf.flt32 - (b,) # Placeholders stop_gradient_all = pholder('stop_gradient_all') # Outputs (from initial steps) batch_size = self.outputs['batch_size'] board_alignments = self.outputs['board_alignments'] decoder_inputs = self.outputs['decoder_inputs'] decoder_mask = self.outputs['decoder_mask'] decoder_type = self.outputs['decoder_type'] raw_decoder_lengths = self.outputs['raw_decoder_lengths'] decoder_lengths = self.outputs['decoder_lengths'] board_state_conv = self.outputs['board_state_conv'] word_embedding = self.outputs['word_embedding'] # --- Decoding --- with tf.variable_scope('decoder_scope', reuse=tf.AUTO_REUSE): lstm_cell = tf.contrib.rnn.LSTMBlockCell(hps('lstm_size')) # decoder output to token decoder_output_layer = tf.layers.Dense(units=VOCABULARY_SIZE, activation=None, kernel_initializer=tf.random_normal_initializer, use_bias=True) # ======== Regular Decoding ======== # Applying dropout to input + attention and to output layer decoder_cell = SeededDropoutWrapper(cell=lstm_cell, seeds=player_seeds, input_keep_probs=1. - dropout_rates, output_keep_probs=1. - dropout_rates, variational_recurrent=hps('use_v_dropout'), input_size=hps('word_emb_size') + hps('attn_size'), dtype=tf.float32) # Apply attention over orderable location at each position decoder_cell = StaticAttentionWrapper(cell=decoder_cell, memory=board_state_conv, alignments=board_alignments, sequence_length=raw_decoder_lengths, output_attention=False) # Setting initial state decoder_init_state = decoder_cell.zero_state(batch_size, tf.float32) # ---- Helper ---- helper = CustomHelper(decoder_type=decoder_type, inputs=decoder_inputs[:, :-1], embedding=word_embedding, sequence_length=decoder_lengths, mask=decoder_mask, time_major=False, softmax_temperature=temperature) # ---- Decoder ---- sequence_mask = tf.sequence_mask(raw_decoder_lengths, maxlen=tf.reduce_max(decoder_lengths), dtype=tf.float32) maximum_iterations = TOKENS_PER_ORDER * NB_SUPPLY_CENTERS model_decoder = MaskedBasicDecoder(cell=decoder_cell, helper=helper, initial_state=decoder_init_state, output_layer=decoder_output_layer, extract_state=True) training_results, _, _ = dynamic_decode(decoder=model_decoder, output_time_major=False, maximum_iterations=maximum_iterations, swap_memory=hps('swap_memory')) global_vars_after_decoder = set(tf.global_variables()) # ======== Beam Search Decoding ======== tile_beam = get_tile_beam(hps('beam_width')) # Applying dropout to input + attention and to output layer decoder_cell = SeededDropoutWrapper(cell=lstm_cell, seeds=tile_beam(player_seeds), input_keep_probs=tile_beam(1. - dropout_rates), output_keep_probs=tile_beam(1. - dropout_rates), variational_recurrent=hps('use_v_dropout'), input_size=hps('word_emb_size') + hps('attn_size'), dtype=tf.float32) # Apply attention over orderable location at each position decoder_cell = StaticAttentionWrapper(cell=decoder_cell, memory=tile_beam(board_state_conv), alignments=tile_beam(board_alignments), sequence_length=tile_beam(raw_decoder_lengths), output_attention=False) # Setting initial state decoder_init_state = decoder_cell.zero_state(batch_size * hps('beam_width'), tf.float32) # ---- Beam Helper and Decoder ---- beam_helper = CustomBeamHelper(cell=decoder_cell, embedding=word_embedding, mask=decoder_mask, sequence_length=decoder_lengths, output_layer=decoder_output_layer, initial_state=decoder_init_state, beam_width=hps('beam_width')) beam_decoder = DiverseBeamSearchDecoder(beam_helper=beam_helper, sequence_length=decoder_lengths, nb_groups=hps('beam_groups')) beam_results, beam_state, _ = dynamic_decode(decoder=beam_decoder, output_time_major=False, maximum_iterations=maximum_iterations, swap_memory=hps('swap_memory')) # Making sure we haven't created new global variables assert not set(tf.global_variables()) - global_vars_after_decoder, 'New global vars were created' # Processing results logits = training_results.rnn_output # (b, dec_len, VOCAB_SIZE) logits_length = tf.shape(logits)[1] # dec_len decoder_target = decoder_inputs[:, 1:1 + logits_length] # Selected tokens are the token that was actually fed at the next position sample_mask = to_float(tf.math.equal(training_results.sample_id, -1)) selected_tokens = to_int32( sequence_mask * (sample_mask * to_float(decoder_target) + (1. - sample_mask) * to_float(training_results.sample_id))) # Argmax tokens are the most likely token outputted at each position argmax_tokens = to_int32(to_float(tf.argmax(logits, axis=-1)) * sequence_mask) log_probs = -1. * cross_entropy(logits=logits, labels=selected_tokens) * sequence_mask # Computing policy loss with tf.variable_scope('policy_loss'): policy_loss = sequence_loss(logits=logits, targets=decoder_target, weights=sequence_mask, average_across_batch=True, average_across_timesteps=True) policy_loss = tf.cond(stop_gradient_all, lambda: tf.stop_gradient(policy_loss), # pylint: disable=cell-var-from-loop lambda: policy_loss) # pylint: disable=cell-var-from-loop # Building output tags outputs = {'tag/policy/token_based/v005_markovian_film_board_align': True, 'targets': decoder_inputs[:, 1:], 'selected_tokens': selected_tokens, 'argmax_tokens': argmax_tokens, 'logits': logits, 'log_probs': log_probs, 'beam_tokens': tf.transpose(beam_results.predicted_ids, perm=[0, 2, 1]), # [batch, beam, steps] 'beam_log_probs': beam_state.log_probs, 'rnn_states': training_results.rnn_state, 'policy_loss': policy_loss, 'draw_prob': self.outputs.get('draw_prob', tf.zeros_like(self.features['draw_target'])), 'learning_rate': self.learning_rate} # Adding features, placeholders and outputs to graph self.add_meta_information(outputs)
def _build_policy_final(self): """ Builds the policy model (final step) """ from diplomacy_research.utils.tensorflow import tf from diplomacy_research.models.layers.attention import AttentionWrapper, BahdanauAttention from diplomacy_research.models.layers.beam_decoder import DiverseBeamSearchDecoder from diplomacy_research.models.layers.decoder import CandidateBasicDecoder from diplomacy_research.models.layers.dropout import SeededDropoutWrapper from diplomacy_research.models.layers.dynamic_decode import dynamic_decode from diplomacy_research.models.policy.order_based.helper import CustomHelper, CustomBeamHelper from diplomacy_research.utils.tensorflow import cross_entropy, sequence_loss, to_int32, to_float, get_tile_beam # Quick function to retrieve hparams and placeholders and function shorthands hps = lambda hparam_name: self.hparams[hparam_name] pholder = lambda placeholder_name: self.placeholders[placeholder_name] # Training loop with tf.variable_scope('policy', reuse=tf.AUTO_REUSE): with tf.device(self.cluster_config.worker_device if self.cluster_config else None): # Features player_seeds = self.features['player_seed'] # tf.int32 - (b,) temperature = self.features['temperature'] # tf,flt32 - (b,) dropout_rates = self.features['dropout_rate'] # tf.flt32 - (b,) # Placeholders stop_gradient_all = pholder('stop_gradient_all') # Outputs (from initial steps) batch_size = self.outputs['batch_size'] decoder_inputs = self.outputs['decoder_inputs'] decoder_type = self.outputs['decoder_type'] raw_decoder_lengths = self.outputs['raw_decoder_lengths'] decoder_lengths = self.outputs['decoder_lengths'] board_state_conv = self.outputs['board_state_conv'] order_embedding = self.outputs['order_embedding'] candidate_embedding = self.outputs['candidate_embedding'] candidates = self.outputs['candidates'] max_candidate_length = self.outputs['max_candidate_length'] # --- Decoding --- with tf.variable_scope('decoder_scope', reuse=tf.AUTO_REUSE): lstm_cell = tf.contrib.rnn.LSTMBlockCell(hps('lstm_size')) # ======== Regular Decoding ======== # Applying dropout to input + attention and to output layer decoder_cell = SeededDropoutWrapper(cell=lstm_cell, seeds=player_seeds, input_keep_probs=1. - dropout_rates, output_keep_probs=1. - dropout_rates, variational_recurrent=hps('use_v_dropout'), input_size=hps('order_emb_size') + hps('attn_size'), dtype=tf.float32) # apply attention over location # curr_state [batch, NB_NODES, attn_size] attention_scope = tf.VariableScope(name='policy/decoder_scope/Attention', reuse=tf.AUTO_REUSE) attention_mechanism = BahdanauAttention(num_units=hps('attn_size'), memory=board_state_conv, normalize=True, name_or_scope=attention_scope) decoder_cell = AttentionWrapper(cell=decoder_cell, attention_mechanism=attention_mechanism, output_attention=False, name_or_scope=attention_scope) # Setting initial state decoder_init_state = decoder_cell.zero_state(batch_size, tf.float32) decoder_init_state = decoder_init_state.clone(attention=tf.reduce_mean(board_state_conv, axis=1)) # ---- Helper ---- helper = CustomHelper(decoder_type=decoder_type, inputs=decoder_inputs[:, :-1], order_embedding=order_embedding, candidate_embedding=candidate_embedding, sequence_length=decoder_lengths, candidates=candidates, time_major=False, softmax_temperature=temperature) # ---- Decoder ---- sequence_mask = tf.sequence_mask(raw_decoder_lengths, maxlen=tf.reduce_max(decoder_lengths), dtype=tf.float32) maximum_iterations = NB_SUPPLY_CENTERS model_decoder = CandidateBasicDecoder(cell=decoder_cell, helper=helper, initial_state=decoder_init_state, max_candidate_length=max_candidate_length, extract_state=True) training_results, _, _ = dynamic_decode(decoder=model_decoder, output_time_major=False, maximum_iterations=maximum_iterations, swap_memory=hps('swap_memory')) global_vars_after_decoder = set(tf.global_variables()) # ======== Beam Search Decoding ======== tile_beam = get_tile_beam(hps('beam_width')) # Applying dropout to input + attention and to output layer decoder_cell = SeededDropoutWrapper(cell=lstm_cell, seeds=tile_beam(player_seeds), input_keep_probs=tile_beam(1. - dropout_rates), output_keep_probs=tile_beam(1. - dropout_rates), variational_recurrent=hps('use_v_dropout'), input_size=hps('order_emb_size') + hps('attn_size'), dtype=tf.float32) # apply attention over location # curr_state [batch, NB_NODES, attn_size] attention_mechanism = BahdanauAttention(num_units=hps('attn_size'), memory=tile_beam(board_state_conv), normalize=True, name_or_scope=attention_scope) decoder_cell = AttentionWrapper(cell=decoder_cell, attention_mechanism=attention_mechanism, output_attention=False, name_or_scope=attention_scope) # Setting initial state decoder_init_state = decoder_cell.zero_state(batch_size * hps('beam_width'), tf.float32) decoder_init_state = decoder_init_state.clone(attention=tf.reduce_mean(tile_beam(board_state_conv), axis=1)) # ---- Beam Helper and Decoder ---- beam_helper = CustomBeamHelper(cell=decoder_cell, order_embedding=order_embedding, candidate_embedding=candidate_embedding, candidates=candidates, sequence_length=decoder_lengths, initial_state=decoder_init_state, beam_width=hps('beam_width')) beam_decoder = DiverseBeamSearchDecoder(beam_helper=beam_helper, sequence_length=decoder_lengths, nb_groups=hps('beam_groups')) beam_results, beam_state, _ = dynamic_decode(decoder=beam_decoder, output_time_major=False, maximum_iterations=maximum_iterations, swap_memory=hps('swap_memory')) # Making sure we haven't created new global variables assert not set(tf.global_variables()) - global_vars_after_decoder, 'New global vars were created' # Processing results candidate_logits = training_results.rnn_output # (b, dec_len, max_cand_len) logits_length = tf.shape(candidate_logits)[1] # dec_len decoder_target = decoder_inputs[:, 1:1 + logits_length] # Selected tokens are the token that was actually fed at the next position sample_mask = to_float(tf.math.equal(training_results.sample_id, -1)) selected_tokens = to_int32( sequence_mask * (sample_mask * to_float(decoder_target) + (1. - sample_mask) * to_float(training_results.sample_id))) # Computing ArgMax tokens argmax_id = to_int32(tf.argmax(candidate_logits, axis=-1)) max_nb_candidate = tf.shape(candidate_logits)[2] candidate_ids = \ tf.reduce_sum(tf.one_hot(argmax_id, max_nb_candidate, dtype=tf.int32) * candidates, axis=-1) argmax_tokens = to_int32(to_float(candidate_ids) * sequence_mask) # Extracting the position of the target candidate tokens_labels = tf.argmax(to_int32(tf.math.equal(selected_tokens[:, :, None], candidates)), -1) target_labels = tf.argmax(to_int32(tf.math.equal(decoder_target[:, :, None], candidates)), -1) # Log Probs log_probs = -1. * cross_entropy(logits=candidate_logits, labels=tokens_labels) * sequence_mask # Computing policy loss with tf.variable_scope('policy_loss'): policy_loss = sequence_loss(logits=candidate_logits, targets=target_labels, weights=sequence_mask, average_across_batch=True, average_across_timesteps=True) policy_loss = tf.cond(stop_gradient_all, lambda: tf.stop_gradient(policy_loss), # pylint: disable=cell-var-from-loop lambda: policy_loss) # pylint: disable=cell-var-from-loop # Building output tags outputs = {'tag/policy/order_based/v001_markovian_no_film': True, 'targets': decoder_inputs[:, 1:], 'selected_tokens': selected_tokens, 'argmax_tokens': argmax_tokens, 'logits': candidate_logits, 'log_probs': log_probs, 'beam_tokens': tf.transpose(beam_results.predicted_ids, perm=[0, 2, 1]), # [batch, beam, steps] 'beam_log_probs': beam_state.log_probs, 'rnn_states': training_results.rnn_state, 'policy_loss': policy_loss, 'draw_prob': self.outputs.get('draw_prob', tf.zeros_like(self.features['draw_target'])), 'learning_rate': self.learning_rate} # Adding features, placeholders and outputs to graph self.add_meta_information(outputs)
def _build_policy_final(self): """ Builds the policy model (final step) """ from diplomacy_research.utils.tensorflow import tf from diplomacy_research.models.layers.attention import StaticAttentionWrapper from diplomacy_research.models.layers.beam_decoder import DiverseBeamSearchDecoder from diplomacy_research.models.layers.decoder import CandidateBasicDecoder from diplomacy_research.models.layers.dropout import SeededDropoutWrapper from diplomacy_research.models.layers.dynamic_decode import dynamic_decode from diplomacy_research.models.layers.initializers import uniform from diplomacy_research.models.layers.transformer import TransformerCell from diplomacy_research.models.layers.wrappers import IdentityCell from diplomacy_research.models.policy.order_based.helper import CustomHelper, CustomBeamHelper from diplomacy_research.utils.tensorflow import cross_entropy, sequence_loss, to_int32, to_float, get_tile_beam # Quick function to retrieve hparams and placeholders and function shorthands hps = lambda hparam_name: self.hparams[hparam_name] pholder = lambda placeholder_name: self.placeholders[placeholder_name] # Training loop with tf.variable_scope('policy', reuse=tf.AUTO_REUSE): with tf.device(self.cluster_config.worker_device if self. cluster_config else None): # Features player_seeds = self.features['player_seed'] # tf.int32 - (b,) temperature = self.features['temperature'] # tf,flt32 - (b,) dropout_rates = self.features[ 'dropout_rate'] # tf.flt32 - (b,) # Placeholders stop_gradient_all = pholder('stop_gradient_all') # Outputs (from initial steps) batch_size = self.outputs['batch_size'] board_alignments = self.outputs['board_alignments'] decoder_inputs = self.outputs['decoder_inputs'] decoder_type = self.outputs['decoder_type'] raw_decoder_lengths = self.outputs['raw_decoder_lengths'] decoder_lengths = self.outputs['decoder_lengths'] board_state_conv = self.outputs['board_state_conv'] order_embedding = self.outputs['order_embedding'] candidate_embedding = self.outputs['candidate_embedding'] candidates = self.outputs['candidates'] max_candidate_length = self.outputs['max_candidate_length'] # Creating a smaller position embedding if it's not present in the outputs # Embeddings needs to be cached locally on the worker, otherwise TF can't compute their gradients with tf.variable_scope('position_embedding_scope'): caching_device = self.cluster_config.caching_device if self.cluster_config else None position_embedding = uniform( name='position_embedding', shape=[NB_SUPPLY_CENTERS, hps('trsf_emb_size')], scale=1., caching_device=caching_device) # Past Attentions past_attentions, message_lengths = None, None # --- Decoding --- with tf.variable_scope('decoder_scope', reuse=tf.AUTO_REUSE): feeder_cell = IdentityCell( output_size=hps('trsf_emb_size') + hps('attn_size')) # ======== Regular Decoding ======== # Applying Dropout to input, attention and output feeder_cell = SeededDropoutWrapper( cell=feeder_cell, seeds=player_seeds, input_keep_probs=1. - dropout_rates, variational_recurrent=hps('use_v_dropout'), input_size=hps('trsf_emb_size') + hps('attn_size'), dtype=tf.float32) # Apply attention over orderable location at each position feeder_cell = StaticAttentionWrapper( cell=feeder_cell, memory=board_state_conv, alignments=board_alignments, sequence_length=raw_decoder_lengths, output_attention=False) # Setting initial state feeder_cell_init_state = feeder_cell.zero_state( batch_size, tf.float32) # ---- Helper ---- helper = CustomHelper( decoder_type=decoder_type, inputs=decoder_inputs[:, :-1], order_embedding=order_embedding, candidate_embedding=candidate_embedding, sequence_length=decoder_lengths, candidates=candidates, time_major=False, softmax_temperature=temperature) # ---- Transformer Cell ---- trsf_scope = tf.VariableScope( name='policy/training_scope/transformer', reuse=False) transformer_cell = TransformerCell( nb_layers=hps('trsf_nb_layers'), nb_heads=hps('trsf_nb_heads'), word_embedding=order_embedding, position_embedding=position_embedding, batch_size=batch_size, feeder_cell=feeder_cell, feeder_init_state=feeder_cell_init_state, past_attentions=past_attentions, past_seq_lengths=message_lengths, scope=trsf_scope, name='transformer') transformer_cell_init_state = transformer_cell.zero_state( batch_size, tf.float32) # ---- Invariants ---- invariants_map = { 'past_attentions': tf.TensorShape([ None, # batch size hps('trsf_nb_layers'), # nb_layers 2, # key, value hps('trsf_nb_heads'), # nb heads None, # Seq len hps('trsf_emb_size') // hps('trsf_nb_heads') ]) } # Head size # ---- Decoder ---- sequence_mask = tf.sequence_mask( raw_decoder_lengths, maxlen=tf.reduce_max(decoder_lengths), dtype=tf.float32) maximum_iterations = NB_SUPPLY_CENTERS model_decoder = CandidateBasicDecoder( cell=transformer_cell, helper=helper, initial_state=transformer_cell_init_state, max_candidate_length=max_candidate_length, extract_state=True) training_results, _, _ = dynamic_decode( decoder=model_decoder, output_time_major=False, maximum_iterations=maximum_iterations, invariants_map=invariants_map, swap_memory=hps('swap_memory')) global_vars_after_decoder = set(tf.global_variables()) # ======== Beam Search Decoding ======== tile_beam = get_tile_beam(hps('beam_width')) beam_feeder_cell = IdentityCell( output_size=hps('trsf_emb_size') + hps('attn_size')) # Applying Dropout to input, attention and output beam_feeder_cell = SeededDropoutWrapper( cell=beam_feeder_cell, seeds=tile_beam(player_seeds), input_keep_probs=tile_beam(1. - dropout_rates), variational_recurrent=hps('use_v_dropout'), input_size=hps('trsf_emb_size') + hps('attn_size'), dtype=tf.float32) # Apply attention over orderable location at each position beam_feeder_cell = StaticAttentionWrapper( cell=beam_feeder_cell, memory=tile_beam(board_state_conv), alignments=tile_beam(board_alignments), sequence_length=tile_beam(raw_decoder_lengths), output_attention=False) # Setting initial state beam_feeder_init_state = beam_feeder_cell.zero_state( batch_size * hps('beam_width'), tf.float32) # ---- Transformer Cell ---- trsf_scope = tf.VariableScope( name='policy/training_scope/transformer', reuse=True) beam_trsf_cell = TransformerCell( nb_layers=hps('trsf_nb_layers'), nb_heads=hps('trsf_nb_heads'), word_embedding=order_embedding, position_embedding=position_embedding, batch_size=batch_size * hps('beam_width'), feeder_cell=beam_feeder_cell, feeder_init_state=beam_feeder_init_state, past_attentions=tile_beam(past_attentions), past_seq_lengths=tile_beam(message_lengths), scope=trsf_scope, name='transformer') beam_trsf_cell_init_state = beam_trsf_cell.zero_state( batch_size * hps('beam_width'), tf.float32) # ---- Beam Helper and Decoder ---- beam_helper = CustomBeamHelper( cell=beam_trsf_cell, order_embedding=order_embedding, candidate_embedding=candidate_embedding, candidates=candidates, sequence_length=decoder_lengths, initial_state=beam_trsf_cell_init_state, beam_width=hps('beam_width')) beam_decoder = DiverseBeamSearchDecoder( beam_helper=beam_helper, sequence_length=decoder_lengths, nb_groups=hps('beam_groups')) beam_results, beam_state, _ = dynamic_decode( decoder=beam_decoder, output_time_major=False, maximum_iterations=maximum_iterations, invariants_map=invariants_map, swap_memory=hps('swap_memory')) # Making sure we haven't created new global variables assert not set( tf.global_variables() ) - global_vars_after_decoder, 'New global vars were created' # Processing results candidate_logits = training_results.rnn_output # (b, dec_len, max_cand_len) logits_length = tf.shape(candidate_logits)[1] # dec_len decoder_target = decoder_inputs[:, 1:1 + logits_length] # Selected tokens are the token that was actually fed at the next position sample_mask = to_float( tf.math.equal(training_results.sample_id, -1)) selected_tokens = to_int32( sequence_mask * (sample_mask * to_float(decoder_target) + (1. - sample_mask) * to_float(training_results.sample_id))) # Computing ArgMax tokens argmax_id = to_int32(tf.argmax(candidate_logits, axis=-1)) max_nb_candidate = tf.shape(candidate_logits)[2] candidate_ids = \ tf.reduce_sum(tf.one_hot(argmax_id, max_nb_candidate, dtype=tf.int32) * candidates, axis=-1) argmax_tokens = to_int32( to_float(candidate_ids) * sequence_mask) # Extracting the position of the target candidate tokens_labels = tf.argmax( to_int32( tf.math.equal(selected_tokens[:, :, None], candidates)), -1) target_labels = tf.argmax( to_int32( tf.math.equal(decoder_target[:, :, None], candidates)), -1) # Log Probs log_probs = -1. * cross_entropy( logits=candidate_logits, labels=tokens_labels) * sequence_mask # Computing policy loss with tf.variable_scope('policy_loss'): policy_loss = sequence_loss(logits=candidate_logits, targets=target_labels, weights=sequence_mask, average_across_batch=True, average_across_timesteps=True) policy_loss = tf.cond( stop_gradient_all, lambda: tf.stop_gradient(policy_loss), # pylint: disable=cell-var-from-loop lambda: policy_loss) # pylint: disable=cell-var-from-loop # Building output tags outputs = { 'tag/policy/order_based/v015_film_transformer_gpt': True, 'targets': decoder_inputs[:, 1:], 'selected_tokens': selected_tokens, 'argmax_tokens': argmax_tokens, 'logits': candidate_logits, 'log_probs': log_probs, 'beam_tokens': tf.transpose(beam_results.predicted_ids, perm=[0, 2, 1]), # [batch, beam, steps] 'beam_log_probs': beam_state.log_probs, 'rnn_states': training_results.rnn_state, 'policy_loss': policy_loss, 'draw_prob': self.outputs.get('draw_prob', tf.zeros_like(self.features['draw_target'])), 'learning_rate': self.learning_rate } # Adding features, placeholders and outputs to graph self.add_meta_information(outputs)
def create_optimizer_op(self, cost_and_scope, ignored_scope=None, max_gradient_norm=None): """ Creates an optimizer op to reduce the cost :param cost_and_scope: List of tuples (cost, scope, ignored_scope) - cost is a tensor representing the cost to minimize - scope is either a string, or a list of strings. Contains the scope(s) where the get the vars to update - ignored_scope is either None, a string, or a list of strings. Contains scope(s) to ignore. :param ignored_scope: A scope or list of scope for which we know we won't compute gradients :param max_gradient_norm: Optional. If set, gradients will be clipped to this value. :return: The optimizer op Note: The ignored scope inside 'cost_and_scope' is local to that cost, while the arg ignored_scope is global for all costs. """ # pylint: disable=too-many-branches from diplomacy_research.utils.tensorflow import tf, scope_vars, ensure_finite from diplomacy_research.utils import gradient_checkpoint assert self.optimizer, 'Optimizer must be defined in self.optimizer before calling this method.' if self.cluster_config \ and self.hparams['training_mode'] == 'supervised' \ and 'sync_gradients' in self.hparams \ and self.hparams['sync_gradients']: assert isinstance(self.optimizer, tf.train.SyncReplicasOptimizer ), 'optimizer must be SyncReplicasOptimizer' assert self.hparams['grad_aggregation'].upper() in [ 'ADD_N', 'ACCUMULATE_N', 'TREE' ], 'Invalid aggregation' # Warning if more than 1 optimizer is created self.nb_optimizers += 1 if self.nb_optimizers > 1: LOGGER.warning( 'You have created %d optimizers for this model. This is not recommended (High memory usage)', self.nb_optimizers) # Determining aggregation_method based on accumulate_n flag aggregation_method = None if self.hparams['grad_aggregation'].upper() == 'ACCUMULATE_N': aggregation_method = tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N elif self.hparams['grad_aggregation'].upper() == 'TREE': aggregation_method = tf.AggregationMethod.EXPERIMENTAL_TREE # Finding all trainable variables all_trainable_vars = tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES) ignored_scope_trainable_vars = [] if isinstance(ignored_scope, list): for scope_name in ignored_scope: ignored_scope_trainable_vars += scope_vars(scope_name, trainable_only=True) elif ignored_scope is not None: ignored_scope_trainable_vars = scope_vars(ignored_scope, trainable_only=True) ignored_scope_trainable_vars = set(ignored_scope_trainable_vars) # Building a list of all trainable vars, and removing them when used by an op unused_trainable_vars = set(all_trainable_vars) - set( ignored_scope_trainable_vars) # Summing gradients if we are optimizing multiple costs global_gradients = {} for cost, scope, local_ignored_scope in cost_and_scope: local_ignored_vars = [] if isinstance(local_ignored_scope, list): for scope_name in local_ignored_scope: local_ignored_vars += scope_vars(scope_name, trainable_only=True) elif local_ignored_scope is not None: local_ignored_vars = scope_vars(local_ignored_scope, trainable_only=True) local_ignored_vars = set(local_ignored_vars) # Computing gradients with respect to all scope vars (except global ignored vars, but incl. local ignored) scope_trainable_vars = [] scope = [scope] if not isinstance(scope, list) else scope for scope_name in scope: for variable in scope_vars(scope_name, trainable_only=True): if variable not in scope_trainable_vars and variable not in ignored_scope_trainable_vars: scope_trainable_vars += [variable] # Computing gradients if self.hparams['gradient_checkpoint']: LOGGER.info( '****** Optimizing graph with gradient checkpointing...') gradients = gradient_checkpoint.gradients( cost, scope_trainable_vars, checkpoints=self.hparams['gradient_checkpoint'], aggregation_method=aggregation_method) LOGGER.info( 'Done optimizing graph with gradient checkpointing...') else: LOGGER.info('****** Computing gradients with respect to %s...', str(cost)) gradients = tf.gradients(cost, scope_trainable_vars, aggregation_method=aggregation_method) # Storing gradients in global_gradients for trainable_var, gradient in zip(scope_trainable_vars, gradients): if trainable_var in local_ignored_vars: continue if trainable_var in unused_trainable_vars: unused_trainable_vars.remove(trainable_var) if gradient is None: LOGGER.warning( 'Gradient for %s is None. Is the graph disconnected?', str(trainable_var)) continue if trainable_var.name in global_gradients: global_gradients[str(trainable_var.name)] += [gradient] else: global_gradients[str(trainable_var.name)] = [gradient] # Warning about missing trainable variables for variable in unused_trainable_vars: LOGGER.warning( 'The training variable %s has not been included in the optimizer_op.', str(variable)) # Warning about ignored training variables for variable in ignored_scope_trainable_vars: LOGGER.info('Ignoring variable: "%s" (Shape: %s).', str(variable.name), str(variable.shape)) # Computing and clipping gradients gradients = [] for variable in all_trainable_vars: var_gradients = global_gradients.get(str(variable.name), []) if not var_gradients: gradients += [None] elif len(var_gradients) == 1: gradients += var_gradients else: if [ 1 for grad in var_gradients if isinstance(grad, tf.IndexedSlices) ]: LOGGER.info('Adding IndexedSlices for %s', variable) gradients += [ tf.add_n(var_gradients, name='%s/Add_N' % (variable.name.split(':')[0])) ] gradients = [ensure_finite(gradient) for gradient in gradients] if max_gradient_norm is not None: gradients, _ = tf.clip_by_global_norm(gradients, max_gradient_norm) # Finding update ops update_ops = [] for _, scope, _ in cost_and_scope: if isinstance(scope, list): for scope_name in scope: for update_op in tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope=scope_name): if update_op not in update_ops: update_ops += [update_op] else: update_ops += tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope=scope) # Printing number of variables global_vars = tf.global_variables() LOGGER.info('Model has %d global vars and %d trainable vars', len(global_vars), len(all_trainable_vars)) # Computing the number of parameters nb_global_params = sum([ reduce(mul, variable.shape.as_list(), 1) for variable in global_vars ]) nb_trainable_params = sum([ reduce(mul, variable.shape.as_list(), 1) for variable in all_trainable_vars ]) LOGGER.info('Model has %s parameters (%s for trainable vars)', '{:,}'.format(nb_global_params), '{:,}'.format(nb_trainable_params)) # Creating optimizer op (with dependencies on update for batch norm) with tf.control_dependencies(update_ops): opt_op = self.optimizer.apply_gradients( zip(gradients, all_trainable_vars), global_step=self.global_step) # Returning optimization op return opt_op