def _get_fake_data(self, inputs, mlm_logits): """Sample from the generator to create corrupted input.""" inputs = pretrain_helpers.unmask(inputs) disallow = tf.one_hot( inputs.masked_lm_ids, depth=self._bert_config.vocab_size, dtype=tf.float32) if self._config.disallow_correct else None sampled_tokens = tf.stop_gradient( pretrain_helpers.sample_from_softmax(mlm_logits / self._config.temperature, disallow=disallow)) sampled_tokids = tf.argmax(sampled_tokens, -1, output_type=tf.int32) updated_input_ids, masked = pretrain_helpers.scatter_update( inputs.input_ids, sampled_tokids, inputs.masked_lm_positions) if self._config.electric_objective: labels = masked else: labels = masked * (1 - tf.cast( tf.equal(updated_input_ids, inputs.input_ids), tf.int32)) updated_inputs = pretrain_data.get_updated_inputs( inputs, input_ids=updated_input_ids) FakedData = collections.namedtuple( "FakedData", ["inputs", "is_fake_tokens", "sampled_tokens"]) return FakedData(inputs=updated_inputs, is_fake_tokens=labels, sampled_tokens=sampled_tokens)
def _get_fake_data(inputs, mlm_logits): """Sample from the generator to create corrupted input.""" masked_lm_weights = inputs.masked_lm_weights inputs = pretrain_helpers.unmask(inputs) disallow = None sampled_tokens = tf.stop_gradient( pretrain_helpers.sample_from_softmax(mlm_logits / 1.0, disallow=disallow)) # sampled_tokens: [batch_size, n_pos, n_vocab] # mlm_logits: [batch_size, n_pos, n_vocab] sampled_tokens_fp32 = tf.cast(sampled_tokens, dtype=tf.float32) print(sampled_tokens_fp32, "===sampled_tokens_fp32===") # [batch_size, n_pos] # mlm_logprobs: [batch_size, n_pos. n_vocab] mlm_logprobs = tf.nn.log_softmax(mlm_logits, axis=-1) pseudo_logprob = tf.reduce_sum(mlm_logprobs * sampled_tokens_fp32, axis=-1) pseudo_logprob *= tf.cast(masked_lm_weights, dtype=tf.float32) # [batch_size] pseudo_logprob = tf.reduce_sum(pseudo_logprob, axis=-1) # [batch_size] # pseudo_logprob /= (1e-10+tf.reduce_sum(tf.cast(masked_lm_weights, dtype=tf.float32), axis=-1)) print("== _get_fake_data pseudo_logprob ==", pseudo_logprob) sampled_tokids = tf.argmax(sampled_tokens, -1, output_type=tf.int32) updated_input_ids, masked = pretrain_helpers.scatter_update( inputs.input_ids, sampled_tokids, inputs.masked_lm_positions) labels = masked * ( 1 - tf.cast(tf.equal(updated_input_ids, inputs.input_ids), tf.int32)) updated_inputs = pretrain_data.get_updated_inputs( inputs, input_ids=updated_input_ids) FakedData = collections.namedtuple( "FakedData", ["inputs", "is_fake_tokens", "sampled_tokens", "pseudo_logprob"]) return FakedData(inputs=updated_inputs, is_fake_tokens=labels, sampled_tokens=sampled_tokens, pseudo_logprob=pseudo_logprob)
def unmask(inputs: pretrain_data.Inputs): unmasked_input_ids, _ = scatter_update( inputs.input_ids, inputs.masked_lm_ids, inputs.masked_lm_positions) return pretrain_data.get_updated_inputs(inputs, input_ids=unmasked_input_ids)
def mask(config: configure_pretraining.PretrainingConfig, inputs: pretrain_data.Inputs, mask_prob, proposal_distribution=1.0, disallow_from_mask=None, already_masked=None): """Implementation of dynamic masking. The optional arguments aren't needed for BERT/ELECTRA and are from early experiments in "strategically" masking out tokens instead of uniformly at random. Args: config: configure_pretraining.PretrainingConfig inputs: pretrain_data.Inputs containing input input_ids/input_mask mask_prob: percent of tokens to mask proposal_distribution: for non-uniform masking can be a [B, L] tensor of scores for masking each position. disallow_from_mask: a boolean tensor of [B, L] of positions that should not be masked out already_masked: a boolean tensor of [B, N] of already masked-out tokens for multiple rounds of masking Returns: a pretrain_data.Inputs with masking added """ # Get the batch size, sequence length, and max masked-out tokens N = config.max_predictions_per_seq B, L = modeling.get_shape_list(inputs.input_ids) # Find indices where masking out a token is allowed vocab = tokenization.FullTokenizer( config.vocab_file, do_lower_case=config.do_lower_case).vocab candidates_mask = _get_candidates_mask(inputs, vocab, disallow_from_mask) # Set the number of tokens to mask out per example num_tokens = tf.cast(tf.reduce_sum(inputs.input_mask, -1), tf.float32) num_to_predict = tf.maximum(1, tf.minimum( N, tf.cast(tf.round(num_tokens * mask_prob), tf.int32))) masked_lm_weights = tf.cast(tf.sequence_mask(num_to_predict, N), tf.float32) if already_masked is not None: masked_lm_weights *= (1 - already_masked) # Get a probability of masking each position in the sequence candidate_mask_float = tf.cast(candidates_mask, tf.float32) sample_prob = (proposal_distribution * candidate_mask_float) sample_prob /= tf.reduce_sum(sample_prob, axis=-1, keepdims=True) # Sample the positions to mask out sample_prob = tf.stop_gradient(sample_prob) sample_logits = tf.log(sample_prob) masked_lm_positions = tf.random.categorical( sample_logits, N, dtype=tf.int32) masked_lm_positions *= tf.cast(masked_lm_weights, tf.int32) # Get the ids of the masked-out tokens shift = tf.expand_dims(L * tf.range(B), -1) flat_positions = tf.reshape(masked_lm_positions + shift, [-1, 1]) masked_lm_ids = tf.gather_nd(tf.reshape(inputs.input_ids, [-1]), flat_positions) masked_lm_ids = tf.reshape(masked_lm_ids, [B, -1]) masked_lm_ids *= tf.cast(masked_lm_weights, tf.int32) # Update the input ids replace_with_mask_positions = masked_lm_positions * tf.cast( tf.less(tf.random.uniform([B, N]), 0.85), tf.int32) inputs_ids, _ = scatter_update( inputs.input_ids, tf.fill([B, N], vocab["[MASK]"]), replace_with_mask_positions) return pretrain_data.get_updated_inputs( inputs, input_ids=tf.stop_gradient(inputs_ids), masked_lm_positions=masked_lm_positions, masked_lm_ids=masked_lm_ids, masked_lm_weights=masked_lm_weights )
def _sample_masking_subset(self, inputs: pretrain_data.Inputs, action_probs): #calculate shifted action_probs input_mask = inputs.input_mask segment_ids = inputs.segment_ids input_ids = inputs.input_ids shape = modeling.get_shape_list(input_ids, expected_rank=2) batch_size = shape[0] max_seq_len = shape[1] def _remove_special_token(elems): action_prob = tf.cast(elems[0], tf.float32) segment = tf.cast(elems[1], tf.int32) input = tf.cast(elems[2], tf.int32) mask = tf.cast(elems[3], tf.int32) seq_len = tf.reduce_sum(mask) seg1_len = seq_len - tf.reduce_sum(segment) seq1_idx = tf.range(start=1, limit=seg1_len - 1, dtype=tf.int32) seq2_limit = tf.math.maximum(seg1_len, seq_len - 1) seq2_idx = tf.range(start=seg1_len, limit=seq2_limit, dtype=tf.int32) mask_idx = tf.range(start=seq_len, limit=max_seq_len, dtype=tf.int32) index_tensor = tf.concat([seq1_idx, seq2_idx, mask_idx], axis=0) seq1_prob = action_prob[1:seg1_len - 1] seq2_prob = action_prob[seg1_len:seq2_limit] mask_prob = tf.ones_like(mask_idx, dtype=tf.float32) * 1e-20 cleaned_action_prob = tf.concat([seq1_prob, seq2_prob, mask_prob], axis=0) cleaned_mask = tf.concat([ mask[1:seg1_len - 1], mask[seg1_len:seq_len - 1], mask[seq_len:max_seq_len] ], axis=0) cleaned_input = tf.concat([ input[1:seg1_len - 1], input[seg1_len:seq_len - 1], input[seq_len:max_seq_len] ], axis=0) cleaned_action_prob = cleaned_action_prob[0:max_seq_len - 3] index_tensor = index_tensor[0:max_seq_len - 3] cleaned_input = cleaned_input[0:max_seq_len - 3] cleaned_mask = cleaned_mask[0:max_seq_len - 3] return (cleaned_action_prob, index_tensor, cleaned_input, cleaned_mask) # Remove CLS and SEP action probs elems = tf.stack([ action_probs, tf.cast(segment_ids, tf.float32), tf.cast(input_ids, tf.float32), tf.cast(input_mask, tf.float32) ], 1) cleaned_action_probs, index_tensors, cleaned_inputs, cleaned_input_mask = tf.map_fn( _remove_special_token, elems, dtype=(tf.float32, tf.int32, tf.int32, tf.int32), parallel_iterations=1) logZ, log_prob = self._calculate_partition_table( cleaned_input_mask, cleaned_action_probs, self._config.max_predictions_per_seq) samples, log_q = self._sampling_a_subset( logZ, log_prob, self._config.max_predictions_per_seq) # Collect masked_lm_ids and masked_lm_positions zero_values = tf.zeros_like(index_tensors, tf.int32) selected_position = tf.where(tf.equal(samples, 1), index_tensors, zero_values) masked_lm_positions, _ = tf.nn.top_k( selected_position, self._config.max_predictions_per_seq, sorted=False) # Get the ids of the masked-out tokens shift = tf.expand_dims(max_seq_len * tf.range(batch_size), -1) flat_positions = tf.reshape(masked_lm_positions + shift, [-1, 1]) masked_lm_ids = tf.gather_nd(tf.reshape(input_ids, [-1]), flat_positions) masked_lm_ids = tf.reshape(masked_lm_ids, [batch_size, -1]) # Update the input ids replaced_prob = tf.random.uniform( [batch_size, self._config.max_predictions_per_seq]) replace_with_mask_positions = masked_lm_positions * tf.cast( tf.less(replaced_prob, 0.85), tf.int32) inputs_ids, _ = scatter_update( inputs.input_ids, tf.fill([batch_size, self._config.max_predictions_per_seq], self._vocab["[MASK]"]), replace_with_mask_positions) # Replace with random tokens replace_with_random_positions = masked_lm_positions * tf.cast( tf.greater(replaced_prob, 0.925), tf.int32) random_tokens = tf.random.uniform( [batch_size, self._config.max_predictions_per_seq], minval=0, maxval=len(self._vocab), dtype=tf.int32) inputs_ids, _ = scatter_update(inputs_ids, random_tokens, replace_with_random_positions) masked_lm_weights = tf.ones_like(masked_lm_ids, tf.float32) inv_vocab = self._inv_vocab # Apply mask on input if self._config.debug: def pretty_print(inputs_ids, masked_lm_ids, masked_lm_positions, masked_lm_weights, tag_ids): debug_inputs = Inputs(input_ids=inputs_ids, input_mask=None, segment_ids=None, masked_lm_positions=masked_lm_positions, masked_lm_ids=masked_lm_ids, masked_lm_weights=masked_lm_weights, tag_ids=tag_ids) pretrain_data.print_tokens(debug_inputs, inv_vocab) ## TODO: save to the mask choice return inputs_ids, masked_lm_ids, masked_lm_positions, masked_lm_weights mask_shape = masked_lm_ids.get_shape() inputs_ids, masked_lm_ids, masked_lm_positions, masked_lm_weights = \ tf.py_func(pretty_print, [inputs_ids, masked_lm_ids, masked_lm_positions, masked_lm_weights, inputs.tag_ids], (tf.int32, tf.int32, tf.int32, tf.float32)) inputs_ids.set_shape(inputs.input_ids.get_shape()) masked_lm_ids.set_shape(mask_shape) masked_lm_positions.set_shape(mask_shape) masked_lm_weights.set_shape(mask_shape) masked_input = pretrain_data.get_updated_inputs( inputs, input_ids=tf.stop_gradient(input_ids), masked_lm_positions=tf.stop_gradient(masked_lm_positions), masked_lm_ids=tf.stop_gradient(masked_lm_ids), masked_lm_weights=tf.stop_gradient(masked_lm_weights), tag_ids=inputs.tag_ids) return log_q, masked_input
def __init__(self, config: configure_pretraining.PretrainingConfig, features, is_training): # Set up model config self._config = config self._bert_config = training_utils.get_bert_config(config) self._teacher_config = training_utils.get_teacher_config(config) embedding_size = (self._bert_config.hidden_size if config.embedding_size is None else config.embedding_size) tokenizer = tokenization.FullTokenizer( config.vocab_file, do_lower_case=config.do_lower_case) self._vocab = tokenizer.vocab self._inv_vocab = tokenizer.inv_vocab # Mask the input inputs = pretrain_data.features_to_inputs(features) old_model = self._build_transformer(inputs, is_training, embedding_size=embedding_size) input_states = old_model.get_sequence_output() input_states = tf.stop_gradient(input_states) teacher_output = self._build_teacher(input_states, inputs, is_training, embedding_size=embedding_size) # calculate the proposal distribution action_prob = teacher_output.action_probs #pi(x_i) coin_toss = tf.random.uniform([]) log_q, masked_inputs = self._sample_masking_subset(inputs, action_prob) if config.masking_strategy == pretrain_helpers.MIX_ADV_STRATEGY: random_masked_input = pretrain_helpers.mask( config, pretrain_data.features_to_inputs(features), config.mask_prob) B, L = modeling.get_shape_list(inputs.input_ids) N = config.max_predictions_per_seq strategy_prob = tf.random.uniform([B]) strategy_prob = tf.expand_dims( tf.cast(tf.greater(strategy_prob, 0.5), tf.int32), 1) l_strategy_prob = tf.tile(strategy_prob, [1, L]) n_strategy_prob = tf.tile(strategy_prob, [1, N]) mix_input_ids = masked_inputs.input_ids * l_strategy_prob + random_masked_input.input_ids * ( 1 - l_strategy_prob) mix_masked_lm_positions = masked_inputs.masked_lm_positions * n_strategy_prob + random_masked_input.masked_lm_positions * ( 1 - n_strategy_prob) mix_masked_lm_ids = masked_inputs.masked_lm_ids * n_strategy_prob + random_masked_input.masked_lm_ids * ( 1 - n_strategy_prob) n_strategy_prob = tf.cast(n_strategy_prob, tf.float32) mix_masked_lm_weights = masked_inputs.masked_lm_weights * n_strategy_prob + random_masked_input.masked_lm_weights * ( 1 - n_strategy_prob) mix_masked_inputs = pretrain_data.get_updated_inputs( inputs, input_ids=tf.stop_gradient(mix_input_ids), masked_lm_positions=mix_masked_lm_positions, masked_lm_ids=mix_masked_lm_ids, masked_lm_weights=mix_masked_lm_weights, tag_ids=inputs.tag_ids) masked_inputs = mix_masked_inputs # BERT model model = self._build_transformer(masked_inputs, is_training, reuse=tf.AUTO_REUSE, embedding_size=embedding_size) mlm_output = self._get_masked_lm_output(masked_inputs, model) self.total_loss = mlm_output.loss # Teacher reward is the -log p(x_S|x;B) reward = tf.stop_gradient( tf.reduce_mean(mlm_output.per_example_loss, 1)) self._baseline = tf.reduce_mean(reward, -1) self._std = tf.math.reduce_std(reward, -1) # Calculate teacher loss def compute_teacher_loss(log_q, reward, baseline, std): advantage = tf.abs((reward - baseline) / std) advantage = tf.stop_gradient(advantage) log_q = tf.Print(log_q, [log_q], "log_q: ") teacher_loss = tf.reduce_mean(-log_q * advantage) return teacher_loss teacher_loss = tf.cond( coin_toss < 0.1, lambda: compute_teacher_loss( log_q, reward, self._baseline, self._std), lambda: tf.constant(0.0)) self.total_loss = mlm_output.loss + teacher_loss self.teacher_loss = teacher_loss self.mlm_loss = mlm_output.loss # Evaluation` eval_fn_inputs = { "input_ids": masked_inputs.input_ids, "masked_lm_preds": mlm_output.preds, "mlm_loss": mlm_output.per_example_loss, "masked_lm_ids": masked_inputs.masked_lm_ids, "masked_lm_weights": masked_inputs.masked_lm_weights, "input_mask": masked_inputs.input_mask } eval_fn_keys = eval_fn_inputs.keys() eval_fn_values = [eval_fn_inputs[k] for k in eval_fn_keys] """Computes the loss and accuracy of the model.""" d = {k: arg for k, arg in zip(eval_fn_keys, eval_fn_values)} metrics = dict() metrics["masked_lm_accuracy"] = tf.metrics.accuracy( labels=tf.reshape(d["masked_lm_ids"], [-1]), predictions=tf.reshape(d["masked_lm_preds"], [-1]), weights=tf.reshape(d["masked_lm_weights"], [-1])) metrics["masked_lm_loss"] = tf.metrics.mean( values=tf.reshape(d["mlm_loss"], [-1]), weights=tf.reshape(d["masked_lm_weights"], [-1])) self.eval_metrics = metrics
def mask(config: configure_pretraining.PretrainingConfig, inputs: pretrain_data.Inputs, mask_prob, proposal_distribution=1.0, disallow_from_mask=None, already_masked=None): """Implementation of dynamic masking. The optional arguments aren't needed for BERT/ELECTRA and are from early experiments in "strategically" masking out tokens instead of uniformly at random. Args: config: configure_pretraining.PretrainingConfig inputs: pretrain_data.Inputs containing input input_ids/input_mask mask_prob: percent of tokens to mask proposal_distribution: for non-uniform masking can be a [B, L] tensor of scores for masking each position. disallow_from_mask: a boolean tensor of [B, L] of positions that should not be masked out already_masked: a boolean tensor of [B, N] of already masked-out tokens for multiple rounds of masking Returns: a pretrain_data.Inputs with masking added """ # Get the batch size, sequence length, and max masked-out tokens N = config.max_predictions_per_seq B, L = modeling.get_shape_list(inputs.input_ids) # Find indices where masking out a token is allowed tokenizer = tokenization.FullTokenizer( config.vocab_file, do_lower_case=config.do_lower_case) vocab = tokenizer.vocab inv_vocab = tokenizer.inv_vocab candidates_mask = _get_candidates_mask(inputs, vocab, disallow_from_mask) # Set the number of tokens to mask out per example num_tokens = tf.cast(tf.reduce_sum(inputs.input_mask, -1), tf.float32) num_to_predict = tf.maximum(1, tf.minimum( N, tf.cast(tf.round(num_tokens * mask_prob), tf.int32))) masked_lm_weights = tf.cast(tf.sequence_mask(num_to_predict, N), tf.float32) if already_masked is not None: masked_lm_weights *= (1 - already_masked) # Get a probability of masking each position in the sequence candidate_mask_float = tf.cast(candidates_mask, tf.float32) if config.masking_strategy == RAND_STRATEGY or config.masking_strategy == MIX_ADV_STRATEGY: sample_prob = (proposal_distribution * candidate_mask_float) elif config.masking_strategy == POS_STRATEGY: unfavor_pos_mask = _get_unfavor_pos_mask(inputs) unfavor_pos_mask_float = tf.cast(unfavor_pos_mask, tf.float32) prefer_pos_mask_float = 1 - unfavor_pos_mask_float # prefered pos have 80% propabiblity, not preferred ones have 20% probability # proposal_distribution = prefer_pos_mask_float proposal_distribution = 0.95 * prefer_pos_mask_float + 0.05 sample_prob = (proposal_distribution * candidate_mask_float) elif config.masking_strategy == ENTROPY_STRATEGY: sample_prob = (proposal_distribution * candidate_mask_float) elif config.masking_strategy == MIX_POS_STRATEGY: rand_sample_prob = (proposal_distribution * candidate_mask_float) unfavor_pos_mask = _get_unfavor_pos_mask(inputs) unfavor_pos_mask_float = tf.cast(unfavor_pos_mask, tf.float32) prefer_pos_mask_float = 1 - unfavor_pos_mask_float # prefered pos have 80% propabiblity, not preferred ones have 20% probability # proposal_distribution = prefer_pos_mask_float proposal_distribution = 0.95 * prefer_pos_mask_float + 0.05 pos_sample_prob = (proposal_distribution * candidate_mask_float) strategy_prob = tf.random.uniform([B]) strategy_prob = tf.expand_dims(tf.cast(tf.greater(strategy_prob,0.5), tf.float32),1) strategy_prob = tf.tile(strategy_prob, [1,L]) sample_prob = rand_sample_prob * strategy_prob + pos_sample_prob * (1 - strategy_prob) elif config.masking_strategy == MIX_ENTROPY_STRATEGY: rand_sample_prob = (proposal_distribution * candidate_mask_float) entropy_sample_prob = (proposal_distribution * candidate_mask_float) strategy_prob = tf.random.uniform([B]) strategy_prob = tf.expand_dims(tf.cast(tf.greater(strategy_prob,0.5), tf.float32),1) strategy_prob = tf.tile(strategy_prob, [1, L]) sample_prob = rand_sample_prob * strategy_prob + entropy_sample_prob * (1 - strategy_prob) else: raise ValueError("{} strategy is not supported".format(config.masking_strategy)) sample_prob /= tf.reduce_sum(sample_prob, axis=-1, keepdims=True) # Sample the positions to mask out sample_prob = tf.stop_gradient(sample_prob) sample_logits = tf.log(sample_prob) masked_lm_positions = tf.random.categorical( sample_logits, N, dtype=tf.int32) masked_lm_positions *= tf.cast(masked_lm_weights, tf.int32) # Get the ids of the masked-out tokens shift = tf.expand_dims(L * tf.range(B), -1) flat_positions = tf.reshape(masked_lm_positions + shift, [-1, 1]) masked_lm_ids = tf.gather_nd(tf.reshape(inputs.input_ids, [-1]), flat_positions) masked_lm_ids = tf.reshape(masked_lm_ids, [B, -1]) masked_lm_ids *= tf.cast(masked_lm_weights, tf.int32) # Update the input ids replace_prob = tf.random.uniform([B, N]) replace_with_mask_positions = masked_lm_positions * tf.cast( tf.less(replace_prob, 0.85), tf.int32) inputs_ids, _ = scatter_update( inputs.input_ids, tf.fill([B, N], vocab["[MASK]"]), replace_with_mask_positions) # Replace with random tokens replace_with_random_positions = masked_lm_positions * tf.cast( tf.greater(replace_prob, 0.925), tf.int32) random_tokens = tf.random.uniform([B,N], minval=0, maxval=len(vocab), dtype=tf.int32) inputs_ids, _ = scatter_update( inputs_ids, random_tokens, replace_with_random_positions) if config.debug: def pretty_print(inputs_ids, masked_lm_ids, masked_lm_positions, masked_lm_weights, tag_ids): debug_inputs = Inputs( input_ids=inputs_ids, input_mask=None, segment_ids=None, masked_lm_positions=masked_lm_positions, masked_lm_ids=masked_lm_ids, masked_lm_weights=masked_lm_weights, tag_ids = tag_ids) pretrain_data.print_tokens(debug_inputs, inv_vocab) ## TODO: save to the mask choice return inputs_ids, masked_lm_ids, masked_lm_positions, masked_lm_weights mask_shape = masked_lm_ids.get_shape() inputs_ids, masked_lm_ids, masked_lm_positions, masked_lm_weights = \ tf.py_func(pretty_print,[inputs_ids, masked_lm_ids, masked_lm_positions, masked_lm_weights, inputs.tag_ids], (tf.int32, tf.int32, tf.int32, tf.float32)) inputs_ids.set_shape(inputs.input_ids.get_shape()) masked_lm_ids.set_shape(mask_shape) masked_lm_positions.set_shape(mask_shape) masked_lm_weights.set_shape(mask_shape) return pretrain_data.get_updated_inputs( inputs, input_ids=tf.stop_gradient(inputs_ids), masked_lm_positions=masked_lm_positions, masked_lm_ids=masked_lm_ids, masked_lm_weights=masked_lm_weights, tag_ids = inputs.tag_ids )
def _get_richer_data(self, fake_data): inputs_tf = fake_data.inputs.input_ids labels_tf = fake_data.is_fake_tokens lens_tf = tf.reduce_sum(fake_data.inputs.input_mask, 1) #retrieve the basic config V = self._bert_config.vocab_size #sub: 10%, del + ins: 5% N = int(self._config.max_predictions_per_seq * self._config.rich_prob) B, L = modeling.get_shape_list(inputs_tf) nlms = 0 bilm = None if self._config.use_bilm: with open(self._config.bilm_file, 'rb') as f: bilm = tf.constant(np.load(f), tf.int32) _, nlms = modeling.get_shape_list(bilm) #make multiple partitions for edit op splits_list = [] for i in range(B): one = tf.random.uniform([N * 4], 1, lens_tf[i], tf.int32) one, _ = tf.unique(one) one = tf.cond(tf.less(tf.shape(one)[0], N * 2 + 1), lambda: tf.expand_dims(tf.range(1, N * 2 + 2), 0), lambda: tf.sort(tf.reshape(one[: N * 2 + 1], [1, N * 2 + 1]))) splits_list.append(one[:, 2::2]) splits_tf = tf.concat(splits_list, 0) splits_up = tf.concat([splits_tf, tf.expand_dims(tf.constant([L] * B, tf.int32), 1)], 1) splits_lo = tf.concat([tf.expand_dims(tf.constant([0] * B, tf.int32), 1), splits_tf], 1) size_splits = splits_up - splits_lo #update the inputs and labels giving random insertion and deletion new_labels_list = [] new_inputs_list = [] for i in range(B): inputs_splits = tf.split(inputs_tf[i, :], size_splits[i, :]) labels_splits = tf.split(labels_tf[i, :], size_splits[i, :]) one_inputs = [] one_labels = [] size_split = len(inputs_splits) inputs_end = inputs_splits[-1] labels_end = labels_splits[-1] for j in range(size_split-1): inputs = inputs_splits[j] labels = labels_splits[j] #label 1 for substistution rand_op = random.randint(2, self._config.num_preds - 1) if rand_op == 2: #label 2 for insertion if bilm is None: #noise insert_tok = tf.random.uniform([1], 1, V, tf.int32) else: #2-gram prediction insert_tok = tf.expand_dims(bilm[inputs[-1], random.randint(0, nlms-1)], 0) is_end_valid = tf.less_equal(2, tf.shape(inputs_end)[0]) inputs = tf.cond(is_end_valid, lambda: tf.concat([inputs, insert_tok], 0), lambda: inputs) labels = tf.cond(is_end_valid, lambda: tf.concat([labels, tf.constant([2])], 0), lambda: labels) inputs_end = tf.cond(is_end_valid, lambda: inputs_end[:-1], lambda: inputs_end) labels_end = tf.cond(is_end_valid, lambda: labels_end[:-1], lambda: labels_end) elif rand_op == 3: #label 3 for deletion labels = tf.concat([labels[:-2], tf.constant([3])], 0) inputs = inputs[:-1] inputs_end = tf.concat([inputs_end, tf.constant([0])], 0) labels_end = tf.concat([labels_end, tf.constant([0])], 0) elif rand_op == 4: #label 4 for swapping labels = tf.concat([labels[:-1], tf.constant([4])], 0) inputs = tf.concat([inputs[:-2], [inputs[-1]], [inputs[-2]]], 0) one_labels.append(labels) one_inputs.append(inputs) one_inputs.append(inputs_end) one_labels.append(labels_end) one_inputs_tf = tf.concat(one_inputs, 0) one_labels_tf = tf.concat(one_labels, 0) one_inputs_tf = tf.cond(tf.less(lens_tf[i], N * 2 + 1), lambda: inputs_tf[i, :], lambda: one_inputs_tf) one_labels_tf = tf.cond(tf.less(lens_tf[i], N * 2 + 1), lambda: labels_tf[i, :], lambda: one_labels_tf) new_inputs_list.append(tf.expand_dims(one_inputs_tf, 0)) new_labels_list.append(tf.expand_dims(one_labels_tf, 0)) new_inputs_tf = tf.concat(new_inputs_list, 0) new_labels_tf = tf.concat(new_labels_list, 0) new_input_mask = tf.cast(tf.not_equal(new_inputs_tf, 0), tf.int32) updated_inputs = pretrain_data.get_updated_inputs( fake_data.inputs, input_ids=new_inputs_tf, input_mask=new_input_mask) RicherData = collections.namedtuple("RicherData", [ "inputs", "is_fake_tokens", "sampled_tokens"]) return RicherData(inputs=updated_inputs, is_fake_tokens=new_labels_tf, sampled_tokens=fake_data.sampled_tokens)