def _ParseRecord(self, record): """Reads and parses a single record.""" p = self.params name_to_features = { 'input_ids': tf.io.FixedLenFeature([p.max_sequence_length], tf.int64), 'input_mask': tf.io.FixedLenFeature([p.max_sequence_length], tf.int64), 'masked_lm_positions': tf.io.FixedLenFeature([p.max_predictions_per_seq], tf.int64), 'masked_lm_ids': tf.io.FixedLenFeature([p.max_predictions_per_seq], tf.int64), 'masked_lm_weights': tf.io.FixedLenFeature([p.max_predictions_per_seq], tf.float32), } example = tf.io.parse_single_example(record, name_to_features) mask_length = tf.cast(tf.reduce_sum(example['masked_lm_weights']), dtype=tf.int32) masked_lm_positions = tf.slice(example['masked_lm_positions'], [0], [mask_length]) masked_lm_ids = tf.cast(tf.slice(example['masked_lm_ids'], [0], [mask_length]), dtype=tf.int32) ret = py_utils.NestedMap() ret.masked_ids = tf.cast(example['input_ids'], dtype=tf.int32) # Get back non-masked, original ids. ret.ids = tf.tensor_scatter_nd_update(tensor=ret.masked_ids, indices=tf.reshape( masked_lm_positions, [-1, 1]), updates=masked_lm_ids) ret.masked_pos = tf.tensor_scatter_nd_update( tensor=tf.zeros_like(ret.masked_ids, dtype=tf.float32), indices=tf.reshape(masked_lm_positions, [-1, 1]), updates=tf.ones_like(masked_lm_ids, dtype=tf.float32)) ret.segment_ids = tf.cast(example['input_mask'], dtype=tf.float32) first_eos_idx = tf.where(tf.math.equal(ret.ids, p.eos_token_id))[0][0] def _RemoveFirstEos(x): # We remove the element at position `first_eos_idx`, and pad with 0 # to keep length unchanged. zero = tf.constant(0, shape=(1, ), dtype=x.dtype) return tf.concat([x[:first_eos_idx], x[first_eos_idx + 1:], zero], axis=0) ret = ret.Transform(_RemoveFirstEos) ret.paddings = 1.0 - ret.segment_ids pos = tf.cast(tf.range(p.max_sequence_length), dtype=tf.float32) ret.segment_pos = tf.cast(ret.segment_ids * pos, dtype=tf.int32) if p.remove_mask: del ret.masked_pos del ret.masked_ids return ret
def _Slice(tensor): """Return a slice of this tensor at time=state0.t.""" shape = py_utils.GetShape(tensor) # All zeros except for t in the time dimension. # e.g. if params.axis=1, begin is [0, t, 0, 0, 0, ...] begin = tf.one_hot(self.params.axis, tf.rank(tensor), on_value=state0.t) # Same as shape, but with a 1 in the time dimension. # e.g. if params.axis=1, shape is [shape[0], 1, shape[2], shape[3], ...] size = tf.concat([ shape[0:self.params.axis], tf.constant([1], dtype=tf.int32), shape[self.params.axis + 1:] ], axis=0) # Make a slice where the time dimension is fixed at state0.t. time_slice = tf.slice(tensor, begin, size) # Remove the time dimension. return tf.squeeze(time_slice, axis=self.params.axis)
def GetDecoderEmbeddingsDefaultTheta(self, input_ids, t=None): p = self.params seq_len = tf.shape(input_ids)[0] # [seq_len, batch, model_dim] input_embs = self.tgt_token_emb.EmbLookup(self.theta.tgt_token_emb, input_ids) # [seq_len, 1, model_dim] if t is None: pos_embs = tf.expand_dims( self.tgt_pos_emb.FProp(self.theta.tgt_pos_emb, seq_len), 1) else: # Support decoding. pos_embs = tf.slice( self.tgt_pos_emb.FProp(self.theta.tgt_pos_emb, p.max_seq_len), [t, 0], [1, p.token_emb.embedding_dim]) input_embs += pos_embs input_embs = self.tgt_dropout.FProp(self.theta.tgt_dropout, input_embs) return input_embs
def GetDecoderEmbeddingsDefaultTheta(self, input_ids, task_ids=None, t=None): p = self.params input_embs = self.tgt_token_emb.EmbLookup(self.theta.tgt_token_emb, input_ids) if t is None: time_dim = 0 if p.batch_dim else 1 seq_len = tf.shape(input_ids)[time_dim] pos_embs = tf.expand_dims( self.tgt_pos_emb.FProp(self.theta.tgt_pos_emb, seq_len), p.batch_dim) else: # Support decoding. pos_embs = tf.slice( self.tgt_pos_emb.FProp(self.theta.tgt_pos_emb, p.max_seq_len), [t, 0], [1, p.token_emb.embedding_dim]) input_embs += pos_embs if task_ids is not None and p.dec_task_emb: input_embs += self.tgt_task_emb.EmbLookup(self.theta.tgt_task_emb, task_ids) input_embs = self.tgt_dropout.FProp(self.theta.tgt_dropout, input_embs) return input_embs
def _PreprocessForTraining(self, image): """Distort one image for training a network. Args: image: The input image, a shape [height, width, num_channels=3] Tensor. Must be of type `tf.float32`. Image values are assumed to be in [0, 1]. Returns: 3-D float Tensor of distorted image used for training with range [0, 1]. """ p = self.params assert image.dtype == tf.float32 crop_bbox_begin, crop_bbox_size, _ = tf.image.sample_distorted_bounding_box( tf.shape(image), # No objects of interest; use the whole image as input. bounding_boxes=tf.zeros([1, 1, 4], dtype=tf.float32), area_range=(p.training_crop_min_area, 1.0), use_image_if_no_bounding_boxes=True) image = tf.slice(image, crop_bbox_begin, crop_bbox_size) # Restore the shape since the dynamic slice based upon the bbox_size loses # the third dimension. image.set_shape([None, None, 3]) # Bilinear resize to the target shape. Note this does not respect the # original aspect ratio and may distort the image. height, width = p.output_image_size image = tf.image.resize(image, [height, width], antialias=True) image.set_shape([height, width, 3]) image = tf.image.random_flip_left_right(image) image = _DistortBrightnessAndColor(image) # [0, 1] => output_range image *= float(p.output_range[1] - p.output_range[0]) image += p.output_range[0] return image
def _BeamSearchStep(self, theta, encoder_outputs, cur_step, step_ids, core_bs_states, other_states, num_hyps_per_beam, pre_beam_search_step_callback, post_beam_search_step_callback): """Extend beam search hyps for one step. | num_beams = Number of source sequences to be decoded. | num_hyps_per_beam = Number of hyps to keep per source sequence. | num_hyps = num_beams * num_hyps_per_beam | src_seq_len = Number of time steps in the source sequence. | src_batch = Number of examples in the source sequence. | tgt_seq_len = Maximum allowed time steps in the target sequence. | tgt_batch = num_hyps_per_beam * src_batch Args: theta: A `.NestedMap` object containing weights' values of the decoder layer and its children layers. encoder_outputs: A `.NestedMap` containing encoder outputs to be passed to the callbacks. cur_step: A scalar int tensor, the current time step, 0-based. step_ids: An int tensor of shape [num_hyps, 1]. The input ids to the current search step. core_bs_states: A tuple of core beam search states. This list is maintained by this helper class. other_states: A `.NestedMap` of other beam search states. This `.NestedMap` is managed and updated by the client. It is expected that each of its member tensors are of rank >= 1. t[i, ...] is the state of the i-th hyp at the beginning of this search step. num_hyps_per_beam: Num of hyps to keep per beam. pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback. See class header comments for more details. post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback. See class header comments for more details. Returns: A tuple of following elements for the next beam search step, (next step, all_done, step_ids, core_bs_states, other_states) """ p = self.params bs_results, other_states = pre_beam_search_step_callback( theta, encoder_outputs, step_ids, other_states, num_hyps_per_beam) (best_scores, cumulative_scores, in_scores, in_hyps, in_prev_hyps, in_done_hyps, in_atten_probs) = core_bs_states (out_best_scores, out_cumulative_scores, out_scores, out_hyps, out_prev_hyps, out_done_hyps, out_atten_probs, all_done) = ops.beam_search_step( bs_results.log_probs, bs_results.atten_probs, best_scores, cumulative_scores, in_scores, in_hyps, in_prev_hyps, in_done_hyps, in_atten_probs, bs_results.is_last_chunk if self._model_uses_eoc_id else [], cur_step, eoc_id=p.target_eoc_id, eos_id=p.target_eos_id, beam_size=p.beam_size, num_hyps_per_beam=num_hyps_per_beam, valid_eos_max_logit_delta=p.valid_eos_max_logit_delta, merge_paths=p.merge_paths, allow_empty_terminated_hyp=p.allow_empty_terminated_hyp, ensure_full_beam=p.ensure_full_beam, force_eos_in_last_step=p.force_eos_in_last_step, local_eos_threshold=p.local_eos_threshold) new_step_ids = tf.reshape(out_hyps[cur_step, :], tf.shape(step_ids)) new_step_ids.set_shape(step_ids.get_shape()) old_hyp_ids = tf.reshape( tf.slice(out_prev_hyps, begin=[cur_step, 0], size=[1, -1]), [-1]) if p.batch_major_compute: # Transformed the indices into the key/value cache for fast decoding # (prefix_states in other_states) due to the num_hyps dimension of # cache is computed as num_beams by num_hyps_per_beam, which is different # from the old_hyp_ids assumption (num_hyps_per_beam by num_beams). # Both transpose and recomputation are required to correct the indices. num_beams = tf.shape(best_scores)[0] old_hyp_ids_in_cache_order = tf.reshape( tf.transpose(tf.reshape(old_hyp_ids, [num_hyps_per_beam, -1])), [-1]) old_hyp_ids_in_cache_order = ( (old_hyp_ids_in_cache_order % num_beams) * num_hyps_per_beam + old_hyp_ids_in_cache_order // num_beams) new_bs_states = (out_best_scores, out_cumulative_scores, out_scores, out_hyps, out_prev_hyps, out_done_hyps, out_atten_probs) def ReOrderHyps(x_in): """Reorders x_in based on prev hyp ids.""" if (isinstance(x_in, tf.Tensor) and x_in.shape.ndims and x_in.shape.ndims > 0): if x_in.shape.ndims > 2 and not p.batch_major_state: # Use corrected indices only here for batch major compute as key/value # caches are the states being affected. correct_old_hyp_ids = ( old_hyp_ids_in_cache_order if p.batch_major_compute else old_hyp_ids) x_out = tf.gather(x_in, correct_old_hyp_ids, axis=1) else: x_out = tf.gather(x_in, old_hyp_ids) x_out.set_shape(x_in.get_shape()) return x_out else: return x_in new_other_states = other_states.Transform(ReOrderHyps) final_other_states = post_beam_search_step_callback(theta, encoder_outputs, new_step_ids, new_other_states) return (cur_step + 1, all_done, new_step_ids, new_bs_states, final_other_states)
def _maybe_update_block_mask(self, weights, threshold): """Performs block-granular masking of the weights. Block pruning occurs only if the block_height or block_width is > 1 and if the weight tensor, when squeezed, has ndims = 2. Otherwise, elementwise pruning occurs. Args: weights: The weight tensor that needs to be masked. threshold: The current threshold value. The function will compute a new threshold and return the exponential moving average using the current value of threshold Returns: new_threshold: The new value of the threshold based on weights, and sparsity at the current global_step new_mask: A numpy array of the same size and shape as weights containing 0 or 1 to indicate which of the values in weights falls below the threshold Raises: ValueError: if block pooling function is not AVG or MAX """ block_dims = self._get_block_dims(weights.op.name) squeezed_weights = tf.squeeze(weights) if squeezed_weights.get_shape().ndims != 2 or block_dims == [1, 1]: return self._update_mask(weights, threshold) for i in range(2): if block_dims[i] == -1: block_dims[i] = squeezed_weights.get_shape()[i] if self._block_pooling_function not in ['AVG', 'MAX']: raise ValueError( 'Unknown pooling function for block sparsity: %s' % self._block_pooling_function) with tf.name_scope(weights.op.name + '_pruning_ops'): abs_weights = tf.abs(squeezed_weights) pool_window = block_dims pool_fn = pruning_utils.factorized_pool squeeze_axis = None if not self._spec.use_tpu: pool_fn = tf.nn.pool abs_weights = tf.reshape(abs_weights, [ 1, abs_weights.get_shape()[0], abs_weights.get_shape()[1], 1 ]) squeeze_axis = [0, 3] pooled_weights = pool_fn(abs_weights, window_shape=pool_window, pooling_type=self._block_pooling_function, strides=pool_window, padding='SAME', name=weights.op.name + '_pooled') if pooled_weights.get_shape().ndims != 2: pooled_weights = tf.squeeze(pooled_weights, axis=squeeze_axis) smoothed_threshold, new_mask = self._update_mask( pooled_weights, threshold) updated_mask = pruning_utils.expand_tensor(new_mask, block_dims) sliced_mask = tf.slice(updated_mask, [0, 0], [ squeezed_weights.get_shape()[0], squeezed_weights.get_shape()[1] ]) return smoothed_threshold, tf.reshape(sliced_mask, tf.shape(weights))