Exemple #1
0
 def _get_fake_data(self, inputs, mlm_logits):
     """Sample from the generator to create corrupted input."""
     inputs = pretrain_helpers.unmask(inputs)
     disallow = tf.one_hot(
         inputs.masked_lm_ids,
         depth=self._bert_config.vocab_size,
         dtype=tf.float32) if self._config.disallow_correct else None
     sampled_tokens = tf.stop_gradient(
         pretrain_helpers.sample_from_softmax(mlm_logits /
                                              self._config.temperature,
                                              disallow=disallow))
     sampled_tokids = tf.argmax(sampled_tokens, -1, output_type=tf.int32)
     updated_input_ids, masked = pretrain_helpers.scatter_update(
         inputs.input_ids, sampled_tokids, inputs.masked_lm_positions)
     if self._config.electric_objective:
         labels = masked
     else:
         labels = masked * (1 - tf.cast(
             tf.equal(updated_input_ids, inputs.input_ids), tf.int32))
     updated_inputs = pretrain_data.get_updated_inputs(
         inputs, input_ids=updated_input_ids)
     FakedData = collections.namedtuple(
         "FakedData", ["inputs", "is_fake_tokens", "sampled_tokens"])
     return FakedData(inputs=updated_inputs,
                      is_fake_tokens=labels,
                      sampled_tokens=sampled_tokens)
Exemple #2
0
def _get_fake_data(inputs, mlm_logits):
    """Sample from the generator to create corrupted input."""
    masked_lm_weights = inputs.masked_lm_weights
    inputs = pretrain_helpers.unmask(inputs)
    disallow = None
    sampled_tokens = tf.stop_gradient(
        pretrain_helpers.sample_from_softmax(mlm_logits / 1.0,
                                             disallow=disallow))

    # sampled_tokens: [batch_size, n_pos, n_vocab]
    # mlm_logits: [batch_size, n_pos, n_vocab]
    sampled_tokens_fp32 = tf.cast(sampled_tokens, dtype=tf.float32)
    print(sampled_tokens_fp32, "===sampled_tokens_fp32===")
    # [batch_size, n_pos]
    # mlm_logprobs: [batch_size, n_pos. n_vocab]
    mlm_logprobs = tf.nn.log_softmax(mlm_logits, axis=-1)
    pseudo_logprob = tf.reduce_sum(mlm_logprobs * sampled_tokens_fp32, axis=-1)
    pseudo_logprob *= tf.cast(masked_lm_weights, dtype=tf.float32)
    # [batch_size]
    pseudo_logprob = tf.reduce_sum(pseudo_logprob, axis=-1)
    # [batch_size]
    # pseudo_logprob /= (1e-10+tf.reduce_sum(tf.cast(masked_lm_weights, dtype=tf.float32), axis=-1))
    print("== _get_fake_data pseudo_logprob ==", pseudo_logprob)
    sampled_tokids = tf.argmax(sampled_tokens, -1, output_type=tf.int32)
    updated_input_ids, masked = pretrain_helpers.scatter_update(
        inputs.input_ids, sampled_tokids, inputs.masked_lm_positions)

    labels = masked * (
        1 - tf.cast(tf.equal(updated_input_ids, inputs.input_ids), tf.int32))
    updated_inputs = pretrain_data.get_updated_inputs(
        inputs, input_ids=updated_input_ids)
    FakedData = collections.namedtuple(
        "FakedData",
        ["inputs", "is_fake_tokens", "sampled_tokens", "pseudo_logprob"])
    return FakedData(inputs=updated_inputs,
                     is_fake_tokens=labels,
                     sampled_tokens=sampled_tokens,
                     pseudo_logprob=pseudo_logprob)
Exemple #3
0
    def _sample_masking_subset(self, inputs: pretrain_data.Inputs,
                               action_probs):
        #calculate shifted action_probs
        input_mask = inputs.input_mask
        segment_ids = inputs.segment_ids
        input_ids = inputs.input_ids

        shape = modeling.get_shape_list(input_ids, expected_rank=2)
        batch_size = shape[0]
        max_seq_len = shape[1]

        def _remove_special_token(elems):
            action_prob = tf.cast(elems[0], tf.float32)
            segment = tf.cast(elems[1], tf.int32)
            input = tf.cast(elems[2], tf.int32)
            mask = tf.cast(elems[3], tf.int32)

            seq_len = tf.reduce_sum(mask)
            seg1_len = seq_len - tf.reduce_sum(segment)
            seq1_idx = tf.range(start=1, limit=seg1_len - 1, dtype=tf.int32)
            seq2_limit = tf.math.maximum(seg1_len, seq_len - 1)
            seq2_idx = tf.range(start=seg1_len,
                                limit=seq2_limit,
                                dtype=tf.int32)
            mask_idx = tf.range(start=seq_len,
                                limit=max_seq_len,
                                dtype=tf.int32)
            index_tensor = tf.concat([seq1_idx, seq2_idx, mask_idx], axis=0)

            seq1_prob = action_prob[1:seg1_len - 1]
            seq2_prob = action_prob[seg1_len:seq2_limit]
            mask_prob = tf.ones_like(mask_idx, dtype=tf.float32) * 1e-20
            cleaned_action_prob = tf.concat([seq1_prob, seq2_prob, mask_prob],
                                            axis=0)
            cleaned_mask = tf.concat([
                mask[1:seg1_len - 1], mask[seg1_len:seq_len - 1],
                mask[seq_len:max_seq_len]
            ],
                                     axis=0)

            cleaned_input = tf.concat([
                input[1:seg1_len - 1], input[seg1_len:seq_len - 1],
                input[seq_len:max_seq_len]
            ],
                                      axis=0)

            cleaned_action_prob = cleaned_action_prob[0:max_seq_len - 3]
            index_tensor = index_tensor[0:max_seq_len - 3]
            cleaned_input = cleaned_input[0:max_seq_len - 3]
            cleaned_mask = cleaned_mask[0:max_seq_len - 3]

            return (cleaned_action_prob, index_tensor, cleaned_input,
                    cleaned_mask)

        # Remove CLS and SEP action probs
        elems = tf.stack([
            action_probs,
            tf.cast(segment_ids, tf.float32),
            tf.cast(input_ids, tf.float32),
            tf.cast(input_mask, tf.float32)
        ], 1)
        cleaned_action_probs, index_tensors, cleaned_inputs, cleaned_input_mask = tf.map_fn(
            _remove_special_token,
            elems,
            dtype=(tf.float32, tf.int32, tf.int32, tf.int32),
            parallel_iterations=1)
        logZ, log_prob = self._calculate_partition_table(
            cleaned_input_mask, cleaned_action_probs,
            self._config.max_predictions_per_seq)

        samples, log_q = self._sampling_a_subset(
            logZ, log_prob, self._config.max_predictions_per_seq)

        # Collect masked_lm_ids and masked_lm_positions
        zero_values = tf.zeros_like(index_tensors, tf.int32)
        selected_position = tf.where(tf.equal(samples, 1), index_tensors,
                                     zero_values)
        masked_lm_positions, _ = tf.nn.top_k(
            selected_position,
            self._config.max_predictions_per_seq,
            sorted=False)

        # Get the ids of the masked-out tokens
        shift = tf.expand_dims(max_seq_len * tf.range(batch_size), -1)
        flat_positions = tf.reshape(masked_lm_positions + shift, [-1, 1])
        masked_lm_ids = tf.gather_nd(tf.reshape(input_ids, [-1]),
                                     flat_positions)
        masked_lm_ids = tf.reshape(masked_lm_ids, [batch_size, -1])

        # Update the input ids
        replaced_prob = tf.random.uniform(
            [batch_size, self._config.max_predictions_per_seq])
        replace_with_mask_positions = masked_lm_positions * tf.cast(
            tf.less(replaced_prob, 0.85), tf.int32)
        inputs_ids, _ = scatter_update(
            inputs.input_ids,
            tf.fill([batch_size, self._config.max_predictions_per_seq],
                    self._vocab["[MASK]"]), replace_with_mask_positions)

        # Replace with random tokens
        replace_with_random_positions = masked_lm_positions * tf.cast(
            tf.greater(replaced_prob, 0.925), tf.int32)
        random_tokens = tf.random.uniform(
            [batch_size, self._config.max_predictions_per_seq],
            minval=0,
            maxval=len(self._vocab),
            dtype=tf.int32)

        inputs_ids, _ = scatter_update(inputs_ids, random_tokens,
                                       replace_with_random_positions)

        masked_lm_weights = tf.ones_like(masked_lm_ids, tf.float32)
        inv_vocab = self._inv_vocab
        # Apply mask on input
        if self._config.debug:

            def pretty_print(inputs_ids, masked_lm_ids, masked_lm_positions,
                             masked_lm_weights, tag_ids):
                debug_inputs = Inputs(input_ids=inputs_ids,
                                      input_mask=None,
                                      segment_ids=None,
                                      masked_lm_positions=masked_lm_positions,
                                      masked_lm_ids=masked_lm_ids,
                                      masked_lm_weights=masked_lm_weights,
                                      tag_ids=tag_ids)
                pretrain_data.print_tokens(debug_inputs, inv_vocab)

                ## TODO: save to the mask choice
                return inputs_ids, masked_lm_ids, masked_lm_positions, masked_lm_weights

            mask_shape = masked_lm_ids.get_shape()
            inputs_ids, masked_lm_ids, masked_lm_positions, masked_lm_weights = \
              tf.py_func(pretty_print, [inputs_ids, masked_lm_ids, masked_lm_positions, masked_lm_weights, inputs.tag_ids],
                         (tf.int32, tf.int32, tf.int32, tf.float32))
            inputs_ids.set_shape(inputs.input_ids.get_shape())
            masked_lm_ids.set_shape(mask_shape)
            masked_lm_positions.set_shape(mask_shape)
            masked_lm_weights.set_shape(mask_shape)

        masked_input = pretrain_data.get_updated_inputs(
            inputs,
            input_ids=tf.stop_gradient(input_ids),
            masked_lm_positions=tf.stop_gradient(masked_lm_positions),
            masked_lm_ids=tf.stop_gradient(masked_lm_ids),
            masked_lm_weights=tf.stop_gradient(masked_lm_weights),
            tag_ids=inputs.tag_ids)

        return log_q, masked_input