Exemplo n.º 1
0
 def _get_fake_data(self, inputs, mlm_logits):
     """Sample from the generator to create corrupted input."""
     inputs = pretrain_helpers.unmask(inputs)
     disallow = tf.one_hot(
         inputs.masked_lm_ids,
         depth=self._bert_config.vocab_size,
         dtype=tf.float32) if self._config.disallow_correct else None
     sampled_tokens = tf.stop_gradient(
         pretrain_helpers.sample_from_softmax(mlm_logits /
                                              self._config.temperature,
                                              disallow=disallow))
     sampled_tokids = tf.argmax(sampled_tokens, -1, output_type=tf.int32)
     updated_input_ids, masked = pretrain_helpers.scatter_update(
         inputs.input_ids, sampled_tokids, inputs.masked_lm_positions)
     if self._config.electric_objective:
         labels = masked
     else:
         labels = masked * (1 - tf.cast(
             tf.equal(updated_input_ids, inputs.input_ids), tf.int32))
     updated_inputs = pretrain_data.get_updated_inputs(
         inputs, input_ids=updated_input_ids)
     FakedData = collections.namedtuple(
         "FakedData", ["inputs", "is_fake_tokens", "sampled_tokens"])
     return FakedData(inputs=updated_inputs,
                      is_fake_tokens=labels,
                      sampled_tokens=sampled_tokens)
Exemplo n.º 2
0
def _get_fake_data(inputs, mlm_logits):
    """Sample from the generator to create corrupted input."""
    masked_lm_weights = inputs.masked_lm_weights
    inputs = pretrain_helpers.unmask(inputs)
    disallow = None
    sampled_tokens = tf.stop_gradient(
        pretrain_helpers.sample_from_softmax(mlm_logits / 1.0,
                                             disallow=disallow))

    # sampled_tokens: [batch_size, n_pos, n_vocab]
    # mlm_logits: [batch_size, n_pos, n_vocab]
    sampled_tokens_fp32 = tf.cast(sampled_tokens, dtype=tf.float32)
    print(sampled_tokens_fp32, "===sampled_tokens_fp32===")
    # [batch_size, n_pos]
    # mlm_logprobs: [batch_size, n_pos. n_vocab]
    mlm_logprobs = tf.nn.log_softmax(mlm_logits, axis=-1)
    pseudo_logprob = tf.reduce_sum(mlm_logprobs * sampled_tokens_fp32, axis=-1)
    pseudo_logprob *= tf.cast(masked_lm_weights, dtype=tf.float32)
    # [batch_size]
    pseudo_logprob = tf.reduce_sum(pseudo_logprob, axis=-1)
    # [batch_size]
    # pseudo_logprob /= (1e-10+tf.reduce_sum(tf.cast(masked_lm_weights, dtype=tf.float32), axis=-1))
    print("== _get_fake_data pseudo_logprob ==", pseudo_logprob)
    sampled_tokids = tf.argmax(sampled_tokens, -1, output_type=tf.int32)
    updated_input_ids, masked = pretrain_helpers.scatter_update(
        inputs.input_ids, sampled_tokids, inputs.masked_lm_positions)

    labels = masked * (
        1 - tf.cast(tf.equal(updated_input_ids, inputs.input_ids), tf.int32))
    updated_inputs = pretrain_data.get_updated_inputs(
        inputs, input_ids=updated_input_ids)
    FakedData = collections.namedtuple(
        "FakedData",
        ["inputs", "is_fake_tokens", "sampled_tokens", "pseudo_logprob"])
    return FakedData(inputs=updated_inputs,
                     is_fake_tokens=labels,
                     sampled_tokens=sampled_tokens,
                     pseudo_logprob=pseudo_logprob)
Exemplo n.º 3
0
def unmask(inputs: pretrain_data.Inputs):
  unmasked_input_ids, _ = scatter_update(
      inputs.input_ids, inputs.masked_lm_ids, inputs.masked_lm_positions)
  return pretrain_data.get_updated_inputs(inputs, input_ids=unmasked_input_ids)
Exemplo n.º 4
0
def mask(config: configure_pretraining.PretrainingConfig,
         inputs: pretrain_data.Inputs, mask_prob, proposal_distribution=1.0,
         disallow_from_mask=None, already_masked=None):
  """Implementation of dynamic masking. The optional arguments aren't needed for
  BERT/ELECTRA and are from early experiments in "strategically" masking out
  tokens instead of uniformly at random.

  Args:
    config: configure_pretraining.PretrainingConfig
    inputs: pretrain_data.Inputs containing input input_ids/input_mask
    mask_prob: percent of tokens to mask
    proposal_distribution: for non-uniform masking can be a [B, L] tensor
                           of scores for masking each position.
    disallow_from_mask: a boolean tensor of [B, L] of positions that should
                        not be masked out
    already_masked: a boolean tensor of [B, N] of already masked-out tokens
                    for multiple rounds of masking
  Returns: a pretrain_data.Inputs with masking added
  """
  # Get the batch size, sequence length, and max masked-out tokens
  N = config.max_predictions_per_seq
  B, L = modeling.get_shape_list(inputs.input_ids)

  # Find indices where masking out a token is allowed
  vocab = tokenization.FullTokenizer(
      config.vocab_file, do_lower_case=config.do_lower_case).vocab
  candidates_mask = _get_candidates_mask(inputs, vocab, disallow_from_mask)

  # Set the number of tokens to mask out per example
  num_tokens = tf.cast(tf.reduce_sum(inputs.input_mask, -1), tf.float32)
  num_to_predict = tf.maximum(1, tf.minimum(
      N, tf.cast(tf.round(num_tokens * mask_prob), tf.int32)))
  masked_lm_weights = tf.cast(tf.sequence_mask(num_to_predict, N), tf.float32)
  if already_masked is not None:
    masked_lm_weights *= (1 - already_masked)

  # Get a probability of masking each position in the sequence
  candidate_mask_float = tf.cast(candidates_mask, tf.float32)
  sample_prob = (proposal_distribution * candidate_mask_float)
  sample_prob /= tf.reduce_sum(sample_prob, axis=-1, keepdims=True)

  # Sample the positions to mask out
  sample_prob = tf.stop_gradient(sample_prob)
  sample_logits = tf.log(sample_prob)
  masked_lm_positions = tf.random.categorical(
      sample_logits, N, dtype=tf.int32)
  masked_lm_positions *= tf.cast(masked_lm_weights, tf.int32)

  # Get the ids of the masked-out tokens
  shift = tf.expand_dims(L * tf.range(B), -1)
  flat_positions = tf.reshape(masked_lm_positions + shift, [-1, 1])
  masked_lm_ids = tf.gather_nd(tf.reshape(inputs.input_ids, [-1]),
                               flat_positions)
  masked_lm_ids = tf.reshape(masked_lm_ids, [B, -1])
  masked_lm_ids *= tf.cast(masked_lm_weights, tf.int32)

  # Update the input ids
  replace_with_mask_positions = masked_lm_positions * tf.cast(
      tf.less(tf.random.uniform([B, N]), 0.85), tf.int32)
  inputs_ids, _ = scatter_update(
      inputs.input_ids, tf.fill([B, N], vocab["[MASK]"]),
      replace_with_mask_positions)

  return pretrain_data.get_updated_inputs(
      inputs,
      input_ids=tf.stop_gradient(inputs_ids),
      masked_lm_positions=masked_lm_positions,
      masked_lm_ids=masked_lm_ids,
      masked_lm_weights=masked_lm_weights
  )
Exemplo n.º 5
0
    def _sample_masking_subset(self, inputs: pretrain_data.Inputs,
                               action_probs):
        #calculate shifted action_probs
        input_mask = inputs.input_mask
        segment_ids = inputs.segment_ids
        input_ids = inputs.input_ids

        shape = modeling.get_shape_list(input_ids, expected_rank=2)
        batch_size = shape[0]
        max_seq_len = shape[1]

        def _remove_special_token(elems):
            action_prob = tf.cast(elems[0], tf.float32)
            segment = tf.cast(elems[1], tf.int32)
            input = tf.cast(elems[2], tf.int32)
            mask = tf.cast(elems[3], tf.int32)

            seq_len = tf.reduce_sum(mask)
            seg1_len = seq_len - tf.reduce_sum(segment)
            seq1_idx = tf.range(start=1, limit=seg1_len - 1, dtype=tf.int32)
            seq2_limit = tf.math.maximum(seg1_len, seq_len - 1)
            seq2_idx = tf.range(start=seg1_len,
                                limit=seq2_limit,
                                dtype=tf.int32)
            mask_idx = tf.range(start=seq_len,
                                limit=max_seq_len,
                                dtype=tf.int32)
            index_tensor = tf.concat([seq1_idx, seq2_idx, mask_idx], axis=0)

            seq1_prob = action_prob[1:seg1_len - 1]
            seq2_prob = action_prob[seg1_len:seq2_limit]
            mask_prob = tf.ones_like(mask_idx, dtype=tf.float32) * 1e-20
            cleaned_action_prob = tf.concat([seq1_prob, seq2_prob, mask_prob],
                                            axis=0)
            cleaned_mask = tf.concat([
                mask[1:seg1_len - 1], mask[seg1_len:seq_len - 1],
                mask[seq_len:max_seq_len]
            ],
                                     axis=0)

            cleaned_input = tf.concat([
                input[1:seg1_len - 1], input[seg1_len:seq_len - 1],
                input[seq_len:max_seq_len]
            ],
                                      axis=0)

            cleaned_action_prob = cleaned_action_prob[0:max_seq_len - 3]
            index_tensor = index_tensor[0:max_seq_len - 3]
            cleaned_input = cleaned_input[0:max_seq_len - 3]
            cleaned_mask = cleaned_mask[0:max_seq_len - 3]

            return (cleaned_action_prob, index_tensor, cleaned_input,
                    cleaned_mask)

        # Remove CLS and SEP action probs
        elems = tf.stack([
            action_probs,
            tf.cast(segment_ids, tf.float32),
            tf.cast(input_ids, tf.float32),
            tf.cast(input_mask, tf.float32)
        ], 1)
        cleaned_action_probs, index_tensors, cleaned_inputs, cleaned_input_mask = tf.map_fn(
            _remove_special_token,
            elems,
            dtype=(tf.float32, tf.int32, tf.int32, tf.int32),
            parallel_iterations=1)
        logZ, log_prob = self._calculate_partition_table(
            cleaned_input_mask, cleaned_action_probs,
            self._config.max_predictions_per_seq)

        samples, log_q = self._sampling_a_subset(
            logZ, log_prob, self._config.max_predictions_per_seq)

        # Collect masked_lm_ids and masked_lm_positions
        zero_values = tf.zeros_like(index_tensors, tf.int32)
        selected_position = tf.where(tf.equal(samples, 1), index_tensors,
                                     zero_values)
        masked_lm_positions, _ = tf.nn.top_k(
            selected_position,
            self._config.max_predictions_per_seq,
            sorted=False)

        # Get the ids of the masked-out tokens
        shift = tf.expand_dims(max_seq_len * tf.range(batch_size), -1)
        flat_positions = tf.reshape(masked_lm_positions + shift, [-1, 1])
        masked_lm_ids = tf.gather_nd(tf.reshape(input_ids, [-1]),
                                     flat_positions)
        masked_lm_ids = tf.reshape(masked_lm_ids, [batch_size, -1])

        # Update the input ids
        replaced_prob = tf.random.uniform(
            [batch_size, self._config.max_predictions_per_seq])
        replace_with_mask_positions = masked_lm_positions * tf.cast(
            tf.less(replaced_prob, 0.85), tf.int32)
        inputs_ids, _ = scatter_update(
            inputs.input_ids,
            tf.fill([batch_size, self._config.max_predictions_per_seq],
                    self._vocab["[MASK]"]), replace_with_mask_positions)

        # Replace with random tokens
        replace_with_random_positions = masked_lm_positions * tf.cast(
            tf.greater(replaced_prob, 0.925), tf.int32)
        random_tokens = tf.random.uniform(
            [batch_size, self._config.max_predictions_per_seq],
            minval=0,
            maxval=len(self._vocab),
            dtype=tf.int32)

        inputs_ids, _ = scatter_update(inputs_ids, random_tokens,
                                       replace_with_random_positions)

        masked_lm_weights = tf.ones_like(masked_lm_ids, tf.float32)
        inv_vocab = self._inv_vocab
        # Apply mask on input
        if self._config.debug:

            def pretty_print(inputs_ids, masked_lm_ids, masked_lm_positions,
                             masked_lm_weights, tag_ids):
                debug_inputs = Inputs(input_ids=inputs_ids,
                                      input_mask=None,
                                      segment_ids=None,
                                      masked_lm_positions=masked_lm_positions,
                                      masked_lm_ids=masked_lm_ids,
                                      masked_lm_weights=masked_lm_weights,
                                      tag_ids=tag_ids)
                pretrain_data.print_tokens(debug_inputs, inv_vocab)

                ## TODO: save to the mask choice
                return inputs_ids, masked_lm_ids, masked_lm_positions, masked_lm_weights

            mask_shape = masked_lm_ids.get_shape()
            inputs_ids, masked_lm_ids, masked_lm_positions, masked_lm_weights = \
              tf.py_func(pretty_print, [inputs_ids, masked_lm_ids, masked_lm_positions, masked_lm_weights, inputs.tag_ids],
                         (tf.int32, tf.int32, tf.int32, tf.float32))
            inputs_ids.set_shape(inputs.input_ids.get_shape())
            masked_lm_ids.set_shape(mask_shape)
            masked_lm_positions.set_shape(mask_shape)
            masked_lm_weights.set_shape(mask_shape)

        masked_input = pretrain_data.get_updated_inputs(
            inputs,
            input_ids=tf.stop_gradient(input_ids),
            masked_lm_positions=tf.stop_gradient(masked_lm_positions),
            masked_lm_ids=tf.stop_gradient(masked_lm_ids),
            masked_lm_weights=tf.stop_gradient(masked_lm_weights),
            tag_ids=inputs.tag_ids)

        return log_q, masked_input
Exemplo n.º 6
0
    def __init__(self, config: configure_pretraining.PretrainingConfig,
                 features, is_training):
        # Set up model config
        self._config = config
        self._bert_config = training_utils.get_bert_config(config)
        self._teacher_config = training_utils.get_teacher_config(config)

        embedding_size = (self._bert_config.hidden_size
                          if config.embedding_size is None else
                          config.embedding_size)

        tokenizer = tokenization.FullTokenizer(
            config.vocab_file, do_lower_case=config.do_lower_case)
        self._vocab = tokenizer.vocab
        self._inv_vocab = tokenizer.inv_vocab

        # Mask the input
        inputs = pretrain_data.features_to_inputs(features)
        old_model = self._build_transformer(inputs,
                                            is_training,
                                            embedding_size=embedding_size)
        input_states = old_model.get_sequence_output()
        input_states = tf.stop_gradient(input_states)

        teacher_output = self._build_teacher(input_states,
                                             inputs,
                                             is_training,
                                             embedding_size=embedding_size)
        # calculate the proposal distribution

        action_prob = teacher_output.action_probs  #pi(x_i)

        coin_toss = tf.random.uniform([])
        log_q, masked_inputs = self._sample_masking_subset(inputs, action_prob)
        if config.masking_strategy == pretrain_helpers.MIX_ADV_STRATEGY:
            random_masked_input = pretrain_helpers.mask(
                config, pretrain_data.features_to_inputs(features),
                config.mask_prob)
            B, L = modeling.get_shape_list(inputs.input_ids)
            N = config.max_predictions_per_seq
            strategy_prob = tf.random.uniform([B])
            strategy_prob = tf.expand_dims(
                tf.cast(tf.greater(strategy_prob, 0.5), tf.int32), 1)
            l_strategy_prob = tf.tile(strategy_prob, [1, L])
            n_strategy_prob = tf.tile(strategy_prob, [1, N])
            mix_input_ids = masked_inputs.input_ids * l_strategy_prob + random_masked_input.input_ids * (
                1 - l_strategy_prob)
            mix_masked_lm_positions = masked_inputs.masked_lm_positions * n_strategy_prob + random_masked_input.masked_lm_positions * (
                1 - n_strategy_prob)
            mix_masked_lm_ids = masked_inputs.masked_lm_ids * n_strategy_prob + random_masked_input.masked_lm_ids * (
                1 - n_strategy_prob)
            n_strategy_prob = tf.cast(n_strategy_prob, tf.float32)
            mix_masked_lm_weights = masked_inputs.masked_lm_weights * n_strategy_prob + random_masked_input.masked_lm_weights * (
                1 - n_strategy_prob)
            mix_masked_inputs = pretrain_data.get_updated_inputs(
                inputs,
                input_ids=tf.stop_gradient(mix_input_ids),
                masked_lm_positions=mix_masked_lm_positions,
                masked_lm_ids=mix_masked_lm_ids,
                masked_lm_weights=mix_masked_lm_weights,
                tag_ids=inputs.tag_ids)
            masked_inputs = mix_masked_inputs

        # BERT model
        model = self._build_transformer(masked_inputs,
                                        is_training,
                                        reuse=tf.AUTO_REUSE,
                                        embedding_size=embedding_size)
        mlm_output = self._get_masked_lm_output(masked_inputs, model)
        self.total_loss = mlm_output.loss

        # Teacher reward is the -log p(x_S|x;B)
        reward = tf.stop_gradient(
            tf.reduce_mean(mlm_output.per_example_loss, 1))
        self._baseline = tf.reduce_mean(reward, -1)
        self._std = tf.math.reduce_std(reward, -1)

        # Calculate teacher loss
        def compute_teacher_loss(log_q, reward, baseline, std):
            advantage = tf.abs((reward - baseline) / std)
            advantage = tf.stop_gradient(advantage)
            log_q = tf.Print(log_q, [log_q], "log_q: ")
            teacher_loss = tf.reduce_mean(-log_q * advantage)
            return teacher_loss

        teacher_loss = tf.cond(
            coin_toss < 0.1, lambda: compute_teacher_loss(
                log_q, reward, self._baseline, self._std),
            lambda: tf.constant(0.0))
        self.total_loss = mlm_output.loss + teacher_loss
        self.teacher_loss = teacher_loss
        self.mlm_loss = mlm_output.loss

        # Evaluation`
        eval_fn_inputs = {
            "input_ids": masked_inputs.input_ids,
            "masked_lm_preds": mlm_output.preds,
            "mlm_loss": mlm_output.per_example_loss,
            "masked_lm_ids": masked_inputs.masked_lm_ids,
            "masked_lm_weights": masked_inputs.masked_lm_weights,
            "input_mask": masked_inputs.input_mask
        }
        eval_fn_keys = eval_fn_inputs.keys()
        eval_fn_values = [eval_fn_inputs[k] for k in eval_fn_keys]
        """Computes the loss and accuracy of the model."""
        d = {k: arg for k, arg in zip(eval_fn_keys, eval_fn_values)}
        metrics = dict()
        metrics["masked_lm_accuracy"] = tf.metrics.accuracy(
            labels=tf.reshape(d["masked_lm_ids"], [-1]),
            predictions=tf.reshape(d["masked_lm_preds"], [-1]),
            weights=tf.reshape(d["masked_lm_weights"], [-1]))
        metrics["masked_lm_loss"] = tf.metrics.mean(
            values=tf.reshape(d["mlm_loss"], [-1]),
            weights=tf.reshape(d["masked_lm_weights"], [-1]))
        self.eval_metrics = metrics
Exemplo n.º 7
0
def mask(config: configure_pretraining.PretrainingConfig,
         inputs: pretrain_data.Inputs, mask_prob, proposal_distribution=1.0,
         disallow_from_mask=None, already_masked=None):
  """Implementation of dynamic masking. The optional arguments aren't needed for
  BERT/ELECTRA and are from early experiments in "strategically" masking out
  tokens instead of uniformly at random.

  Args:
    config: configure_pretraining.PretrainingConfig
    inputs: pretrain_data.Inputs containing input input_ids/input_mask
    mask_prob: percent of tokens to mask
    proposal_distribution: for non-uniform masking can be a [B, L] tensor
                           of scores for masking each position.
    disallow_from_mask: a boolean tensor of [B, L] of positions that should
                        not be masked out
    already_masked: a boolean tensor of [B, N] of already masked-out tokens
                    for multiple rounds of masking
  Returns: a pretrain_data.Inputs with masking added
  """
  # Get the batch size, sequence length, and max masked-out tokens
  N = config.max_predictions_per_seq
  B, L = modeling.get_shape_list(inputs.input_ids)

  # Find indices where masking out a token is allowed
  tokenizer = tokenization.FullTokenizer(
      config.vocab_file, do_lower_case=config.do_lower_case)
  vocab = tokenizer.vocab
  inv_vocab = tokenizer.inv_vocab
  candidates_mask = _get_candidates_mask(inputs, vocab, disallow_from_mask)

  # Set the number of tokens to mask out per example
  num_tokens = tf.cast(tf.reduce_sum(inputs.input_mask, -1), tf.float32)
  num_to_predict = tf.maximum(1, tf.minimum(
      N, tf.cast(tf.round(num_tokens * mask_prob), tf.int32)))
  masked_lm_weights = tf.cast(tf.sequence_mask(num_to_predict, N), tf.float32)
  if already_masked is not None:
    masked_lm_weights *= (1 - already_masked)

  # Get a probability of masking each position in the sequence
  candidate_mask_float = tf.cast(candidates_mask, tf.float32)

  if config.masking_strategy == RAND_STRATEGY or config.masking_strategy == MIX_ADV_STRATEGY:
    sample_prob = (proposal_distribution * candidate_mask_float)
  elif config.masking_strategy == POS_STRATEGY:
    unfavor_pos_mask = _get_unfavor_pos_mask(inputs)
    unfavor_pos_mask_float = tf.cast(unfavor_pos_mask, tf.float32)
    prefer_pos_mask_float = 1 - unfavor_pos_mask_float

    # prefered pos have 80% propabiblity, not preferred ones have 20% probability
    # proposal_distribution = prefer_pos_mask_float
    proposal_distribution = 0.95 * prefer_pos_mask_float + 0.05
    sample_prob = (proposal_distribution * candidate_mask_float)
  elif config.masking_strategy == ENTROPY_STRATEGY:
    sample_prob = (proposal_distribution * candidate_mask_float)
  elif config.masking_strategy == MIX_POS_STRATEGY:
    rand_sample_prob = (proposal_distribution * candidate_mask_float)
    unfavor_pos_mask = _get_unfavor_pos_mask(inputs)
    unfavor_pos_mask_float = tf.cast(unfavor_pos_mask, tf.float32)
    prefer_pos_mask_float = 1 - unfavor_pos_mask_float

    # prefered pos have 80% propabiblity, not preferred ones have 20% probability
    # proposal_distribution = prefer_pos_mask_float
    proposal_distribution = 0.95 * prefer_pos_mask_float + 0.05
    pos_sample_prob = (proposal_distribution * candidate_mask_float)

    strategy_prob = tf.random.uniform([B])
    strategy_prob = tf.expand_dims(tf.cast(tf.greater(strategy_prob,0.5), tf.float32),1)
    strategy_prob = tf.tile(strategy_prob, [1,L])
    sample_prob = rand_sample_prob * strategy_prob + pos_sample_prob * (1 - strategy_prob)
  elif config.masking_strategy == MIX_ENTROPY_STRATEGY:
    rand_sample_prob = (proposal_distribution * candidate_mask_float)
    entropy_sample_prob = (proposal_distribution * candidate_mask_float)
    strategy_prob = tf.random.uniform([B])
    strategy_prob = tf.expand_dims(tf.cast(tf.greater(strategy_prob,0.5), tf.float32),1)
    strategy_prob = tf.tile(strategy_prob, [1, L])
    sample_prob = rand_sample_prob * strategy_prob + entropy_sample_prob * (1 - strategy_prob)
  else:
    raise ValueError("{} strategy is not supported".format(config.masking_strategy))
  sample_prob /= tf.reduce_sum(sample_prob, axis=-1, keepdims=True)

  # Sample the positions to mask out
  sample_prob = tf.stop_gradient(sample_prob)
  sample_logits = tf.log(sample_prob)
  masked_lm_positions = tf.random.categorical(
      sample_logits, N, dtype=tf.int32)
  masked_lm_positions *= tf.cast(masked_lm_weights, tf.int32)

  # Get the ids of the masked-out tokens
  shift = tf.expand_dims(L * tf.range(B), -1)
  flat_positions = tf.reshape(masked_lm_positions + shift, [-1, 1])
  masked_lm_ids = tf.gather_nd(tf.reshape(inputs.input_ids, [-1]),
                               flat_positions)
  masked_lm_ids = tf.reshape(masked_lm_ids, [B, -1])
  masked_lm_ids *= tf.cast(masked_lm_weights, tf.int32)

  # Update the input ids
  replace_prob = tf.random.uniform([B, N])
  replace_with_mask_positions = masked_lm_positions * tf.cast(
      tf.less(replace_prob, 0.85), tf.int32)
  inputs_ids, _ = scatter_update(
      inputs.input_ids, tf.fill([B, N], vocab["[MASK]"]),
      replace_with_mask_positions)

  # Replace with random tokens
  replace_with_random_positions = masked_lm_positions * tf.cast(
    tf.greater(replace_prob, 0.925), tf.int32)
  random_tokens = tf.random.uniform([B,N], minval=0, maxval=len(vocab), dtype=tf.int32)
  inputs_ids, _ = scatter_update(
    inputs_ids, random_tokens,
    replace_with_random_positions)

  if config.debug:
    def pretty_print(inputs_ids, masked_lm_ids, masked_lm_positions, masked_lm_weights, tag_ids):
      debug_inputs = Inputs(
      input_ids=inputs_ids,
      input_mask=None,
      segment_ids=None,
      masked_lm_positions=masked_lm_positions,
      masked_lm_ids=masked_lm_ids,
      masked_lm_weights=masked_lm_weights,
      tag_ids = tag_ids)
      pretrain_data.print_tokens(debug_inputs, inv_vocab)

      ## TODO: save to the mask choice
      return inputs_ids, masked_lm_ids, masked_lm_positions, masked_lm_weights

    mask_shape = masked_lm_ids.get_shape()
    inputs_ids, masked_lm_ids, masked_lm_positions, masked_lm_weights = \
      tf.py_func(pretty_print,[inputs_ids, masked_lm_ids, masked_lm_positions, masked_lm_weights, inputs.tag_ids],
                 (tf.int32, tf.int32, tf.int32, tf.float32))
    inputs_ids.set_shape(inputs.input_ids.get_shape())
    masked_lm_ids.set_shape(mask_shape)
    masked_lm_positions.set_shape(mask_shape)
    masked_lm_weights.set_shape(mask_shape)

  return pretrain_data.get_updated_inputs(
      inputs,
      input_ids=tf.stop_gradient(inputs_ids),
      masked_lm_positions=masked_lm_positions,
      masked_lm_ids=masked_lm_ids,
      masked_lm_weights=masked_lm_weights,
      tag_ids = inputs.tag_ids
    )
Exemplo n.º 8
0
  def _get_richer_data(self, fake_data):
    inputs_tf = fake_data.inputs.input_ids
    labels_tf = fake_data.is_fake_tokens
    lens_tf = tf.reduce_sum(fake_data.inputs.input_mask, 1)
    #retrieve the basic config
    V = self._bert_config.vocab_size
    #sub: 10%, del + ins: 5%
    N = int(self._config.max_predictions_per_seq * self._config.rich_prob)
    B, L = modeling.get_shape_list(inputs_tf)
    nlms = 0
    bilm = None
    if self._config.use_bilm:
      with open(self._config.bilm_file, 'rb') as f:
        bilm = tf.constant(np.load(f), tf.int32)
      _, nlms = modeling.get_shape_list(bilm)
    #make multiple partitions for edit op
    splits_list = []
    for i in range(B):
      one = tf.random.uniform([N * 4], 1, lens_tf[i], tf.int32)
      one, _ = tf.unique(one)
      one = tf.cond(tf.less(tf.shape(one)[0], N * 2 + 1),
                    lambda: tf.expand_dims(tf.range(1, N * 2 + 2), 0),
                    lambda: tf.sort(tf.reshape(one[: N * 2 + 1], [1, N * 2 + 1])))
      splits_list.append(one[:, 2::2])
    splits_tf = tf.concat(splits_list, 0)
    splits_up = tf.concat([splits_tf, tf.expand_dims(tf.constant([L] * B, tf.int32), 1)], 1)
    splits_lo = tf.concat([tf.expand_dims(tf.constant([0] * B, tf.int32), 1), splits_tf], 1)
    size_splits = splits_up - splits_lo
    #update the inputs and labels giving random insertion and deletion
    new_labels_list = []
    new_inputs_list = []
    for i in range(B):
      inputs_splits = tf.split(inputs_tf[i, :], size_splits[i, :])
      labels_splits = tf.split(labels_tf[i, :], size_splits[i, :])
      one_inputs = []
      one_labels = []
      size_split = len(inputs_splits)
      inputs_end = inputs_splits[-1]
      labels_end = labels_splits[-1]
      for j in range(size_split-1):
        inputs = inputs_splits[j]
        labels = labels_splits[j] #label 1 for substistution
        rand_op = random.randint(2, self._config.num_preds - 1) 
        if rand_op == 2: #label 2 for insertion
          if bilm is None: #noise
            insert_tok = tf.random.uniform([1], 1, V, tf.int32)
          else: #2-gram prediction
            insert_tok = tf.expand_dims(bilm[inputs[-1], random.randint(0, nlms-1)], 0)
          is_end_valid = tf.less_equal(2, tf.shape(inputs_end)[0])
          inputs = tf.cond(is_end_valid, lambda: tf.concat([inputs, insert_tok], 0), lambda: inputs)
          labels = tf.cond(is_end_valid, lambda: tf.concat([labels, tf.constant([2])], 0), lambda: labels)
          inputs_end = tf.cond(is_end_valid, lambda: inputs_end[:-1], lambda: inputs_end)
          labels_end = tf.cond(is_end_valid, lambda: labels_end[:-1], lambda: labels_end)
        elif rand_op == 3: #label 3 for deletion
          labels = tf.concat([labels[:-2], tf.constant([3])], 0)
          inputs = inputs[:-1]
          inputs_end = tf.concat([inputs_end, tf.constant([0])], 0)
          labels_end = tf.concat([labels_end, tf.constant([0])], 0)
        elif rand_op == 4: #label 4 for swapping
          labels = tf.concat([labels[:-1], tf.constant([4])], 0)
          inputs = tf.concat([inputs[:-2], [inputs[-1]], [inputs[-2]]], 0)
        one_labels.append(labels)
        one_inputs.append(inputs)
      one_inputs.append(inputs_end)
      one_labels.append(labels_end)
      one_inputs_tf = tf.concat(one_inputs, 0)
      one_labels_tf = tf.concat(one_labels, 0)
      one_inputs_tf = tf.cond(tf.less(lens_tf[i], N * 2 + 1), lambda: inputs_tf[i, :], lambda: one_inputs_tf)
      one_labels_tf = tf.cond(tf.less(lens_tf[i], N * 2 + 1), lambda: labels_tf[i, :], lambda: one_labels_tf)
      new_inputs_list.append(tf.expand_dims(one_inputs_tf, 0))
      new_labels_list.append(tf.expand_dims(one_labels_tf, 0))

    new_inputs_tf = tf.concat(new_inputs_list, 0)
    new_labels_tf = tf.concat(new_labels_list, 0)
    new_input_mask = tf.cast(tf.not_equal(new_inputs_tf, 0), tf.int32)
    updated_inputs = pretrain_data.get_updated_inputs(
        fake_data.inputs, input_ids=new_inputs_tf, input_mask=new_input_mask)
    RicherData = collections.namedtuple("RicherData", [
        "inputs", "is_fake_tokens", "sampled_tokens"])
    return RicherData(inputs=updated_inputs, is_fake_tokens=new_labels_tf,
                     sampled_tokens=fake_data.sampled_tokens)