コード例 #1
0
 def test_get_board_value(model):
     """ Tests the get_board_value method """
     from diplomacy_research.utils.tensorflow import tf
     nb_global_vars_before = len(tf.global_variables())
     board_state = tf.placeholder(dtype=tf.float32,
                                  shape=[None, NB_NODES, NB_FEATURES],
                                  name='fake_board')
     current_power = tf.placeholder(dtype=tf.int32,
                                    shape=[None],
                                    name='fake_current_power')
     model.get_board_value(board_state, current_power, reuse=True)
     nb_global_vars_after = len(tf.global_variables())
     assert nb_global_vars_before == nb_global_vars_after, 'New variables added when getting board value.'
コード例 #2
0
    def initialize(self, session):
        """ Initialize the adapter (init global vars and the dataset)
            :type session: tensorflow.python.client.session.Session
        """
        if not self.feedable_dataset.can_support_iterator or not self.iterator:
            return

        from diplomacy_research.utils.tensorflow import tf
        assert session, 'You must pass a session to initialize the adapter'
        assert isinstance(self.feedable_dataset,
                          QueueDataset), 'The dataset must be a QueueDataset'
        self.session = session

        # Initializes uninit global vars
        graph = self.graph or tf.get_default_graph()
        if not graph.finalized:
            with graph.as_default():
                var_to_initialize = tf.global_variables() + tf.local_variables(
                )
                is_initialized = self.session.run([
                    tf.is_variable_initialized(var)
                    for var in var_to_initialize
                ])
                not_initialized_vars = [
                    var for (var,
                             is_init) in zip(var_to_initialize, is_initialized)
                    if not is_init
                ]
                if not_initialized_vars:
                    LOGGER.info('Initialized %d variables.',
                                len(not_initialized_vars))
                    self.session.run(
                        tf.variables_initializer(not_initialized_vars))

        # Initializing the dataset to use the feedable model
        if not self.feedable_dataset.is_started and self.session:
            self.feedable_dataset.start(self.session)
        elif not self.feedable_dataset.is_initialized and self.session:
            self.feedable_dataset.initialize(self.session)
コード例 #3
0
ファイル: model.py プロジェクト: zhanpengfang/research
    def _build_policy_final(self):
        """ Builds the policy model (final step) """
        from diplomacy_research.utils.tensorflow import tf
        from diplomacy_research.models.layers.attention import StaticAttentionWrapper
        from diplomacy_research.models.layers.beam_decoder import DiverseBeamSearchDecoder
        from diplomacy_research.models.layers.decoder import MaskedBasicDecoder
        from diplomacy_research.models.layers.dropout import SeededDropoutWrapper
        from diplomacy_research.models.layers.dynamic_decode import dynamic_decode
        from diplomacy_research.models.policy.token_based.helper import CustomHelper, CustomBeamHelper
        from diplomacy_research.utils.tensorflow import cross_entropy, sequence_loss, to_int32, to_float, get_tile_beam

        # Quick function to retrieve hparams and placeholders and function shorthands
        hps = lambda hparam_name: self.hparams[hparam_name]
        pholder = lambda placeholder_name: self.placeholders[placeholder_name]

        # Training loop
        with tf.variable_scope('policy', reuse=tf.AUTO_REUSE):
            with tf.device(self.cluster_config.worker_device if self.cluster_config else None):

                # Features
                player_seeds = self.features['player_seed']                 # tf.int32 - (b,)
                temperature = self.features['temperature']                  # tf,flt32 - (b,)
                dropout_rates = self.features['dropout_rate']               # tf.flt32 - (b,)

                # Placeholders
                stop_gradient_all = pholder('stop_gradient_all')

                # Outputs (from initial steps)
                batch_size = self.outputs['batch_size']
                board_alignments = self.outputs['board_alignments']
                decoder_inputs = self.outputs['decoder_inputs']
                decoder_mask = self.outputs['decoder_mask']
                decoder_type = self.outputs['decoder_type']
                raw_decoder_lengths = self.outputs['raw_decoder_lengths']
                decoder_lengths = self.outputs['decoder_lengths']
                board_state_conv = self.outputs['board_state_conv']
                word_embedding = self.outputs['word_embedding']

                # --- Decoding ---
                with tf.variable_scope('decoder_scope', reuse=tf.AUTO_REUSE):
                    lstm_cell = tf.contrib.rnn.LSTMBlockCell(hps('lstm_size'))

                    # decoder output to token
                    decoder_output_layer = tf.layers.Dense(units=VOCABULARY_SIZE,
                                                           activation=None,
                                                           kernel_initializer=tf.random_normal_initializer,
                                                           use_bias=True)

                    # ======== Regular Decoding ========
                    # Applying dropout to input + attention and to output layer
                    decoder_cell = SeededDropoutWrapper(cell=lstm_cell,
                                                        seeds=player_seeds,
                                                        input_keep_probs=1. - dropout_rates,
                                                        output_keep_probs=1. - dropout_rates,
                                                        variational_recurrent=hps('use_v_dropout'),
                                                        input_size=hps('word_emb_size') + hps('attn_size'),
                                                        dtype=tf.float32)

                    # Apply attention over orderable location at each position
                    decoder_cell = StaticAttentionWrapper(cell=decoder_cell,
                                                          memory=board_state_conv,
                                                          alignments=board_alignments,
                                                          sequence_length=raw_decoder_lengths,
                                                          output_attention=False)

                    # Setting initial state
                    decoder_init_state = decoder_cell.zero_state(batch_size, tf.float32)

                    # ---- Helper ----
                    helper = CustomHelper(decoder_type=decoder_type,
                                          inputs=decoder_inputs[:, :-1],
                                          embedding=word_embedding,
                                          sequence_length=decoder_lengths,
                                          mask=decoder_mask,
                                          time_major=False,
                                          softmax_temperature=temperature)

                    # ---- Decoder ----
                    sequence_mask = tf.sequence_mask(raw_decoder_lengths,
                                                     maxlen=tf.reduce_max(decoder_lengths),
                                                     dtype=tf.float32)
                    maximum_iterations = TOKENS_PER_ORDER * NB_SUPPLY_CENTERS
                    model_decoder = MaskedBasicDecoder(cell=decoder_cell,
                                                       helper=helper,
                                                       initial_state=decoder_init_state,
                                                       output_layer=decoder_output_layer,
                                                       extract_state=True)
                    training_results, _, _ = dynamic_decode(decoder=model_decoder,
                                                            output_time_major=False,
                                                            maximum_iterations=maximum_iterations,
                                                            swap_memory=hps('swap_memory'))
                    global_vars_after_decoder = set(tf.global_variables())

                    # ======== Beam Search Decoding ========
                    tile_beam = get_tile_beam(hps('beam_width'))

                    # Applying dropout to input + attention and to output layer
                    decoder_cell = SeededDropoutWrapper(cell=lstm_cell,
                                                        seeds=tile_beam(player_seeds),
                                                        input_keep_probs=tile_beam(1. - dropout_rates),
                                                        output_keep_probs=tile_beam(1. - dropout_rates),
                                                        variational_recurrent=hps('use_v_dropout'),
                                                        input_size=hps('word_emb_size') + hps('attn_size'),
                                                        dtype=tf.float32)

                    # Apply attention over orderable location at each position
                    decoder_cell = StaticAttentionWrapper(cell=decoder_cell,
                                                          memory=tile_beam(board_state_conv),
                                                          alignments=tile_beam(board_alignments),
                                                          sequence_length=tile_beam(raw_decoder_lengths),
                                                          output_attention=False)

                    # Setting initial state
                    decoder_init_state = decoder_cell.zero_state(batch_size * hps('beam_width'), tf.float32)

                    # ---- Beam Helper and Decoder ----
                    beam_helper = CustomBeamHelper(cell=decoder_cell,
                                                   embedding=word_embedding,
                                                   mask=decoder_mask,
                                                   sequence_length=decoder_lengths,
                                                   output_layer=decoder_output_layer,
                                                   initial_state=decoder_init_state,
                                                   beam_width=hps('beam_width'))
                    beam_decoder = DiverseBeamSearchDecoder(beam_helper=beam_helper,
                                                            sequence_length=decoder_lengths,
                                                            nb_groups=hps('beam_groups'))
                    beam_results, beam_state, _ = dynamic_decode(decoder=beam_decoder,
                                                                 output_time_major=False,
                                                                 maximum_iterations=maximum_iterations,
                                                                 swap_memory=hps('swap_memory'))

                    # Making sure we haven't created new global variables
                    assert not set(tf.global_variables()) - global_vars_after_decoder, 'New global vars were created'

                    # Processing results
                    logits = training_results.rnn_output                            # (b, dec_len, VOCAB_SIZE)
                    logits_length = tf.shape(logits)[1]                             # dec_len
                    decoder_target = decoder_inputs[:, 1:1 + logits_length]

                    # Selected tokens are the token that was actually fed at the next position
                    sample_mask = to_float(tf.math.equal(training_results.sample_id, -1))
                    selected_tokens = to_int32(
                        sequence_mask * (sample_mask * to_float(decoder_target)
                                         + (1. - sample_mask) * to_float(training_results.sample_id)))

                    # Argmax tokens are the most likely token outputted at each position
                    argmax_tokens = to_int32(to_float(tf.argmax(logits, axis=-1)) * sequence_mask)
                    log_probs = -1. * cross_entropy(logits=logits, labels=selected_tokens) * sequence_mask

                # Computing policy loss
                with tf.variable_scope('policy_loss'):
                    policy_loss = sequence_loss(logits=logits,
                                                targets=decoder_target,
                                                weights=sequence_mask,
                                                average_across_batch=True,
                                                average_across_timesteps=True)
                    policy_loss = tf.cond(stop_gradient_all,
                                          lambda: tf.stop_gradient(policy_loss),                                        # pylint: disable=cell-var-from-loop
                                          lambda: policy_loss)                                                          # pylint: disable=cell-var-from-loop

        # Building output tags
        outputs = {'tag/policy/token_based/v005_markovian_film_board_align': True,
                   'targets': decoder_inputs[:, 1:],
                   'selected_tokens': selected_tokens,
                   'argmax_tokens': argmax_tokens,
                   'logits': logits,
                   'log_probs': log_probs,
                   'beam_tokens': tf.transpose(beam_results.predicted_ids, perm=[0, 2, 1]),     # [batch, beam, steps]
                   'beam_log_probs': beam_state.log_probs,
                   'rnn_states': training_results.rnn_state,
                   'policy_loss': policy_loss,
                   'draw_prob': self.outputs.get('draw_prob', tf.zeros_like(self.features['draw_target'])),
                   'learning_rate': self.learning_rate}

        # Adding features, placeholders and outputs to graph
        self.add_meta_information(outputs)
コード例 #4
0
    def _build_policy_final(self):
        """ Builds the policy model (final step) """
        from diplomacy_research.utils.tensorflow import tf
        from diplomacy_research.models.layers.attention import AttentionWrapper, BahdanauAttention
        from diplomacy_research.models.layers.beam_decoder import DiverseBeamSearchDecoder
        from diplomacy_research.models.layers.decoder import CandidateBasicDecoder
        from diplomacy_research.models.layers.dropout import SeededDropoutWrapper
        from diplomacy_research.models.layers.dynamic_decode import dynamic_decode
        from diplomacy_research.models.policy.order_based.helper import CustomHelper, CustomBeamHelper
        from diplomacy_research.utils.tensorflow import cross_entropy, sequence_loss, to_int32, to_float, get_tile_beam

        # Quick function to retrieve hparams and placeholders and function shorthands
        hps = lambda hparam_name: self.hparams[hparam_name]
        pholder = lambda placeholder_name: self.placeholders[placeholder_name]

        # Training loop
        with tf.variable_scope('policy', reuse=tf.AUTO_REUSE):
            with tf.device(self.cluster_config.worker_device if self.cluster_config else None):

                # Features
                player_seeds = self.features['player_seed']                 # tf.int32 - (b,)
                temperature = self.features['temperature']                  # tf,flt32 - (b,)
                dropout_rates = self.features['dropout_rate']               # tf.flt32 - (b,)

                # Placeholders
                stop_gradient_all = pholder('stop_gradient_all')

                # Outputs (from initial steps)
                batch_size = self.outputs['batch_size']
                decoder_inputs = self.outputs['decoder_inputs']
                decoder_type = self.outputs['decoder_type']
                raw_decoder_lengths = self.outputs['raw_decoder_lengths']
                decoder_lengths = self.outputs['decoder_lengths']
                board_state_conv = self.outputs['board_state_conv']
                order_embedding = self.outputs['order_embedding']
                candidate_embedding = self.outputs['candidate_embedding']
                candidates = self.outputs['candidates']
                max_candidate_length = self.outputs['max_candidate_length']

                # --- Decoding ---
                with tf.variable_scope('decoder_scope', reuse=tf.AUTO_REUSE):
                    lstm_cell = tf.contrib.rnn.LSTMBlockCell(hps('lstm_size'))

                    # ======== Regular Decoding ========
                    # Applying dropout to input + attention and to output layer
                    decoder_cell = SeededDropoutWrapper(cell=lstm_cell,
                                                        seeds=player_seeds,
                                                        input_keep_probs=1. - dropout_rates,
                                                        output_keep_probs=1. - dropout_rates,
                                                        variational_recurrent=hps('use_v_dropout'),
                                                        input_size=hps('order_emb_size') + hps('attn_size'),
                                                        dtype=tf.float32)

                    # apply attention over location
                    # curr_state [batch, NB_NODES, attn_size]
                    attention_scope = tf.VariableScope(name='policy/decoder_scope/Attention', reuse=tf.AUTO_REUSE)
                    attention_mechanism = BahdanauAttention(num_units=hps('attn_size'),
                                                            memory=board_state_conv,
                                                            normalize=True,
                                                            name_or_scope=attention_scope)
                    decoder_cell = AttentionWrapper(cell=decoder_cell,
                                                    attention_mechanism=attention_mechanism,
                                                    output_attention=False,
                                                    name_or_scope=attention_scope)

                    # Setting initial state
                    decoder_init_state = decoder_cell.zero_state(batch_size, tf.float32)
                    decoder_init_state = decoder_init_state.clone(attention=tf.reduce_mean(board_state_conv, axis=1))

                    # ---- Helper ----
                    helper = CustomHelper(decoder_type=decoder_type,
                                          inputs=decoder_inputs[:, :-1],
                                          order_embedding=order_embedding,
                                          candidate_embedding=candidate_embedding,
                                          sequence_length=decoder_lengths,
                                          candidates=candidates,
                                          time_major=False,
                                          softmax_temperature=temperature)

                    # ---- Decoder ----
                    sequence_mask = tf.sequence_mask(raw_decoder_lengths,
                                                     maxlen=tf.reduce_max(decoder_lengths),
                                                     dtype=tf.float32)
                    maximum_iterations = NB_SUPPLY_CENTERS
                    model_decoder = CandidateBasicDecoder(cell=decoder_cell,
                                                          helper=helper,
                                                          initial_state=decoder_init_state,
                                                          max_candidate_length=max_candidate_length,
                                                          extract_state=True)
                    training_results, _, _ = dynamic_decode(decoder=model_decoder,
                                                            output_time_major=False,
                                                            maximum_iterations=maximum_iterations,
                                                            swap_memory=hps('swap_memory'))
                    global_vars_after_decoder = set(tf.global_variables())

                    # ======== Beam Search Decoding ========
                    tile_beam = get_tile_beam(hps('beam_width'))

                    # Applying dropout to input + attention and to output layer
                    decoder_cell = SeededDropoutWrapper(cell=lstm_cell,
                                                        seeds=tile_beam(player_seeds),
                                                        input_keep_probs=tile_beam(1. - dropout_rates),
                                                        output_keep_probs=tile_beam(1. - dropout_rates),
                                                        variational_recurrent=hps('use_v_dropout'),
                                                        input_size=hps('order_emb_size') + hps('attn_size'),
                                                        dtype=tf.float32)

                    # apply attention over location
                    # curr_state [batch, NB_NODES, attn_size]
                    attention_mechanism = BahdanauAttention(num_units=hps('attn_size'),
                                                            memory=tile_beam(board_state_conv),
                                                            normalize=True,
                                                            name_or_scope=attention_scope)
                    decoder_cell = AttentionWrapper(cell=decoder_cell,
                                                    attention_mechanism=attention_mechanism,
                                                    output_attention=False,
                                                    name_or_scope=attention_scope)

                    # Setting initial state
                    decoder_init_state = decoder_cell.zero_state(batch_size * hps('beam_width'), tf.float32)
                    decoder_init_state = decoder_init_state.clone(attention=tf.reduce_mean(tile_beam(board_state_conv),
                                                                                           axis=1))

                    # ---- Beam Helper and Decoder ----
                    beam_helper = CustomBeamHelper(cell=decoder_cell,
                                                   order_embedding=order_embedding,
                                                   candidate_embedding=candidate_embedding,
                                                   candidates=candidates,
                                                   sequence_length=decoder_lengths,
                                                   initial_state=decoder_init_state,
                                                   beam_width=hps('beam_width'))
                    beam_decoder = DiverseBeamSearchDecoder(beam_helper=beam_helper,
                                                            sequence_length=decoder_lengths,
                                                            nb_groups=hps('beam_groups'))
                    beam_results, beam_state, _ = dynamic_decode(decoder=beam_decoder,
                                                                 output_time_major=False,
                                                                 maximum_iterations=maximum_iterations,
                                                                 swap_memory=hps('swap_memory'))

                    # Making sure we haven't created new global variables
                    assert not set(tf.global_variables()) - global_vars_after_decoder, 'New global vars were created'

                    # Processing results
                    candidate_logits = training_results.rnn_output                  # (b, dec_len, max_cand_len)
                    logits_length = tf.shape(candidate_logits)[1]                   # dec_len
                    decoder_target = decoder_inputs[:, 1:1 + logits_length]

                    # Selected tokens are the token that was actually fed at the next position
                    sample_mask = to_float(tf.math.equal(training_results.sample_id, -1))
                    selected_tokens = to_int32(
                        sequence_mask * (sample_mask * to_float(decoder_target)
                                         + (1. - sample_mask) * to_float(training_results.sample_id)))

                    # Computing ArgMax tokens
                    argmax_id = to_int32(tf.argmax(candidate_logits, axis=-1))
                    max_nb_candidate = tf.shape(candidate_logits)[2]
                    candidate_ids = \
                        tf.reduce_sum(tf.one_hot(argmax_id, max_nb_candidate, dtype=tf.int32) * candidates, axis=-1)
                    argmax_tokens = to_int32(to_float(candidate_ids) * sequence_mask)

                    # Extracting the position of the target candidate
                    tokens_labels = tf.argmax(to_int32(tf.math.equal(selected_tokens[:, :, None], candidates)), -1)
                    target_labels = tf.argmax(to_int32(tf.math.equal(decoder_target[:, :, None], candidates)), -1)

                    # Log Probs
                    log_probs = -1. * cross_entropy(logits=candidate_logits, labels=tokens_labels) * sequence_mask

                # Computing policy loss
                with tf.variable_scope('policy_loss'):
                    policy_loss = sequence_loss(logits=candidate_logits,
                                                targets=target_labels,
                                                weights=sequence_mask,
                                                average_across_batch=True,
                                                average_across_timesteps=True)
                    policy_loss = tf.cond(stop_gradient_all,
                                          lambda: tf.stop_gradient(policy_loss),                                        # pylint: disable=cell-var-from-loop
                                          lambda: policy_loss)                                                          # pylint: disable=cell-var-from-loop

        # Building output tags
        outputs = {'tag/policy/order_based/v001_markovian_no_film': True,
                   'targets': decoder_inputs[:, 1:],
                   'selected_tokens': selected_tokens,
                   'argmax_tokens': argmax_tokens,
                   'logits': candidate_logits,
                   'log_probs': log_probs,
                   'beam_tokens': tf.transpose(beam_results.predicted_ids, perm=[0, 2, 1]),     # [batch, beam, steps]
                   'beam_log_probs': beam_state.log_probs,
                   'rnn_states': training_results.rnn_state,
                   'policy_loss': policy_loss,
                   'draw_prob': self.outputs.get('draw_prob', tf.zeros_like(self.features['draw_target'])),
                   'learning_rate': self.learning_rate}

        # Adding features, placeholders and outputs to graph
        self.add_meta_information(outputs)
コード例 #5
0
    def _build_policy_final(self):
        """ Builds the policy model (final step) """
        from diplomacy_research.utils.tensorflow import tf
        from diplomacy_research.models.layers.attention import StaticAttentionWrapper
        from diplomacy_research.models.layers.beam_decoder import DiverseBeamSearchDecoder
        from diplomacy_research.models.layers.decoder import CandidateBasicDecoder
        from diplomacy_research.models.layers.dropout import SeededDropoutWrapper
        from diplomacy_research.models.layers.dynamic_decode import dynamic_decode
        from diplomacy_research.models.layers.initializers import uniform
        from diplomacy_research.models.layers.transformer import TransformerCell
        from diplomacy_research.models.layers.wrappers import IdentityCell
        from diplomacy_research.models.policy.order_based.helper import CustomHelper, CustomBeamHelper
        from diplomacy_research.utils.tensorflow import cross_entropy, sequence_loss, to_int32, to_float, get_tile_beam

        # Quick function to retrieve hparams and placeholders and function shorthands
        hps = lambda hparam_name: self.hparams[hparam_name]
        pholder = lambda placeholder_name: self.placeholders[placeholder_name]

        # Training loop
        with tf.variable_scope('policy', reuse=tf.AUTO_REUSE):
            with tf.device(self.cluster_config.worker_device if self.
                           cluster_config else None):

                # Features
                player_seeds = self.features['player_seed']  # tf.int32 - (b,)
                temperature = self.features['temperature']  # tf,flt32 - (b,)
                dropout_rates = self.features[
                    'dropout_rate']  # tf.flt32 - (b,)

                # Placeholders
                stop_gradient_all = pholder('stop_gradient_all')

                # Outputs (from initial steps)
                batch_size = self.outputs['batch_size']
                board_alignments = self.outputs['board_alignments']
                decoder_inputs = self.outputs['decoder_inputs']
                decoder_type = self.outputs['decoder_type']
                raw_decoder_lengths = self.outputs['raw_decoder_lengths']
                decoder_lengths = self.outputs['decoder_lengths']
                board_state_conv = self.outputs['board_state_conv']
                order_embedding = self.outputs['order_embedding']
                candidate_embedding = self.outputs['candidate_embedding']
                candidates = self.outputs['candidates']
                max_candidate_length = self.outputs['max_candidate_length']

                # Creating a smaller position embedding if it's not present in the outputs
                # Embeddings needs to be cached locally on the worker, otherwise TF can't compute their gradients
                with tf.variable_scope('position_embedding_scope'):
                    caching_device = self.cluster_config.caching_device if self.cluster_config else None
                    position_embedding = uniform(
                        name='position_embedding',
                        shape=[NB_SUPPLY_CENTERS,
                               hps('trsf_emb_size')],
                        scale=1.,
                        caching_device=caching_device)

                # Past Attentions
                past_attentions, message_lengths = None, None

                # --- Decoding ---
                with tf.variable_scope('decoder_scope', reuse=tf.AUTO_REUSE):
                    feeder_cell = IdentityCell(
                        output_size=hps('trsf_emb_size') + hps('attn_size'))

                    # ======== Regular Decoding ========
                    # Applying Dropout to input, attention and output
                    feeder_cell = SeededDropoutWrapper(
                        cell=feeder_cell,
                        seeds=player_seeds,
                        input_keep_probs=1. - dropout_rates,
                        variational_recurrent=hps('use_v_dropout'),
                        input_size=hps('trsf_emb_size') + hps('attn_size'),
                        dtype=tf.float32)

                    # Apply attention over orderable location at each position
                    feeder_cell = StaticAttentionWrapper(
                        cell=feeder_cell,
                        memory=board_state_conv,
                        alignments=board_alignments,
                        sequence_length=raw_decoder_lengths,
                        output_attention=False)

                    # Setting initial state
                    feeder_cell_init_state = feeder_cell.zero_state(
                        batch_size, tf.float32)

                    # ---- Helper ----
                    helper = CustomHelper(
                        decoder_type=decoder_type,
                        inputs=decoder_inputs[:, :-1],
                        order_embedding=order_embedding,
                        candidate_embedding=candidate_embedding,
                        sequence_length=decoder_lengths,
                        candidates=candidates,
                        time_major=False,
                        softmax_temperature=temperature)

                    # ---- Transformer Cell ----
                    trsf_scope = tf.VariableScope(
                        name='policy/training_scope/transformer', reuse=False)
                    transformer_cell = TransformerCell(
                        nb_layers=hps('trsf_nb_layers'),
                        nb_heads=hps('trsf_nb_heads'),
                        word_embedding=order_embedding,
                        position_embedding=position_embedding,
                        batch_size=batch_size,
                        feeder_cell=feeder_cell,
                        feeder_init_state=feeder_cell_init_state,
                        past_attentions=past_attentions,
                        past_seq_lengths=message_lengths,
                        scope=trsf_scope,
                        name='transformer')
                    transformer_cell_init_state = transformer_cell.zero_state(
                        batch_size, tf.float32)

                    # ---- Invariants ----
                    invariants_map = {
                        'past_attentions':
                        tf.TensorShape([
                            None,  # batch size
                            hps('trsf_nb_layers'),  # nb_layers
                            2,  # key, value
                            hps('trsf_nb_heads'),  # nb heads
                            None,  # Seq len
                            hps('trsf_emb_size') // hps('trsf_nb_heads')
                        ])
                    }  # Head size

                    # ---- Decoder ----
                    sequence_mask = tf.sequence_mask(
                        raw_decoder_lengths,
                        maxlen=tf.reduce_max(decoder_lengths),
                        dtype=tf.float32)
                    maximum_iterations = NB_SUPPLY_CENTERS
                    model_decoder = CandidateBasicDecoder(
                        cell=transformer_cell,
                        helper=helper,
                        initial_state=transformer_cell_init_state,
                        max_candidate_length=max_candidate_length,
                        extract_state=True)
                    training_results, _, _ = dynamic_decode(
                        decoder=model_decoder,
                        output_time_major=False,
                        maximum_iterations=maximum_iterations,
                        invariants_map=invariants_map,
                        swap_memory=hps('swap_memory'))
                    global_vars_after_decoder = set(tf.global_variables())

                    # ======== Beam Search Decoding ========
                    tile_beam = get_tile_beam(hps('beam_width'))
                    beam_feeder_cell = IdentityCell(
                        output_size=hps('trsf_emb_size') + hps('attn_size'))

                    # Applying Dropout to input, attention and output
                    beam_feeder_cell = SeededDropoutWrapper(
                        cell=beam_feeder_cell,
                        seeds=tile_beam(player_seeds),
                        input_keep_probs=tile_beam(1. - dropout_rates),
                        variational_recurrent=hps('use_v_dropout'),
                        input_size=hps('trsf_emb_size') + hps('attn_size'),
                        dtype=tf.float32)

                    # Apply attention over orderable location at each position
                    beam_feeder_cell = StaticAttentionWrapper(
                        cell=beam_feeder_cell,
                        memory=tile_beam(board_state_conv),
                        alignments=tile_beam(board_alignments),
                        sequence_length=tile_beam(raw_decoder_lengths),
                        output_attention=False)

                    # Setting initial state
                    beam_feeder_init_state = beam_feeder_cell.zero_state(
                        batch_size * hps('beam_width'), tf.float32)

                    # ---- Transformer Cell ----
                    trsf_scope = tf.VariableScope(
                        name='policy/training_scope/transformer', reuse=True)
                    beam_trsf_cell = TransformerCell(
                        nb_layers=hps('trsf_nb_layers'),
                        nb_heads=hps('trsf_nb_heads'),
                        word_embedding=order_embedding,
                        position_embedding=position_embedding,
                        batch_size=batch_size * hps('beam_width'),
                        feeder_cell=beam_feeder_cell,
                        feeder_init_state=beam_feeder_init_state,
                        past_attentions=tile_beam(past_attentions),
                        past_seq_lengths=tile_beam(message_lengths),
                        scope=trsf_scope,
                        name='transformer')
                    beam_trsf_cell_init_state = beam_trsf_cell.zero_state(
                        batch_size * hps('beam_width'), tf.float32)

                    # ---- Beam Helper and Decoder ----
                    beam_helper = CustomBeamHelper(
                        cell=beam_trsf_cell,
                        order_embedding=order_embedding,
                        candidate_embedding=candidate_embedding,
                        candidates=candidates,
                        sequence_length=decoder_lengths,
                        initial_state=beam_trsf_cell_init_state,
                        beam_width=hps('beam_width'))
                    beam_decoder = DiverseBeamSearchDecoder(
                        beam_helper=beam_helper,
                        sequence_length=decoder_lengths,
                        nb_groups=hps('beam_groups'))
                    beam_results, beam_state, _ = dynamic_decode(
                        decoder=beam_decoder,
                        output_time_major=False,
                        maximum_iterations=maximum_iterations,
                        invariants_map=invariants_map,
                        swap_memory=hps('swap_memory'))

                    # Making sure we haven't created new global variables
                    assert not set(
                        tf.global_variables()
                    ) - global_vars_after_decoder, 'New global vars were created'

                    # Processing results
                    candidate_logits = training_results.rnn_output  # (b, dec_len, max_cand_len)
                    logits_length = tf.shape(candidate_logits)[1]  # dec_len
                    decoder_target = decoder_inputs[:, 1:1 + logits_length]

                    # Selected tokens are the token that was actually fed at the next position
                    sample_mask = to_float(
                        tf.math.equal(training_results.sample_id, -1))
                    selected_tokens = to_int32(
                        sequence_mask *
                        (sample_mask * to_float(decoder_target) +
                         (1. - sample_mask) *
                         to_float(training_results.sample_id)))

                    # Computing ArgMax tokens
                    argmax_id = to_int32(tf.argmax(candidate_logits, axis=-1))
                    max_nb_candidate = tf.shape(candidate_logits)[2]
                    candidate_ids = \
                        tf.reduce_sum(tf.one_hot(argmax_id, max_nb_candidate, dtype=tf.int32) * candidates, axis=-1)
                    argmax_tokens = to_int32(
                        to_float(candidate_ids) * sequence_mask)

                    # Extracting the position of the target candidate
                    tokens_labels = tf.argmax(
                        to_int32(
                            tf.math.equal(selected_tokens[:, :, None],
                                          candidates)), -1)
                    target_labels = tf.argmax(
                        to_int32(
                            tf.math.equal(decoder_target[:, :, None],
                                          candidates)), -1)

                    # Log Probs
                    log_probs = -1. * cross_entropy(
                        logits=candidate_logits,
                        labels=tokens_labels) * sequence_mask

                # Computing policy loss
                with tf.variable_scope('policy_loss'):
                    policy_loss = sequence_loss(logits=candidate_logits,
                                                targets=target_labels,
                                                weights=sequence_mask,
                                                average_across_batch=True,
                                                average_across_timesteps=True)
                    policy_loss = tf.cond(
                        stop_gradient_all,
                        lambda: tf.stop_gradient(policy_loss),  # pylint: disable=cell-var-from-loop
                        lambda: policy_loss)  # pylint: disable=cell-var-from-loop

        # Building output tags
        outputs = {
            'tag/policy/order_based/v015_film_transformer_gpt':
            True,
            'targets':
            decoder_inputs[:, 1:],
            'selected_tokens':
            selected_tokens,
            'argmax_tokens':
            argmax_tokens,
            'logits':
            candidate_logits,
            'log_probs':
            log_probs,
            'beam_tokens':
            tf.transpose(beam_results.predicted_ids,
                         perm=[0, 2, 1]),  # [batch, beam, steps]
            'beam_log_probs':
            beam_state.log_probs,
            'rnn_states':
            training_results.rnn_state,
            'policy_loss':
            policy_loss,
            'draw_prob':
            self.outputs.get('draw_prob',
                             tf.zeros_like(self.features['draw_target'])),
            'learning_rate':
            self.learning_rate
        }

        # Adding features, placeholders and outputs to graph
        self.add_meta_information(outputs)
コード例 #6
0
ファイル: base_model.py プロジェクト: zhanpengfang/research
    def create_optimizer_op(self,
                            cost_and_scope,
                            ignored_scope=None,
                            max_gradient_norm=None):
        """ Creates an optimizer op to reduce the cost
            :param cost_and_scope: List of tuples (cost, scope, ignored_scope)
                - cost is a tensor representing the cost to minimize
                - scope is either a string, or a list of strings. Contains the scope(s) where the get the vars to update
                - ignored_scope is either None, a string, or a list of strings. Contains scope(s) to ignore.
            :param ignored_scope: A scope or list of scope for which we know we won't compute gradients
            :param max_gradient_norm: Optional. If set, gradients will be clipped to this value.
            :return: The optimizer op

            Note: The ignored scope inside 'cost_and_scope' is local to that cost, while the arg ignored_scope is global
                  for all costs.
        """
        # pylint: disable=too-many-branches
        from diplomacy_research.utils.tensorflow import tf, scope_vars, ensure_finite
        from diplomacy_research.utils import gradient_checkpoint
        assert self.optimizer, 'Optimizer must be defined in self.optimizer before calling this method.'
        if self.cluster_config \
                and self.hparams['training_mode'] == 'supervised' \
                and 'sync_gradients' in self.hparams \
                and self.hparams['sync_gradients']:
            assert isinstance(self.optimizer, tf.train.SyncReplicasOptimizer
                              ), 'optimizer must be SyncReplicasOptimizer'
        assert self.hparams['grad_aggregation'].upper() in [
            'ADD_N', 'ACCUMULATE_N', 'TREE'
        ], 'Invalid aggregation'

        # Warning if more than 1 optimizer is created
        self.nb_optimizers += 1
        if self.nb_optimizers > 1:
            LOGGER.warning(
                'You have created %d optimizers for this model. This is not recommended (High memory usage)',
                self.nb_optimizers)

        # Determining aggregation_method based on accumulate_n flag
        aggregation_method = None
        if self.hparams['grad_aggregation'].upper() == 'ACCUMULATE_N':
            aggregation_method = tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N
        elif self.hparams['grad_aggregation'].upper() == 'TREE':
            aggregation_method = tf.AggregationMethod.EXPERIMENTAL_TREE

        # Finding all trainable variables
        all_trainable_vars = tf.get_collection(
            tf.GraphKeys.TRAINABLE_VARIABLES)

        ignored_scope_trainable_vars = []
        if isinstance(ignored_scope, list):
            for scope_name in ignored_scope:
                ignored_scope_trainable_vars += scope_vars(scope_name,
                                                           trainable_only=True)
        elif ignored_scope is not None:
            ignored_scope_trainable_vars = scope_vars(ignored_scope,
                                                      trainable_only=True)
        ignored_scope_trainable_vars = set(ignored_scope_trainable_vars)

        # Building a list of all trainable vars, and removing them when used by an op
        unused_trainable_vars = set(all_trainable_vars) - set(
            ignored_scope_trainable_vars)

        # Summing gradients if we are optimizing multiple costs
        global_gradients = {}
        for cost, scope, local_ignored_scope in cost_and_scope:

            local_ignored_vars = []
            if isinstance(local_ignored_scope, list):
                for scope_name in local_ignored_scope:
                    local_ignored_vars += scope_vars(scope_name,
                                                     trainable_only=True)
            elif local_ignored_scope is not None:
                local_ignored_vars = scope_vars(local_ignored_scope,
                                                trainable_only=True)
            local_ignored_vars = set(local_ignored_vars)

            # Computing gradients with respect to all scope vars (except global ignored vars, but incl. local ignored)
            scope_trainable_vars = []
            scope = [scope] if not isinstance(scope, list) else scope
            for scope_name in scope:
                for variable in scope_vars(scope_name, trainable_only=True):
                    if variable not in scope_trainable_vars and variable not in ignored_scope_trainable_vars:
                        scope_trainable_vars += [variable]

            # Computing gradients
            if self.hparams['gradient_checkpoint']:
                LOGGER.info(
                    '****** Optimizing graph with gradient checkpointing...')
                gradients = gradient_checkpoint.gradients(
                    cost,
                    scope_trainable_vars,
                    checkpoints=self.hparams['gradient_checkpoint'],
                    aggregation_method=aggregation_method)
                LOGGER.info(
                    'Done optimizing graph with gradient checkpointing...')
            else:
                LOGGER.info('****** Computing gradients with respect to %s...',
                            str(cost))
                gradients = tf.gradients(cost,
                                         scope_trainable_vars,
                                         aggregation_method=aggregation_method)

            # Storing gradients in global_gradients
            for trainable_var, gradient in zip(scope_trainable_vars,
                                               gradients):
                if trainable_var in local_ignored_vars:
                    continue
                if trainable_var in unused_trainable_vars:
                    unused_trainable_vars.remove(trainable_var)
                if gradient is None:
                    LOGGER.warning(
                        'Gradient for %s is None. Is the graph disconnected?',
                        str(trainable_var))
                    continue
                if trainable_var.name in global_gradients:
                    global_gradients[str(trainable_var.name)] += [gradient]
                else:
                    global_gradients[str(trainable_var.name)] = [gradient]

        # Warning about missing trainable variables
        for variable in unused_trainable_vars:
            LOGGER.warning(
                'The training variable %s has not been included in the optimizer_op.',
                str(variable))

        # Warning about ignored training variables
        for variable in ignored_scope_trainable_vars:
            LOGGER.info('Ignoring variable: "%s" (Shape: %s).',
                        str(variable.name), str(variable.shape))

        # Computing and clipping gradients
        gradients = []
        for variable in all_trainable_vars:
            var_gradients = global_gradients.get(str(variable.name), [])
            if not var_gradients:
                gradients += [None]
            elif len(var_gradients) == 1:
                gradients += var_gradients
            else:
                if [
                        1 for grad in var_gradients
                        if isinstance(grad, tf.IndexedSlices)
                ]:
                    LOGGER.info('Adding IndexedSlices for %s', variable)
                gradients += [
                    tf.add_n(var_gradients,
                             name='%s/Add_N' % (variable.name.split(':')[0]))
                ]
        gradients = [ensure_finite(gradient) for gradient in gradients]
        if max_gradient_norm is not None:
            gradients, _ = tf.clip_by_global_norm(gradients, max_gradient_norm)

        # Finding update ops
        update_ops = []
        for _, scope, _ in cost_and_scope:
            if isinstance(scope, list):
                for scope_name in scope:
                    for update_op in tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                                       scope=scope_name):
                        if update_op not in update_ops:
                            update_ops += [update_op]
            else:
                update_ops += tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                                scope=scope)

        # Printing number of variables
        global_vars = tf.global_variables()
        LOGGER.info('Model has %d global vars and %d trainable vars',
                    len(global_vars), len(all_trainable_vars))

        # Computing the number of parameters
        nb_global_params = sum([
            reduce(mul, variable.shape.as_list(), 1)
            for variable in global_vars
        ])
        nb_trainable_params = sum([
            reduce(mul, variable.shape.as_list(), 1)
            for variable in all_trainable_vars
        ])
        LOGGER.info('Model has %s parameters (%s for trainable vars)',
                    '{:,}'.format(nb_global_params),
                    '{:,}'.format(nb_trainable_params))

        # Creating optimizer op (with dependencies on update for batch norm)
        with tf.control_dependencies(update_ops):
            opt_op = self.optimizer.apply_gradients(
                zip(gradients, all_trainable_vars),
                global_step=self.global_step)

        # Returning optimization op
        return opt_op