def _Proc(record): """Parses a serialized tf.Example record.""" outputs = [ ('inputs', tf.io.VarLenFeature(tf.int64)), ('targets', tf.io.VarLenFeature(tf.int64)), # Default eval weight to 1.0 ('eval_weight', tf.io.FixedLenFeature([], tf.float32, default_value=1.0)), ] features = tf.io.parse_single_example(record, dict(outputs)) for k, v in six.iteritems(features): if k != 'eval_weight': features[k] = v.values else: eval_weight = v src_ids = features['inputs'] tgt_labels = features['targets'] # Derive trivial segmentation for unpacked input. src_paddings, tgt_ids, tgt_paddings, tgt_weights, bucket_key = _DerivePaddingsAndIds( src_ids, tgt_labels) src_len = tf.shape(src_ids)[0] tgt_len = tf.shape(tgt_ids)[0] src_pos = tf.range(src_len, dtype=tf.int32) src_seg = tf.zeros_like(src_paddings) tgt_pos = tf.range(tgt_len, dtype=tf.int32) tgt_seg = tf.zeros_like(tgt_paddings) return [ src_ids, src_paddings, tgt_ids, tgt_paddings, tgt_labels, tgt_weights, src_pos, src_seg, tgt_pos, tgt_seg, eval_weight ], bucket_key
def Step(recurrent_theta, state0, inputs): """Computes one decoder step.""" del inputs with tf.name_scope('single_sampler_step'): # Compute logits and states. bs_result, bs_state1 = pre_step_callback( recurrent_theta.theta, recurrent_theta.encoder_outputs, tf.expand_dims(state0.ids, 1), # [batch, 1]. state0.bs_state, num_hyps_per_beam=1) batch = tf.shape(bs_result.log_probs)[0] state1 = py_utils.NestedMap(timestep=state0.timestep + 1) state1.logits = bs_result.log_probs # Sample ids from logits. [batch]. state1.ids = tf.reshape( tf.random.stateless_categorical( state1.logits / p.temperature, num_samples=1, seed=tf.stack( [recurrent_theta.random_seed, state0.timestep]), dtype=state0.ids.dtype, name='sample_next_id'), [batch]) if 'is_last_chunk' in bs_result and p.target_eoc_id >= 0: state1.ids = tf.where( tf.math.logical_and( bs_result.is_last_chunk, tf.equal(state1.ids, p.target_eoc_id)), tf.fill(tf.shape(state1.ids), p.target_eos_id), state1.ids) state1.bs_state = post_step_callback( recurrent_theta.theta, recurrent_theta.encoder_outputs, state1.ids, bs_state1) return state1, py_utils.NestedMap()
def _InputBatch(self): p = self.params @tf.function def ReadData(): x, y = io_ops.restore_v2(p.ckpt, [p.data, p.label], [''] * 2, [p.data_dtype, p.label_dtype]) # Always convert to float32. return tf.cast(x, tf.float32), tf.cast(y, tf.float32) # Loads data and label into memory and keep it around. data, label = ops.cached_call(f=ReadData.get_concrete_function(), T=[tf.float32, tf.float32]) b, shape = self.InfeedBatchSize(), list(p.data_shape) data = tf.reshape(data, [-1] + shape) label = tf.reshape(label, [-1]) label = py_utils.HasShape(label, [tf.shape(data)[0]]) sample_ids = ops.random_permutation_sequence( num=p.num_samples, batch=b, repeat=p.repeat, seed=p.random_seed if p.random_seed else 0) n = tf.shape(sample_ids)[0] raw = py_utils.PadOrTrimTo(tf.gather(data, sample_ids), [b] + shape) ret = py_utils.NestedMap( raw=raw, data=self._Preprocess(raw), label=py_utils.PadOrTrimTo(tf.gather(label, sample_ids), [b]), weight=py_utils.PadOrTrimTo(tf.ones([n], dtype=tf.float32), [b])) if not py_utils.use_tpu(): ret['sample_ids'] = sample_ids return ret
def ComputeSplits(batch_size, num_splits): """Creates a tensor of size num_splits of number of values per split. Assigns each split floor(batch_size/num_splits) and round-robins the remainder (if any) to each split. Example:: batch_size: [5] num_splits: 3 returns: [2, 2, 1] Args: batch_size: tensor of rank 0, size of tensor to be split num_splits: number of splits to split tensor into Returns: tensor of length num_splits containing sizes of each split """ values = tf.tile(tf.div([batch_size], num_splits), tf.constant([num_splits], dtype=tf.int32)) mods = tf.tile(tf.constant([1]), tf.math.floormod([batch_size], num_splits)) zeros = tf.tile(tf.constant([0]), tf.subtract(tf.shape(values), tf.shape(mods))) mods = tf.concat([mods, zeros], 0) ret = tf.add(values, mods) # for some reason TF erases shape information if num_splits is 1 if num_splits == 1: ret.set_shape([1]) return ret
def _Moments(inputs, mask, enable_cross_replica_sum_on_tpu=False): """Computes mean and variance over the valid data points in inputs.""" inputs = py_utils.with_dependencies([ py_utils.assert_equal(tf.rank(inputs), tf.rank(mask)), py_utils.assert_greater_equal(mask, tf.zeros_like(mask)), ], inputs) rank = tf.rank(mask) reduce_over_dims = tf.range(0, rank - 1) sum_v = tf.reduce_sum(inputs * tf.cast(mask, inputs.dtype), reduce_over_dims) count_v = tf.reduce_sum(mask, reduce_over_dims) # Input shape is guaranteed to be a multiple of mask shape because the # inputs * mask op above was successfully broadcasted. mask_multiplier = tf.shape(inputs)[:-1] // tf.shape(mask)[:-1] count_v *= tf.cast(tf.reduce_prod(mask_multiplier), count_v.dtype) if py_utils.use_tpu() and enable_cross_replica_sum_on_tpu: sum_v = tf.tpu.cross_replica_sum(sum_v) count_v = tf.tpu.cross_replica_sum(count_v) count_v = tf.maximum(count_v, 1.0) mean = sum_v / count_v sum_vv = tf.reduce_sum((inputs - mean) * (inputs - mean) * mask, reduce_over_dims) if py_utils.use_tpu() and enable_cross_replica_sum_on_tpu: sum_vv = tf.tpu.cross_replica_sum(sum_vv) variance = py_utils.with_dependencies([ py_utils.assert_greater_equal(sum_vv, tf.zeros_like(sum_vv)), ], sum_vv / count_v) return mean, variance
def FProp(self, theta, input_batch): """Encodes source as represented by `inputs` and `paddings`. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. input_batch: A `.NestedMap` with fields: - ids: The inputs tensor. It is expected to be of shape [batch, time]. - paddings: The paddings tensor. Expected shape [batch, time]. Returns: A NestedMap containing: - encoded: The encoded features, a tensor of shape [time, batch, depth] - padding: of shape [time, batch] - segment_id: [time, batch] if packed inputs are supported by the model (and all layers), or None otherwise. """ p = self.params src_segment_id = None with tf.name_scope(p.name): # Now the rnn layers. inputs = py_utils.with_dependencies([ py_utils.assert_shape_match(tf.shape(input_batch.ids), [-1, -1]), py_utils.assert_shape_match(tf.shape(input_batch.ids), tf.shape(input_batch.paddings)) ], tf.transpose(input_batch.ids)) paddings = tf.expand_dims(tf.transpose(input_batch.paddings), 2) xs = self.emb.EmbLookup(theta.emb, inputs) xs = self.ApplyClipping(theta, xs) self._emb_out = xs ps = paddings # When cc_schedule is specified, make sure lstm_tpl is QuantizedLSTMCell # with the same cc_schedule so that the RNN layer output is within # clipping range. xs = self.rnn[0].FProp(theta.rnn[0], xs, ps) xs = self.dropout.FProp(theta.dropout, xs) for i in range(1, p.num_lstm_layers): layer = self.rnn[i] ys, _ = layer.FProp(theta.rnn[i], xs, ps) ys = self.dropout.FProp(theta.dropout, ys) if hasattr(layer.params, 'cell'): layer_params = layer.params.cell else: layer_params = layer.params if layer_params.num_input_nodes == layer_params.num_output_nodes: xs += ys # Residual skip xs = self.ApplyClipping(theta, xs) else: # When cc_schedule is specified, make sure lstm_tpl is # QuantizedLSTMCell with the same cc_schedule so that the RNN layer # output is within clipping range. xs = ys return py_utils.NestedMap(encoded=xs, padding=tf.squeeze(ps, [2]), segment_id=src_segment_id)
def FProp(self, theta, inputs, paddings): """Apply convolution to inputs. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. inputs: The inputs tensor. It is expected to be of shape [batch, time, frequency, channel]. The time dimension corresponds to the height dimension as in images and the frequency dimension corresponds to the width dimension as in images. paddings: The paddings tensor, expected to be of shape [batch, time]. Returns: outputs, out_paddings pair. """ p = self.params with tf.name_scope(p.name): inputs = py_utils.with_dependencies([ py_utils.assert_shape_match(tf.shape(paddings), [-1, -1]), py_utils.assert_shape_match( tf.shape(inputs), tf.concat([ tf.shape(paddings), [-1, symbolic.ToStatic(self.input_channels)] ], 0)) ], inputs) def _ApplyPadding(tensor_in, padding_in): padding_expanded = tf.expand_dims( tf.expand_dims(padding_in, -1), -1) return tensor_in * (1.0 - padding_expanded) # Zeroing out padded inputs. inputs = _ApplyPadding(inputs, paddings) # Apply conv on 'inputs'. out = self._ApplyConv(theta, inputs) if p.partial_conv: out = self._RescaleBoundary(out, paddings) # NOTE: this may be slightly inaccurate when p.dilation_rate[0] > 1. # But there's likely no real problems. Trying to set it gives an error: # pooling with SAME padding is not implemented for dilation_rate > 1. # NOTE: we use window=p.filter_stride[0] to be compatible with legacy # implementation. Consider updating it to be the actual shape. conv_padding = ComputeConvOutputPadding(paddings, window=p.filter_stride[0], stride=p.filter_stride[0]) # Assuming padded nodes will be properly zero-ed out if necessary by # sub-sequent layers. # out = _ApplyPadding(out, conv_padding) out = py_utils.HasShape( out, symbolic.ToStatic(self.OutShape(tf.shape(inputs)))) return out, conv_padding
def _AugmentationNetwork(self, series_length, inputs, paddings, global_seed, domain_id_index=0): """Returns augmented features. Args: series_length: Total length of time series. inputs: Batch of input features of shape (batch_size, time_length, num_freq, channels). paddings: Batch of padding vectors of shape (batch_size, time_length). global_seed: an integer seed tensor for stateless random ops. domain_id_index: domain id index. Returns: Batch of output features of shape (batch_size, time_length, num_freq, channels) obtained by applying random augmentations to inputs. """ p = self.params dtype = p.dtype # Unstack the features. if p.unstack: inputs, paddings = self.UnstackFeatures(inputs, paddings) lengths = tf.reduce_sum(1 - paddings, 1) inputs = self._TimeWarp(inputs, lengths, global_seed=global_seed, dtype=dtype, domain_id_index=domain_id_index) inputs = self._TimeMask(inputs, lengths, global_seed=global_seed, noisify=p.use_noise, gaussian_noise=p.gaussian_noise, dtype=dtype, domain_id_index=domain_id_index) inputs = self._FrequencyMask(inputs, global_seed=global_seed, dtype=dtype, domain_id_index=domain_id_index) # Restack the features after applying specaugment. if p.unstack: inputs = tf.reshape( inputs, [tf.shape(inputs)[0], series_length, -1, tf.shape(inputs)[3]]) return inputs
def FProp(self, theta, input_batch): p = self.params with tf.name_scope(p.name): inputs = py_utils.with_dependencies([ py_utils.assert_shape_match(tf.shape(input_batch.ids), [-1, -1]), py_utils.assert_shape_match(tf.shape(input_batch.ids), tf.shape(input_batch.paddings)) ], tf.transpose(input_batch.ids)) paddings = tf.expand_dims(tf.transpose(input_batch.paddings), 2) if p.packed_input: src_segment_id = tf.expand_dims( tf.transpose(input_batch.segment_ids), 2) else: src_segment_id = None xs = self.emb.EmbLookup(theta.emb, inputs) xs = self.ApplyClipping(theta, xs) summary_utils.histogram('input_emb', xs) xs = self.dropout.FProp(theta.dropout, xs) ps = paddings # Now the rnn layers. outputs_list = [] for i in range(0, p.num_lstm_layers): layer = self.rnn[i] ys = layer.FProp(theta.rnn[i], xs, ps, segment_id=src_segment_id) ys = self.dropout.FProp(theta.dropout, ys) if i >= p.residual_start: xs += ys # Residual skip xs = self.ApplyClipping(theta, xs) else: xs = ys outputs_list.append(xs) summary_utils.histogram('layer_out_%s' % i, xs) if p.is_transparent: xs = self.transparent_merger.FProp(theta.transparent_merger, outputs_list) if p.lstm_cell_size * 2 != p.encoder_out_dim: # Project to the right depth. xs = self.final_proj.FProp(theta.final_proj, xs, ps) summary_utils.histogram('final_proj_out', xs) if src_segment_id is not None: src_segment_id = tf.squeeze(src_segment_id, [2]) return py_utils.NestedMap(encoded=xs, padding=tf.squeeze(ps, [2]), segment_id=src_segment_id)
def _DerivePaddingsAndIds(src_ids, tgt_labels): """tgt_ids is tgt_labels shifted right by one, with a SOS ID prepended.""" tgt_ids = tf.concat([[p.sos_id], tgt_labels[:-1]], axis=0) src_paddings = tf.zeros(tf.shape(src_ids), dtype=tf.float32) tgt_paddings = tf.zeros(tf.shape(tgt_ids), dtype=tf.float32) tgt_weights = tf.ones(tf.shape(tgt_ids), dtype=tf.float32) bucket_key = tf.cast( tf.maximum( tf.reduce_sum(1.0 - src_paddings), tf.reduce_sum(1.0 - tgt_paddings)), tf.int32) return src_paddings, tgt_ids, tgt_paddings, tgt_weights, bucket_key
def FProp(self, theta, input_batch, state0=None): p = self.params src_segment_id = None with tf.name_scope(p.name): # Reshape to [t, b] inputs = py_utils.with_dependencies([ py_utils.assert_shape_match(tf.shape(input_batch.ids), [-1, -1]), py_utils.assert_shape_match(tf.shape(input_batch.ids), tf.shape(input_batch.paddings)) ], tf.transpose(input_batch.ids)) paddings = tf.expand_dims(tf.transpose(input_batch.paddings), 2) # Setup streaming states. if not state0: state0 = self.zero_state(theta, tf.shape(inputs)[1]) state1 = py_utils.NestedMap(rnn=[None] * p.num_lstm_layers) xs = self.emb.EmbLookup(theta.emb, inputs) xs = self.ApplyClipping(theta, xs) summary_utils.histogram('input_emb', xs) xs = self.dropout.FProp(theta.dropout, xs) ps = paddings # Now the rnn layers. outputs_list = [] for i in range(0, p.num_lstm_layers): layer = self.rnn[i] ys, state1.rnn[i] = layer.FProp(theta.rnn[i], xs, ps, state0=state0.rnn[i]) ys = self.dropout.FProp(theta.dropout, ys) if i >= p.residual_start: xs += ys # Residual skip xs = self.ApplyClipping(theta, xs) else: xs = ys outputs_list.append(xs) summary_utils.histogram('layer_out_%s' % i, xs) if p.is_transparent: xs = self.transparent_merger.FProp(theta.transparent_merger, outputs_list) return py_utils.NestedMap(encoded=xs, padding=tf.squeeze(ps, [2]), segment_id=src_segment_id, state=state1)
def InitBeamSearchStateCallback(theta, encoder_outputs, num_hyps_per_beam): """Wrapper for adding bias to _InitBeamSearchStateCallback. Exapnds state to track consistency of hypothesis with provided target. Args: theta: A NestedMap object containing weights' values of this layer and its children layers. encoder_outputs: A NestedMap computed by encoder. num_hyps_per_beam: An int, number hyps to keep for source sentence. Returns: initial_results: a `.NestedMap` of initial results. states: a `.NestedMap` of initial model states that the client would like to keep track of for each hyp. The states relevant here are: time_step: A scalar indicating current step (=0 for initial state) of decoder. Must be provided and maintained by super. consistent: A boolean tensor of shape [tgt_batch, 1] which tracks whether each hypothesis has exactly matched encoder_outputs.targets so far. """ initial_results, states = self._InitBeamSearchStateCallback( theta, encoder_outputs, num_hyps_per_beam) assert hasattr(states, 'time_step') num_hyps = tf.shape(encoder_outputs.padding)[1] * num_hyps_per_beam # states.consistent is initially all True states.consistent = tf.ones([ num_hyps, ], dtype=tf.bool) return initial_results, states
def Finalize(self): """Finishes creation of the overall figure, returning the image summary.""" subplot_grid_shape = self._subplot_grid_shape if subplot_grid_shape is None: subplot_grid_shape = (len(self._subplots), 1) # AddMatplotlibFigureSummary (due to restrictions of py_func) only supports # flattened list of tensors so we must do some bookkeeping to maintain a # mapping from _SubplotMetadata object to flattened_tensors. subplot_slices = [] flattened_tensors = [] for subplot in self._subplots: start = len(flattened_tensors) subplot_slices.append((start, start + len(subplot.tensor_list))) flattened_tensors.extend(subplot.tensor_list) def PlotFunc(fig, *numpy_data_list): gs = gridspec.GridSpec(*subplot_grid_shape, **self._gridspec_kwargs) for n, subplot in enumerate(self._subplots): axes = fig.add_subplot(gs[n]) start, end = subplot_slices[n] subplot_data = numpy_data_list[start:end] subplot.plot_func(fig, axes, *subplot_data) func = functools.partial(_RenderMatplotlibFigures, self._figsize, self._max_outputs, PlotFunc) batch_sizes = [tf.shape(t)[0] for t in flattened_tensors] num_tensors = len(flattened_tensors) with tf.control_dependencies([ tf.assert_equal( batch_sizes, [batch_sizes[0]] * num_tensors, summarize=num_tensors) ]): rendered = tf.py_func( func, flattened_tensors, tf.uint8, name='RenderMatplotlibFigures') return tf.summary.image(self._name, rendered, max_outputs=self._max_outputs)
def FProp(self, theta, inputs, paddings=None): """Apply batch normalization. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. inputs: The inputs tensor. Shaped [..., dim]. paddings: The paddings tensor. Shaped [..., 1], with the same rank as the input tensor. Returns: Output after applying batch normalization, with the same shape as 'inputs'. """ p = self.params if paddings is None: paddings = self._GetDefaultPaddings(inputs) with tf.name_scope(p.name): norm_mean, norm_variance, beta, gamma = self.ComputeAndUpdateMoments( theta, inputs, paddings) with tf.control_dependencies([ py_utils.assert_greater_equal(norm_variance, tf.zeros_like(norm_variance)), py_utils.assert_shape_match([tf.shape(inputs)[-1]], tf.shape(norm_mean)), py_utils.assert_shape_match([tf.shape(inputs)[-1]], tf.shape(norm_variance)), ]): if p.use_fused_batch_norm_for_eval and self.do_eval: bn_output, _, _ = nn.fused_batch_norm( inputs, gamma, beta, norm_mean, norm_variance, self._epsilon, is_training=False) else: bn_output = tf.nn.batch_normalization(inputs, norm_mean, norm_variance, beta, gamma, self._epsilon) if p.set_padded_output_to_zero: bn_output *= 1.0 - paddings return bn_output
def _GreedySearchStep(self, theta, encoder_outputs, cur_step, step_ids, hyp_ids, hyp_lens, done_hyps, other_states, pre_beam_search_step_callback, post_beam_search_step_callback): """Extend greedy search hyps for one step. Args: theta: A `.NestedMap` object containing weights' values of the decoder layer and its children layers. encoder_outputs: A `.NestedMap` containing encoder outputs to be passed to the callbacks. cur_step: A scalar int tensor, the current time step, 0-based. step_ids: An int tensor of shape [num_hyps, 1]. The input ids to the current search step. hyp_ids: An int tensor of shape [num_hyps, tgt_seq_len]. hyp_lens: Valid length of all the hyps. Tokens after eos ids are not counted. done_hyps: Whether or not a hyp has finished. other_states: A `.NestedMap` of other beam search states. This `.NestedMap` is managed and updated by the client. It is expected that each of its member tensors are of rank >= 1. t[i, ...] is the state of the i-th hyp at the beginning of this search step. pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback. See class header comments for more details. post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback. See class header comments for more details. Returns: A tuple of following elements for the next greedy search step, (next step, new_step_ids, hyp_ids, hyp_lens, done_hyps, other_states) """ p = self.params # Increment hyp_lens by 1 if the hyp is not finished yet. hyp_lens = hyp_lens + (1 - tf.cast(done_hyps, tf.int32)) bs_results, new_other_states = pre_beam_search_step_callback( theta, encoder_outputs, step_ids, other_states, 1) # num_hyps_per_beam new_step_ids = tf.math.argmax(bs_results.log_probs, 1) new_step_ids = tf.cast(new_step_ids, tf.int32) new_step_ids = tf.reshape(new_step_ids, tf.shape(step_ids)) final_other_states = post_beam_search_step_callback( theta, encoder_outputs, new_step_ids, new_other_states) # Stash new_step_ids into the right slot. new_step_ids_1d = tf.reshape(new_step_ids, [-1]) hyp_ids = inplace_ops.alias_inplace_update(hyp_ids, cur_step, new_step_ids_1d) # Update done_hyps if the current step_ids is the end of sequence token. done_hyps = tf.math.logical_or( done_hyps, tf.equal(new_step_ids_1d, p.target_eos_id)) return (cur_step + 1, new_step_ids, hyp_ids, hyp_lens, done_hyps, final_other_states)
def FProp(self, theta, inputs): """Applies batch normalization. Using the implementation in github.com/ tensorflow/tpu/blob/master/models/official/amoeba_net/network_utils.py#L550 Args: theta: A nested map object containing weights' values of this layer and its children layers. inputs: The inputs tensor. Shaped [..., dim]. Returns: Output after applying batch normalization, with the same shape as 'inputs'. """ p = self.params inputs_dtype = inputs.dtype inputs = tf.cast(inputs, p.dtype) inputs = py_utils.with_dependencies([ py_utils.assert_shape_match([tf.shape(inputs)[-1]], tf.shape( theta.beta)) ], inputs) with tf.name_scope(p.name) as scope: if self.do_eval: outputs = tf.nn.batch_normalization(inputs, theta.moving_mean, theta.moving_variance, theta.beta, theta.gamma, p.epsilon) else: mean, variance = self._Moments(inputs, p.bn_group_size) mean = py_utils.CheckNumerics( mean, 'mean of {} failed numeric check'.format(scope)) variance = py_utils.CheckNumerics( variance, 'variance of {} failed numeric check'.format(scope)) outputs = tf.nn.batch_normalization(inputs, mean, variance, theta.beta, theta.gamma, p.epsilon) outputs.set_shape(inputs.get_shape()) return tf.cast(outputs, inputs_dtype)
def ReOrderHyps(x_in): """Reorders x_in based on prev hyp ids.""" if isinstance(x_in, tf.Tensor) and x_in.shape.ndims > 0: # For rank > 1 tensors we make use of an efficient matmul based gather # on tpu that takes in account the range of the values. For R1, we # rely on the tf.gather and xla to optimize it efficiently for R1 # layout. if x_in.shape.ndims > 1: if p.batch_major_state: num_hyps = tf.shape(old_hyp_ids)[0] x_out = beam_search_tpu_ops.fast_gather( x_in, old_hyp_ids, num_hyps, max_value=None, batch_major_state=p.batch_major_state) else: # Use corrected indices only here for batch major compute as # key/value caches are the states being affected. correct_old_hyp_ids = (old_hyp_ids_in_cache_order if p.batch_major_compute else old_hyp_ids) def _GatherStep(x_in, t): """Gather for one time step. Args: x_in: in the shape of [T, B, ...] we first get slice(t) from the tensors, then gather old_hyp_ids from the slice and write the interpolated slice inplace to update the original x_in. t: current time step Returns: Updated x_in and time step """ x = tf.gather(tf.gather(x_in, t), correct_old_hyp_ids) return inplace_ops.alias_inplace_update( x_in, t, x), t + 1 x_out, _ = tf.while_loop( lambda _, t: t <= cur_step, _GatherStep, (x_in, tf.zeros([], tf.int32))) else: x_out = tf.gather(x_in, old_hyp_ids) x_out.set_shape(x_in.get_shape()) return x_out else: return x_in
def SplitTensors(xs, num_splits): """Splits tensors in `xs` evenly into num_splits along the 1st dimenion. Args: xs: A tuple of tensors. Each tensor's 1st dimension is the same size. num_splits: A python integer. Returns: A tuple of lists of tensors, num elements in the tuple = len(xs). i-th element in each list corresponds to i-th split of each tensor in xs along the first dimension of each tensor. """ # assert first dim of all tensors in xs is equal batch_dims = [tf.shape(x)[0] for x in xs] all_batch_dims = tf.stack(batch_dims) all_batch_dims = py_utils.with_dependencies([ py_utils.assert_equal(all_batch_dims, tf.shape(xs[0])[0], message='first dim of tensors in xs must match'), py_utils.assert_greater_equal( tf.shape(xs[0])[0], num_splits, message='first dim of tensors in xs must be greater than num_splits' ) ], all_batch_dims) splits = ComputeSplits(tf.shape(xs[0])[0], num_splits) # add the above assertion into the compute graph splits = py_utils.with_dependencies([all_batch_dims], splits) split_xs = [ tf.split(axis=0, num_or_size_splits=splits, value=x) for x in xs ] return split_xs
def SequenceLength(padding): """Computes the length of a sequence based on binary padding. Args: padding: A tensor of binary paddings shaped [batch, seqlen]. Returns: seq_lens, A tensor of shape [batch] containing the non-padded length of each element of plot_tensor along the batch dimension. """ seq_lens = tf.cast(tf.round(tf.reduce_sum(1 - padding, axis=1)), tf.int32) # Get rid of any extra dimensions. batch_size = tf.shape(padding)[0] seq_lens = tf.reshape(seq_lens, [batch_size], name='seq_lens') return seq_lens
def ApplyBias(): """Bias and update log_probs and consistent.""" def TileForBeamAndFlatten(tensor): tensor = tf.reshape(tensor, [1, -1]) # [1, src_batch] tensor = tf.tile( tensor, [num_hyps_per_beam, 1]) # [num_hyps_per_beam, src_batch] tgt_batch = tf.shape(step_ids)[0] # num_hyps_per_beam*src_batch return tf.reshape(tensor, [tgt_batch]) # Consistent if step_ids == labels from previous step # TODO(navari): Consider updating consistent only if weights > 0. Then # re-evaluate the need for bias_only_if_consistent=True. # Note that prev_label is incorrrect for step 0 but is overridden later prev_label = TileForBeamAndFlatten( tf.gather(labels, tf.maximum(time_step - 1, 0), axis=1)) is_step0 = tf.equal(time_step, 0) local_consistence = tf.math.logical_or( is_step0, tf.equal(prev_label, tf.squeeze(step_ids, 1))) consistent = tf.math.logical_and(states.consistent, local_consistence) # get label, weight slices corresponding to current time_step label = TileForBeamAndFlatten(tf.gather(labels, time_step, axis=1)) weight = TileForBeamAndFlatten(tf.gather(weights, time_step, axis=1)) if p.bias_only_if_consistent: weight = weight * tf.cast(consistent, p.dtype) # convert from dense label to sparse label probs vocab_size = tf.shape(bs_results.log_probs)[1] uncertainty = tf.constant( 1e-10, p.dtype) # avoid 0 probs which may cause issues with log label_probs = tf.one_hot( label, vocab_size, on_value=1 - uncertainty, off_value=uncertainty / tf.cast(vocab_size - 1, p.dtype), dtype=p.dtype) # [tgt_batch, vocab_size] pred_probs = tf.exp(bs_results.log_probs) # interpolate predicted probs and label probs weight = tf.expand_dims(weight, 1) probs = py_utils.with_dependencies([ py_utils.assert_less_equal(weight, 1.), py_utils.assert_greater_equal(weight, 0.) ], (1.0 - weight) * pred_probs + weight * label_probs) return tf.math.log(probs), consistent
def _FrequencyMask(self, inputs, global_seed, dtype=tf.float32, domain_id_index=0): """Applies frequency masking with given degree to inputs. Args: inputs: Batch of input features of shape (batch_size, time_length, num_freq, channels). global_seed: an integer seed tensor for stateless random ops. dtype: Data type. domain_id_index: domain id index. Returns: Inputs with random frequency masking applied. """ p = self.params # Mask parameters. freq_mask_max_bins = p.freq_mask_max_bins[domain_id_index] multiplicity = p.freq_mask_count[domain_id_index] # If masking length or count is zero, do nothing. if freq_mask_max_bins == 0 or multiplicity == 0: return inputs # Arguments to pass to mask generator. batch_size, _, num_freq, _ = py_utils.GetShape(inputs) choose_range = tf.cast(tf.broadcast_to(num_freq, (batch_size, )), dtype=tf.int32) # Create masks in frequency direction and apply. block_arrays = self._GetMask(tf.shape(inputs)[0], choose_range=choose_range, mask_size=num_freq, global_seed=global_seed, max_length=freq_mask_max_bins, masks_per_frame=0.0, multiplicity=multiplicity, dtype=dtype, max_ratio=1.0) return self.EinsumBxycByBxyc(inputs, block_arrays)
def FProp(self, theta, inputs): """Apply projection to inputs. Args: theta: A NestedMap object containing weights' values of this layer and its children layers. inputs: The inputs tensor. Shaped [..., input_dims]. Returns: Projected inputs. """ p = self.params with tf.name_scope(p.name): computation_cost.Add( self, 'flops', tf.reduce_prod(tf.cast(tf.shape(inputs)[:-1], tf.int64)) * tf.cast(symbolic.ToTensor(p.input_dims * p.output_dims), tf.int64) * 2) return py_utils.ProjectLastDim(inputs, theta.w, p.input_dims, p.output_dims)
def GreedySearchDecode(self, theta, encoder_outputs, init_beam_search_state=None, pre_beam_search_step_callback=None, post_beam_search_step_callback=None, max_steps=None): """Performs greedy-search based decoding. Args: theta: A NestedMap object containing weights' values of the decoder layer and its children layers. encoder_outputs: A NestedMap containing encoder outputs to be passed to the callbacks. init_beam_search_state: The `InitBeamSearchState` callback. Please refer to the class header comments for more details. pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback. Please refer to the class header comments for more details. post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback. Please refer to the class header comments for more details. max_steps: maximum beam search steps. If None, use self.params.target_seq_len. Returns: A tuple (hyp_ids, hyp_lens, done_hyps). Note that num_hyps is same as src_batch_size. - hyp_ids: [num_hyps, max_step]. Hyps end with <eos> token if the <eos> token is encountered during search. - hyp_lens: [num_hyps]. - done_hyps: [num_hyps], whether or not an eos is encountered. """ p = self.params if max_steps is None: max_steps = p.target_seq_len initial_results, other_states = init_beam_search_state( theta, encoder_outputs, 1 # num_hyps_per_beam ) num_hyps = tf.shape(initial_results.log_probs)[0] if 'step_ids' in initial_results: # [num_hyps, 1] step_ids = tf.ensure_shape(initial_results.step_ids, [None, 1]) else: step_ids = tf.fill([num_hyps, 1], tf.constant(p.target_sos_id, dtype=tf.int32)) cur_step = tf.constant(0, dtype=tf.int32) done_hyps = inplace_ops.empty(shape=[num_hyps], dtype=tf.bool, init=True, name='done_hyps') hyp_lens = inplace_ops.empty(shape=[num_hyps], dtype=tf.int32, init=True, name='hyp_lens') hyp_ids = inplace_ops.empty(shape=[max_steps, num_hyps], dtype=tf.int32, init=True, name='hyp_ids') def LoopContinue(cur_step, unused_step_ids, unused_hyp_ids, unused_hyp_lens, done_hyps, unused_other_states_list): return tf.math.logical_and( cur_step < max_steps, tf.math.logical_not(tf.reduce_all(done_hyps))) def LoopBody(cur_step, step_ids, hyp_ids, hyp_lens, done_hyps, other_states_list): (cur_step, new_step_ids, hyp_ids, hyp_lens, done_hyps, new_other_states) = self._GreedySearchStep( theta, encoder_outputs, cur_step, step_ids, hyp_ids, hyp_lens, done_hyps, other_states.Pack(other_states_list), pre_beam_search_step_callback, post_beam_search_step_callback) return (cur_step, new_step_ids, hyp_ids, hyp_lens, done_hyps, new_other_states.Flatten()) flat_other_states = other_states.Flatten() _, _, final_hyp_ids, final_hyp_lens, final_done_hyps, _ = tf.while_loop( LoopContinue, LoopBody, loop_vars=(cur_step, step_ids, hyp_ids, hyp_lens, done_hyps, flat_other_states), parallel_iterations=10, back_prop=False, swap_memory=False, shape_invariants=(tf.TensorShape(cur_step.get_shape()), tf.TensorShape(step_ids.get_shape()), tf.TensorShape(hyp_ids.get_shape()), tf.TensorShape(hyp_lens.get_shape()), tf.TensorShape(done_hyps.get_shape()), _GetShapes(flat_other_states, none_shapes=True))) # transpose hyp_ids so it matches BeamSearchDecode's output final_hyp_ids = tf.transpose(final_hyp_ids) return final_hyp_ids, final_hyp_lens, final_done_hyps
def MergeBeamSearchOutputs(max_hyps_per_beam, beam_search_outputs): """Merges beam search hyps from multiple decoders. Args: max_hyps_per_beam: the number of top hyps in the merged results. Must be less than or equal to total number of input hyps. beam_search_outputs: a list of BeamSearchDecodeOutput objects. Must share the same source_batch and max sequence length. Returns: A BeamSearchDecodeOutput object containing max_hyps_per_beam hypotheses per beam. """ source_batch = tf.shape(beam_search_outputs[0].topk_hyps)[0] value_dict = {} for output in beam_search_outputs: hyps_per_beam = py_utils.with_dependencies([ py_utils.assert_equal(source_batch, tf.shape(output.topk_hyps)[0]), ], tf.shape( output.topk_hyps)[1]) for k, v in six.iteritems(output._asdict()): if v is None: continue if k == 'done_hyps': v = tf.transpose(v) if k not in value_dict: value_dict[k] = [] value_dict[k].append( tf.reshape(v, [source_batch, hyps_per_beam, -1])) # Concatenate the tensors along the 'num_hyps_per_beam' dimension. concatenated = {} for k, values in six.iteritems(value_dict): if len(values) != len(beam_search_outputs): raise ValueError('Incomplete values for %s: %s' % (k, beam_search_outputs)) concatenated[k] = tf.concat(values, axis=1) scores = concatenated['topk_scores'] scores = tf.where(tf.equal(concatenated['topk_lens'], 0), tf.fill(tf.shape(scores), -1e6), scores) scores = tf.squeeze(scores, -1) # Select top max_hyps_per_beam indices per beam. _, top_indices = tf.nn.top_k(scores, max_hyps_per_beam) batch_ids = tf.tile(tf.expand_dims(tf.range(source_batch), -1), [1, max_hyps_per_beam]) # [source_batch, max_hyps_per_beam, 2] gather_indices = tf.stack([batch_ids, top_indices], axis=-1) # Gather the merged top hyps according to 'gather_indices'. top = beam_search_outputs[0]._asdict() total_hyps = source_batch * max_hyps_per_beam for k, v in six.iteritems(concatenated): v = tf.gather_nd(v, gather_indices) if k == 'done_hyps': v = tf.transpose(tf.reshape(v, [total_hyps, -1])) elif k == 'topk_hyps': v = tf.reshape(v, [source_batch, max_hyps_per_beam]) elif k == 'topk_ids': v = tf.reshape(v, [total_hyps, -1]) elif k in ('topk_lens', 'topk_scores', 'topk_decoded'): v = tf.reshape(v, [total_hyps]) else: raise ValueError('Unexpected field: %s' % k) top[k] = v return BeamSearchDecodeOutput(**top)
def BeamSearchDecode(self, theta, encoder_outputs, num_hyps_per_beam_override=0, init_beam_search_state=None, pre_beam_search_step_callback=None, post_beam_search_step_callback=None, max_steps=None): """Performs beam-search based decoding. Args: theta: A NestedMap object containing weights' values of the decoder layer and its children layers. encoder_outputs: A NestedMap containing encoder outputs to be passed to the callbacks. Mostly opaque to BeamSearchHelper, except that it should contain either a 'seq_lengths' field of shape [source_batch_size] or a 'paddings' field of shape [source_max_lengths, source_batch_size]. num_hyps_per_beam_override: If set to a value <= 0, this parameter is ignored. If set to a value > 0, then this value will be used to override `p.num_hyps_per_beam`. init_beam_search_state: The `InitBeamSearchState` callback. Please refer to the class header comments for more details. pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback. Please refer to the class header comments for more details. post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback. Please refer to the class header comments for more details. max_steps: maximum beam search steps. If None, use self.params.target_seq_len. Returns: A `BeamSearchDecodeOutput`. """ p = self.params num_hyps_per_beam = p.num_hyps_per_beam if num_hyps_per_beam_override > 0: num_hyps_per_beam = num_hyps_per_beam_override if max_steps is None: max_steps = p.target_seq_len initial_results, other_states = init_beam_search_state( theta, encoder_outputs, num_hyps_per_beam) num_hyps = tf.shape(initial_results.log_probs)[0] num_beams = num_hyps // num_hyps_per_beam if 'step_ids' in initial_results: # [num_hyps, 1] step_ids = tf.ensure_shape(initial_results.step_ids, [None, 1]) else: step_ids = tf.fill([num_hyps, 1], tf.constant(p.target_sos_id, dtype=tf.int32)) min_score = -1e36 best_scores = (tf.zeros(shape=[num_beams], dtype=p.dtype) + min_score) cumulative_scores = tf.zeros(shape=[num_hyps], dtype=p.dtype) in_scores = tf.zeros([max_steps, num_hyps], dtype=p.dtype) in_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.int32) in_prev_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.int32) in_done_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.string) bs_atten_probs = tf.zeros( [max_steps, num_hyps, tf.shape(initial_results.atten_probs)[1]], dtype=p.dtype) cur_step = tf.constant(0, dtype=tf.int32) all_done = tf.constant(False, dtype=tf.bool) core_bs_states = (best_scores, cumulative_scores, in_scores, in_hyps, in_prev_hyps, in_done_hyps, bs_atten_probs) def LoopContinue(cur_step, all_done, unused_step_ids, unused_core_bs_states, unused_other_states_list): return tf.math.logical_and(cur_step < max_steps, tf.math.logical_not(all_done)) def LoopBody(cur_step, unused_all_done, step_ids, core_bs_states, other_states_list): (cur_step, all_done, new_step_ids, new_bs_states, new_other_states) = self._BeamSearchStep( theta, encoder_outputs, cur_step, step_ids, core_bs_states, other_states.Pack(other_states_list), num_hyps_per_beam, pre_beam_search_step_callback, post_beam_search_step_callback) return (cur_step, all_done, new_step_ids, new_bs_states, new_other_states.Flatten()) flat_other_states = other_states.Flatten() _, _, _, final_bs_states, flat_final_other_states = tf.while_loop( LoopContinue, LoopBody, loop_vars=(cur_step, all_done, step_ids, core_bs_states, flat_other_states), parallel_iterations=10, back_prop=False, swap_memory=False, shape_invariants=(tf.TensorShape(cur_step.get_shape()), tf.TensorShape(all_done.get_shape()), tf.TensorShape(step_ids.get_shape()), _GetShapes(core_bs_states), _GetShapes(flat_other_states, none_shapes=True))) # [target_seq_len, num_beams * num_hyps_per_beam]. final_done_hyps = final_bs_states[5] final_other_states = other_states.Pack(flat_final_other_states) # Assume that `paddings` has shape [source_max_lengths, source_batch_size] # by default, and compute `encoded_seq_lengths` accordingly. This can be # overridden by directly passing `seq_lengths` in the `encoder_outputs` # NestedMap. encoded_seq_lengths = getattr(encoder_outputs, 'seq_lengths', None) if encoded_seq_lengths is None: source_paddings = encoder_outputs.padding if isinstance(source_paddings, py_utils.NestedMap): encoded_seq_lengths = tf.cast( tf.round( tf.reduce_sum( 1.0 - tf.transpose(source_paddings.Flatten()[0]), 1)), tf.int32) else: encoded_seq_lengths = tf.cast( tf.round( tf.reduce_sum(1.0 - tf.transpose(source_paddings), 1)), tf.int32) # [num_beams, num_hyps_per_beam]. topk_hyps = ops.top_k_terminated_hyps( final_done_hyps, encoded_seq_lengths, k=num_hyps_per_beam, num_hyps_per_beam=num_hyps_per_beam, length_normalization=p.length_normalization, coverage_penalty=p.coverage_penalty, target_seq_length_ratio=p.target_seq_length_ratio, eoc_id=p.target_eoc_id, merge_paths=p.merge_paths) # [num_beams * num_hyps_per_beam, ...]. max_seq_length = 0 if isinstance(max_steps, tf.Tensor) else max_steps topk_ids, topk_lens, topk_scores = ops.unpack_hyp( tf.reshape(topk_hyps, [-1]), max_seq_length=max_seq_length) # [num_beams, num_hyps_per_beam]. topk_scores = tf.reshape(topk_scores, tf.shape(topk_hyps)) return BeamSearchDecodeOutput(final_done_hyps, topk_hyps, topk_ids, topk_lens, topk_scores, None, final_other_states)
def _BeamSearchStep(self, theta, encoder_outputs, cur_step, step_ids, core_bs_states, other_states, num_hyps_per_beam, pre_beam_search_step_callback, post_beam_search_step_callback): """Extend beam search hyps for one step. | num_beams = Number of source sequences to be decoded. | num_hyps_per_beam = Number of hyps to keep per source sequence. | num_hyps = num_beams * num_hyps_per_beam | src_seq_len = Number of time steps in the source sequence. | src_batch = Number of examples in the source sequence. | tgt_seq_len = Maximum allowed time steps in the target sequence. | tgt_batch = num_hyps_per_beam * src_batch Args: theta: A `.NestedMap` object containing weights' values of the decoder layer and its children layers. encoder_outputs: A `.NestedMap` containing encoder outputs to be passed to the callbacks. cur_step: A scalar int tensor, the current time step, 0-based. step_ids: An int tensor of shape [num_hyps, 1]. The input ids to the current search step. core_bs_states: A tuple of core beam search states. This list is maintained by this helper class. other_states: A `.NestedMap` of other beam search states. This `.NestedMap` is managed and updated by the client. It is expected that each of its member tensors are of rank >= 1. t[i, ...] is the state of the i-th hyp at the beginning of this search step. num_hyps_per_beam: Num of hyps to keep per beam. pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback. See class header comments for more details. post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback. See class header comments for more details. Returns: A tuple of following elements for the next beam search step, (next step, all_done, step_ids, core_bs_states, other_states) """ p = self.params bs_results, other_states = pre_beam_search_step_callback( theta, encoder_outputs, step_ids, other_states, num_hyps_per_beam) (best_scores, cumulative_scores, in_scores, in_hyps, in_prev_hyps, in_done_hyps, in_atten_probs) = core_bs_states (out_best_scores, out_cumulative_scores, out_scores, out_hyps, out_prev_hyps, out_done_hyps, out_atten_probs, all_done) = ops.beam_search_step( tf.cast(bs_results.log_probs, dtype=p.dtype), tf.cast(bs_results.atten_probs, dtype=p.dtype), best_scores, cumulative_scores, in_scores, in_hyps, in_prev_hyps, in_done_hyps, in_atten_probs, bs_results.is_last_chunk if self._model_uses_eoc_id else [], cur_step, eoc_id=p.target_eoc_id, eos_id=p.target_eos_id, beam_size=p.beam_size, num_hyps_per_beam=num_hyps_per_beam, valid_eos_max_logit_delta=p.valid_eos_max_logit_delta, merge_paths=p.merge_paths, allow_empty_terminated_hyp=p.allow_empty_terminated_hyp, ensure_full_beam=p.ensure_full_beam, force_eos_in_last_step=p.force_eos_in_last_step, local_eos_threshold=p.local_eos_threshold) new_step_ids = tf.reshape(out_hyps[cur_step, :], tf.shape(step_ids)) new_step_ids.set_shape(step_ids.get_shape()) old_hyp_ids = tf.reshape( tf.slice(out_prev_hyps, begin=[cur_step, 0], size=[1, -1]), [-1]) if p.batch_major_compute: # Transformed the indices into the key/value cache for fast decoding # (prefix_states in other_states) due to the num_hyps dimension of # cache is computed as num_beams by num_hyps_per_beam, which is different # from the old_hyp_ids assumption (num_hyps_per_beam by num_beams). # Both transpose and recomputation are required to correct the indices. num_beams = tf.shape(best_scores)[0] old_hyp_ids_in_cache_order = tf.reshape( tf.transpose(tf.reshape(old_hyp_ids, [num_hyps_per_beam, -1])), [-1]) old_hyp_ids_in_cache_order = ( (old_hyp_ids_in_cache_order % num_beams) * num_hyps_per_beam + old_hyp_ids_in_cache_order // num_beams) new_bs_states = (out_best_scores, out_cumulative_scores, out_scores, out_hyps, out_prev_hyps, out_done_hyps, out_atten_probs) def ReOrderHyps(x_in): """Reorders x_in based on prev hyp ids.""" if (isinstance(x_in, tf.Tensor) and x_in.shape.ndims and x_in.shape.ndims > 0): if x_in.shape.ndims > 2 and not p.batch_major_state: # Use corrected indices only here for batch major compute as key/value # caches are the states being affected. correct_old_hyp_ids = (old_hyp_ids_in_cache_order if p.batch_major_compute else old_hyp_ids) x_out = tf.gather(x_in, correct_old_hyp_ids, axis=1) else: x_out = tf.gather(x_in, old_hyp_ids) x_out.set_shape(x_in.get_shape()) return x_out else: return x_in new_other_states = other_states.Transform(ReOrderHyps) final_other_states = post_beam_search_step_callback( theta, encoder_outputs, new_step_ids, new_other_states) return (cur_step + 1, all_done, new_step_ids, new_bs_states, final_other_states)
def TileForBeamAndFlatten(tensor): tensor = tf.reshape(tensor, [1, -1]) # [1, src_batch] tensor = tf.tile( tensor, [num_hyps_per_beam, 1]) # [num_hyps_per_beam, src_batch] tgt_batch = tf.shape(step_ids)[0] # num_hyps_per_beam*src_batch return tf.reshape(tensor, [tgt_batch])
def PadToTargetSeqLen(tensor, constant): length = tf.shape(tensor)[1] pad = tf.maximum(0, p.beam_search.target_seq_len - length) return tf.pad(tensor, [[0, 0], [0, pad]], constant_values=constant)
def FProp(self, theta, input_batch): """Embeds source ids and transforms with TransformerStack. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. input_batch: A `.NestedMap` with fields: - ids: The inputs tensor. It is expected to be of shape [batch, time]. - paddings: The paddings tensor. Expected shape [batch, time]. - task_ids: If p.task_emb is provided, must contain per-token task ids of shape [batch, time]. Returns: A NestedMap containing - encoded: The encoded features, either a tensor of shape [time, batch, depth], or a list of tensors if is_transparent is set in transformer_stack. - padding: of shape [time, batch] - segment_id: [time, batch] if packed inputs are supported by the model (and all layers), or None otherwise. - embedded_inputs: [time, batch, depth] embedded inputs tokens without positional encodings. """ p = self.params with tf.name_scope(p.name): src_segment_id = None src_segment_pos = None input_ids = py_utils.with_dependencies([ py_utils.assert_shape_match(tf.shape(input_batch.ids), tf.shape(input_batch.paddings)), py_utils.assert_equal(tf.rank(input_batch.ids), 2) ], input_batch.ids) if (not py_utils.use_tpu() and tf.flags.FLAGS.transformer_encoder_truncates_inputs): max_seq_length = tf.cast( tf.reduce_max(tf.reduce_sum(1.0 - input_batch.paddings, 1)), tf.int32) paddings = py_utils.with_dependencies([ py_utils.assert_equal( tf.constant(True, tf.bool), tf.reduce_all( input_batch.paddings[:, max_seq_length:] > 0.5)) ], input_batch.paddings) input_ids = input_ids[:, :max_seq_length] paddings = paddings[:, :max_seq_length] if p.packed_input: src_segment_id = input_batch.segment_ids[:, : max_seq_length] src_segment_pos = input_batch.segment_pos[:, : max_seq_length] else: paddings = input_batch.paddings if p.packed_input: src_segment_id = input_batch.segment_ids src_segment_pos = input_batch.segment_pos max_time = tf.shape(input_ids)[1] # Input token embeddings + positional embeddings if not p.shared_emb: input_embs = self.token_emb.EmbLookup( theta.token_emb, tf.reshape(input_ids, [-1])) else: input_embs = self.softmax.EmbLookup( theta.softmax, tf.reshape(input_ids, [-1])) input_embs = tf.reshape(input_embs, [-1, max_time, p.token_emb.embedding_dim]) # [time, batch, dim] orig_input_embs = tf.transpose(input_embs, [1, 0, 2]) if p.packed_input: position_embs = self.position_emb.FPropWithPosition( theta.position_emb, src_segment_pos) else: position_embs = self.position_emb.FProp( theta.position_emb, max_time) position_embs = tf.reshape( position_embs, [1, max_time, p.token_emb.embedding_dim]) input_embs += position_embs if p.task_emb: input_embs += self.task_emb.EmbLookup(theta.task_emb, input_batch.task_ids) if p.model_dim != p.token_emb.embedding_dim: input_embs = self.emb_proj.FProp(theta.emb_proj, input_embs) paddings = tf.cast(tf.transpose(paddings), py_utils.FPropDtype(p)) if p.packed_input: src_segment_id = tf.transpose(src_segment_id) input_embs = self.input_dropout.FProp(theta.input_dropout, input_embs) # [time, batch, dim] transformer_input = tf.transpose(input_embs, [1, 0, 2]) if not self.do_eval and p.apply_source_mask: # Augment padding for masked source word positions. dtype = paddings.dtype source_mask = tf.where(tf.equal(input_ids, p.source_mask_id), tf.ones_like(input_ids, dtype=dtype), tf.zeros_like(input_ids, dtype=dtype)) # Make sure padding is between 0 and 1. paddings = tf.clip_by_value(paddings + tf.transpose(source_mask), 0.0, 1.0) encoded, padding, segment_id = self.transformer_stack.FProp( theta.transformer_stack, transformer_input, paddings, src_segment_id) return py_utils.NestedMap(encoded=encoded, padding=padding, segment_id=segment_id, embedded_inputs=orig_input_embs)
def BeamSearchDecodePostProcess(self, num_hyps_per_beam, max_steps, r1_shape, r2_shape, r3_shape, hyps, prev_hyps, done_hyps, scores, atten_probs, eos_scores, eos_atten_probs, source_seq_lengths, *flat_final_other_states): """Beam search post processing functions on CPUs. Args: num_hyps_per_beam: Number of hyps per beam. max_steps: Maximum number of beam search steps. r1_shape: A tensor of shape [1] with value [time]. r2_shape: A tensor of shape [2] with values [time, b * k]. r3_shape: A tensor of shape [3] with values [time, b * k, seq_len]. hyps: A tensor of shape [1] with ids of the token selected. prev_hyps: A tensor of shape [time * b * k] with index to the previous hyps which was selected. done_hyps: A boolean tensor of shape [time * b * k] where value indicates if hyps was terminated. scores: A tensor of shape [time * b * k] with scores of the token selected. atten_probs: A tensor of shape [time * b * k, seq_len] which contain the attention probabilities over the source words against word in the previous hyps. eos_scores: A tensor of shape [time * b * k] with scores of the eos token selected. eos_atten_probs: A tensor of shape [time * b * k, seq_len] which contain the attention probabilities over the source words against word in the previous hyps. source_seq_lengths: A tensor of shape [time] containing the source seq_lengths. *flat_final_other_states: A array of tensors that are part of other states. Returns: final_done_hyps: A tensor of shape [time, b * k] containing `Hypothesis` pbs containing terminated hyps. topk_hyps, topk_ids, topk_lens, topk_scores: Top K terminated Hyps. flat_final_other_states: A array of tensors that are part of other states. """ p = self.params def _ReshapeBackToHigherRank(inps, r_shape): for i in range(len(inps)): inps[i] = tf.reshape(inps[i], r_shape) return inps # Reshape all tensors back to original shapes of rank 1, 2 and 3. r1_inps = [source_seq_lengths] r1_inps = _ReshapeBackToHigherRank(r1_inps, r1_shape) r2_inps = [hyps, prev_hyps, done_hyps, scores, eos_scores] r2_inps = _ReshapeBackToHigherRank(r2_inps, r2_shape) r3_inps = [atten_probs, eos_atten_probs] r3_inps = _ReshapeBackToHigherRank(r3_inps, r3_shape) (source_seq_lengths, hyps, prev_hyps, done_hyps, scores, eos_scores, atten_probs, eos_atten_probs) = (r1_inps + r2_inps + r3_inps) final_done_hyps = ops.hyps_from_beam_search_outs( hyps, prev_hyps, done_hyps, scores, atten_probs, eos_scores, eos_atten_probs, eos_id=p.target_eos_id, num_hyps_per_beam=num_hyps_per_beam) topk_hyps = ops.top_k_terminated_hyps( final_done_hyps, source_seq_lengths, k=num_hyps_per_beam, num_hyps_per_beam=num_hyps_per_beam, length_normalization=p.length_normalization, coverage_penalty=p.coverage_penalty, target_seq_length_ratio=p.target_seq_length_ratio, eoc_id=p.target_eoc_id, merge_paths=p.merge_paths) topk_ids, topk_lens, topk_scores = ops.unpack_hyp( topk_hyps, max_seq_length=max_steps) # [num_beams, num_hyps_per_beam]. topk_scores = tf.reshape(topk_scores, tf.shape(topk_hyps)) return (final_done_hyps, topk_hyps, topk_ids, topk_lens, topk_scores) + tuple(flat_final_other_states)