def _get_fake_data(self, inputs, mlm_logits): """Sample from the generator to create corrupted input.""" inputs = pretrain_helpers.unmask(inputs) disallow = tf.one_hot( inputs.masked_lm_ids, depth=self._bert_config.vocab_size, dtype=tf.float32) if self._config.disallow_correct else None sampled_tokens = tf.stop_gradient( pretrain_helpers.sample_from_softmax(mlm_logits / self._config.temperature, disallow=disallow)) sampled_tokids = tf.argmax(sampled_tokens, -1, output_type=tf.int32) updated_input_ids, masked = pretrain_helpers.scatter_update( inputs.input_ids, sampled_tokids, inputs.masked_lm_positions) if self._config.electric_objective: labels = masked else: labels = masked * (1 - tf.cast( tf.equal(updated_input_ids, inputs.input_ids), tf.int32)) updated_inputs = pretrain_data.get_updated_inputs( inputs, input_ids=updated_input_ids) FakedData = collections.namedtuple( "FakedData", ["inputs", "is_fake_tokens", "sampled_tokens"]) return FakedData(inputs=updated_inputs, is_fake_tokens=labels, sampled_tokens=sampled_tokens)
def _get_fake_data(inputs, mlm_logits): """Sample from the generator to create corrupted input.""" masked_lm_weights = inputs.masked_lm_weights inputs = pretrain_helpers.unmask(inputs) disallow = None sampled_tokens = tf.stop_gradient( pretrain_helpers.sample_from_softmax(mlm_logits / 1.0, disallow=disallow)) # sampled_tokens: [batch_size, n_pos, n_vocab] # mlm_logits: [batch_size, n_pos, n_vocab] sampled_tokens_fp32 = tf.cast(sampled_tokens, dtype=tf.float32) print(sampled_tokens_fp32, "===sampled_tokens_fp32===") # [batch_size, n_pos] # mlm_logprobs: [batch_size, n_pos. n_vocab] mlm_logprobs = tf.nn.log_softmax(mlm_logits, axis=-1) pseudo_logprob = tf.reduce_sum(mlm_logprobs * sampled_tokens_fp32, axis=-1) pseudo_logprob *= tf.cast(masked_lm_weights, dtype=tf.float32) # [batch_size] pseudo_logprob = tf.reduce_sum(pseudo_logprob, axis=-1) # [batch_size] # pseudo_logprob /= (1e-10+tf.reduce_sum(tf.cast(masked_lm_weights, dtype=tf.float32), axis=-1)) print("== _get_fake_data pseudo_logprob ==", pseudo_logprob) sampled_tokids = tf.argmax(sampled_tokens, -1, output_type=tf.int32) updated_input_ids, masked = pretrain_helpers.scatter_update( inputs.input_ids, sampled_tokids, inputs.masked_lm_positions) labels = masked * ( 1 - tf.cast(tf.equal(updated_input_ids, inputs.input_ids), tf.int32)) updated_inputs = pretrain_data.get_updated_inputs( inputs, input_ids=updated_input_ids) FakedData = collections.namedtuple( "FakedData", ["inputs", "is_fake_tokens", "sampled_tokens", "pseudo_logprob"]) return FakedData(inputs=updated_inputs, is_fake_tokens=labels, sampled_tokens=sampled_tokens, pseudo_logprob=pseudo_logprob)
def _sample_masking_subset(self, inputs: pretrain_data.Inputs, action_probs): #calculate shifted action_probs input_mask = inputs.input_mask segment_ids = inputs.segment_ids input_ids = inputs.input_ids shape = modeling.get_shape_list(input_ids, expected_rank=2) batch_size = shape[0] max_seq_len = shape[1] def _remove_special_token(elems): action_prob = tf.cast(elems[0], tf.float32) segment = tf.cast(elems[1], tf.int32) input = tf.cast(elems[2], tf.int32) mask = tf.cast(elems[3], tf.int32) seq_len = tf.reduce_sum(mask) seg1_len = seq_len - tf.reduce_sum(segment) seq1_idx = tf.range(start=1, limit=seg1_len - 1, dtype=tf.int32) seq2_limit = tf.math.maximum(seg1_len, seq_len - 1) seq2_idx = tf.range(start=seg1_len, limit=seq2_limit, dtype=tf.int32) mask_idx = tf.range(start=seq_len, limit=max_seq_len, dtype=tf.int32) index_tensor = tf.concat([seq1_idx, seq2_idx, mask_idx], axis=0) seq1_prob = action_prob[1:seg1_len - 1] seq2_prob = action_prob[seg1_len:seq2_limit] mask_prob = tf.ones_like(mask_idx, dtype=tf.float32) * 1e-20 cleaned_action_prob = tf.concat([seq1_prob, seq2_prob, mask_prob], axis=0) cleaned_mask = tf.concat([ mask[1:seg1_len - 1], mask[seg1_len:seq_len - 1], mask[seq_len:max_seq_len] ], axis=0) cleaned_input = tf.concat([ input[1:seg1_len - 1], input[seg1_len:seq_len - 1], input[seq_len:max_seq_len] ], axis=0) cleaned_action_prob = cleaned_action_prob[0:max_seq_len - 3] index_tensor = index_tensor[0:max_seq_len - 3] cleaned_input = cleaned_input[0:max_seq_len - 3] cleaned_mask = cleaned_mask[0:max_seq_len - 3] return (cleaned_action_prob, index_tensor, cleaned_input, cleaned_mask) # Remove CLS and SEP action probs elems = tf.stack([ action_probs, tf.cast(segment_ids, tf.float32), tf.cast(input_ids, tf.float32), tf.cast(input_mask, tf.float32) ], 1) cleaned_action_probs, index_tensors, cleaned_inputs, cleaned_input_mask = tf.map_fn( _remove_special_token, elems, dtype=(tf.float32, tf.int32, tf.int32, tf.int32), parallel_iterations=1) logZ, log_prob = self._calculate_partition_table( cleaned_input_mask, cleaned_action_probs, self._config.max_predictions_per_seq) samples, log_q = self._sampling_a_subset( logZ, log_prob, self._config.max_predictions_per_seq) # Collect masked_lm_ids and masked_lm_positions zero_values = tf.zeros_like(index_tensors, tf.int32) selected_position = tf.where(tf.equal(samples, 1), index_tensors, zero_values) masked_lm_positions, _ = tf.nn.top_k( selected_position, self._config.max_predictions_per_seq, sorted=False) # Get the ids of the masked-out tokens shift = tf.expand_dims(max_seq_len * tf.range(batch_size), -1) flat_positions = tf.reshape(masked_lm_positions + shift, [-1, 1]) masked_lm_ids = tf.gather_nd(tf.reshape(input_ids, [-1]), flat_positions) masked_lm_ids = tf.reshape(masked_lm_ids, [batch_size, -1]) # Update the input ids replaced_prob = tf.random.uniform( [batch_size, self._config.max_predictions_per_seq]) replace_with_mask_positions = masked_lm_positions * tf.cast( tf.less(replaced_prob, 0.85), tf.int32) inputs_ids, _ = scatter_update( inputs.input_ids, tf.fill([batch_size, self._config.max_predictions_per_seq], self._vocab["[MASK]"]), replace_with_mask_positions) # Replace with random tokens replace_with_random_positions = masked_lm_positions * tf.cast( tf.greater(replaced_prob, 0.925), tf.int32) random_tokens = tf.random.uniform( [batch_size, self._config.max_predictions_per_seq], minval=0, maxval=len(self._vocab), dtype=tf.int32) inputs_ids, _ = scatter_update(inputs_ids, random_tokens, replace_with_random_positions) masked_lm_weights = tf.ones_like(masked_lm_ids, tf.float32) inv_vocab = self._inv_vocab # Apply mask on input if self._config.debug: def pretty_print(inputs_ids, masked_lm_ids, masked_lm_positions, masked_lm_weights, tag_ids): debug_inputs = Inputs(input_ids=inputs_ids, input_mask=None, segment_ids=None, masked_lm_positions=masked_lm_positions, masked_lm_ids=masked_lm_ids, masked_lm_weights=masked_lm_weights, tag_ids=tag_ids) pretrain_data.print_tokens(debug_inputs, inv_vocab) ## TODO: save to the mask choice return inputs_ids, masked_lm_ids, masked_lm_positions, masked_lm_weights mask_shape = masked_lm_ids.get_shape() inputs_ids, masked_lm_ids, masked_lm_positions, masked_lm_weights = \ tf.py_func(pretty_print, [inputs_ids, masked_lm_ids, masked_lm_positions, masked_lm_weights, inputs.tag_ids], (tf.int32, tf.int32, tf.int32, tf.float32)) inputs_ids.set_shape(inputs.input_ids.get_shape()) masked_lm_ids.set_shape(mask_shape) masked_lm_positions.set_shape(mask_shape) masked_lm_weights.set_shape(mask_shape) masked_input = pretrain_data.get_updated_inputs( inputs, input_ids=tf.stop_gradient(input_ids), masked_lm_positions=tf.stop_gradient(masked_lm_positions), masked_lm_ids=tf.stop_gradient(masked_lm_ids), masked_lm_weights=tf.stop_gradient(masked_lm_weights), tag_ids=inputs.tag_ids) return log_q, masked_input