def _ProcessSingleInput(self, source_id, src, tgt): """Performs strings-to-ids on the given input pair via p.tokenizer_dict.""" _, src_labels, src_paddings = self.StringsToIds( tf.reshape(src, [1]), is_source=True, key=self._src_tokenizer_key) tgt_ids, tgt_labels, tgt_paddings = self.StringsToIds( tf.reshape(tgt, [1]), is_source=False, key=self._tgt_tokenizer_key) # Mask positions to 0 where padding is 1 for consistency. We do this because # tokenizer implementation may use EOS token to pad. src_labels = py_utils.ApplyPadding(src_paddings, src_labels) tgt_ids = py_utils.ApplyPadding(tgt_paddings, tgt_ids) tgt_labels = py_utils.ApplyPadding(tgt_paddings, tgt_labels) features = py_utils.NestedMap() features.src = py_utils.NestedMap() features.src.ids = src_labels # ids_indicator is 1 if and only if the output from tokenizer has a # non-padded id. Unlike weights, it will not mutate and can be used for # determining actual sequence length, for example. features.src.ids_indicator = 1 - src_paddings features.tgt = py_utils.NestedMap() features.tgt.ids = tgt_ids features.tgt.labels = tgt_labels features.tgt.ids_indicator = 1 - tgt_paddings src_task_id, tgt_task_id = self._GetTaskIds(source_id) # task_ids are padded with zeros. features.src.task_ids = tf.cast( features.src.ids_indicator, dtype=tf.int32) * src_task_id features.tgt.task_ids = tf.cast( features.tgt.ids_indicator, dtype=tf.int32) * tgt_task_id if not py_utils.use_tpu(): features.src.strs = src features.tgt.strs = tgt return features.Transform(tf.squeeze)
def SequenceAppendToken(x, x_paddings, token, extend=False): """Appends <token> to sequence `x`. Args: x: A sequence of tokens of shape [batch_size, x_len_max]. x_paddings: The paddings of `x`. token: The token to append (of type integer). extend: Whether to extend `x` along the length dimension, this must be true for any sequence length in `x` that is `x_len_max` or else an invalid sequence will be emitted. Returns: A tuple. - The new sequence, Tensor of shape [batch_size, x_len_max]. - The new paddings, Tensor of shape [batch_size, x_len_max]. """ batch_size = py_utils.GetShape(x)[0] x_len = tf.cast(tf.round(tf.reduce_sum(1 - x_paddings, 1)), tf.int32) if extend: x = tf.pad(x, [[0, 0], [0, 1]]) # Mask all invalid entries of `x` to 0. x *= tf.sequence_mask(x_len, py_utils.GetShape(x)[1], x.dtype) # Append the <token> based on `x_len`. x += tf.scatter_nd(tf.stack([tf.range(batch_size), x_len], axis=1), tf.cast(tf.fill([batch_size], token), x.dtype), py_utils.GetShape(x)) x_paddings = 1 - tf.sequence_mask(x_len + 1, py_utils.GetShape(x)[1], x_paddings.dtype) return x, x_paddings
def _InputBatch(self): ret = py_utils.NestedMap() ret.bucket_keys = self._bucket_keys ret.src = py_utils.NestedMap() ret.src.ids = tf.cast(self._src_ids, dtype=tf.int32) ret.src.paddings = self._src_paddings ret.tgt = py_utils.NestedMap() ret.tgt.ids = self._tgt_ids ret.tgt.labels = tf.cast(self._tgt_labels, dtype=tf.int32) ret.tgt.weights = self._tgt_weights ret.tgt.paddings = self._tgt_paddings if (self.params.fprop_dtype is None or self.params.dtype == self.params.fprop_dtype): return ret def _Cast(v): if not v.dtype.is_floating: return v return tf.cast(v, self.params.fprop_dtype) return ret.Transform(_Cast)
def _Moments(inputs, mask, enable_cross_replica_sum_on_tpu=False): """Computes mean and variance over the valid data points in inputs.""" inputs = py_utils.with_dependencies([ py_utils.assert_equal(tf.rank(inputs), tf.rank(mask)), py_utils.assert_greater_equal(mask, tf.zeros_like(mask)), ], inputs) rank = tf.rank(mask) reduce_over_dims = tf.range(0, rank - 1) sum_v = tf.reduce_sum(inputs * tf.cast(mask, inputs.dtype), reduce_over_dims) count_v = tf.reduce_sum(mask, reduce_over_dims) # Input shape is guaranteed to be a multiple of mask shape because the # inputs * mask op above was successfully broadcasted. mask_multiplier = tf.shape(inputs)[:-1] // tf.shape(mask)[:-1] count_v *= tf.cast(tf.reduce_prod(mask_multiplier), count_v.dtype) if py_utils.use_tpu() and enable_cross_replica_sum_on_tpu: sum_v = tf.tpu.cross_replica_sum(sum_v) count_v = tf.tpu.cross_replica_sum(count_v) count_v = tf.maximum(count_v, 1.0) mean = sum_v / count_v sum_vv = tf.reduce_sum((inputs - mean) * (inputs - mean) * mask, reduce_over_dims) if py_utils.use_tpu() and enable_cross_replica_sum_on_tpu: sum_vv = tf.tpu.cross_replica_sum(sum_vv) variance = py_utils.with_dependencies([ py_utils.assert_greater_equal(sum_vv, tf.zeros_like(sum_vv)), ], sum_vv / count_v) return mean, variance
def FProp(self, theta, current_step): """Returns the current learning rate decay.""" p = self.params current_step = tf.cast(current_step, tf.float32) warmup_steps = tf.cast(p.warmup_steps, tf.float32) linear_warmup = tf.minimum(1.0, current_step / warmup_steps) rsqrt_decay = tf.math.rsqrt(tf.maximum(current_step, warmup_steps)) return p.model_dim**-0.5 * linear_warmup * rsqrt_decay
def FProp(self, theta, current_step): """Returns the current learning rate decay.""" p = self.params current_step = tf.cast(current_step, tf.float32) warmup_steps = tf.cast( p.warmup_examples / (p.batch_size * self._num_replicas), tf.float32) return tf.minimum((current_step + 1) * warmup_steps**-1.5, (current_step + 1)**-0.5)
def _InputBatch(self): length = tf.reduce_prod(self.shape) counter = summary_utils.StatsCounter('CountingInputGenerator') new_value = tf.cast(counter.IncBy(length), dtype=tf.int32) - length new_value = tf.stop_gradient(new_value) values = new_value + tf.range(length) shaped_values = tf.reshape(tf.cast(values, dtype=tf.float32), self.shape) targets = tf.reduce_sum(shaped_values, axis=0) return py_utils.NestedMap(src_ids=shaped_values, tgt_ids=targets)
def FProp(self, theta, current_step): """Returns the current learning rate decay.""" p = self.params current_step = tf.cast(current_step, tf.float32) warmup_steps = tf.cast(p.warmup_steps * p.worker_replicas, tf.float32) if p.decay_end is not None: current_step = tf.where(current_step < p.decay_end, current_step, tf.cast(p.decay_end, tf.float32)) return p.model_dim**-0.5 * tf.minimum( (current_step + 1) * warmup_steps**-1.5, (current_step + 1)**-0.5)
def ComputePredictions(self, theta, batch): # pyformat: disable """Compute the model predictions. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. batch: A `.NestedMap`. - src: A `.NestedMap`. - ids: The source ids, ends in <eos>. - paddings: The source paddings. - tgt: A `.NestedMap`. - ids: The target ids, ends in <eos>. - paddings: The target paddings. Returns: A `.NestedMap`. - outputs: The contextualized output vectors of shape [batch_size, time_dim, model_dim]. - tgt: A `.NestedMap` (optional, only during training). - ids: The canvas ids. - paddings: The canvas paddings. - target_indices: The target indices. - target_weights: The target weights. """ # pyformat: enable p = self.params # TODO(williamchan): Currently, we only support KERMIT mode (i.e., no # encoder, unified architecture). assert not p.encoder # Sometimes src and tgt have different types. We reconcile here and use # int32. batch.src.ids = tf.cast(batch.src.ids, tf.int32) batch.tgt.ids = tf.cast(batch.tgt.ids, tf.int32) canvas_and_targets = self._CreateCanvasAndTargets(batch) batch = py_utils.NestedMap(tgt=py_utils.NestedMap( ids=canvas_and_targets.canvas, paddings=canvas_and_targets.canvas_paddings)) predictions = super(InsertionModel, self).ComputePredictions(theta, batch) if not self.do_eval: predictions.tgt = py_utils.NestedMap( ids=canvas_and_targets.canvas, paddings=canvas_and_targets.canvas_paddings, target_indices=canvas_and_targets.target_indices, target_weights=canvas_and_targets.target_weights) return predictions
def FProp(self, theta, inputs, paddings, domain_ids=None): """Applies data augmentation by randomly mask spectrum in inputs. Args: theta: A NestedMap object containing weights' values of this layer and its children layers. inputs: A tensor of shape [batch, time, freq, num_channels]. paddings: A 0/1 tensor of shape [batch, time]. domain_ids: input domain_ids of shape [batch, time]. Returns: A pair of 2 tensors: - augmented_inputs: A tensor of shape [batch, time, freq, num_channels]. - paddings: A 0/1 tensor of shape [batch, time]. """ p = self.params global_seed = None # A tensor seed in case stateless random ops are needed. if p.use_input_dependent_random_seed: global_seed = _global_seed_from_inputs(inputs) batch_size, series_length, _, _ = py_utils.GetShape(inputs) if len(p.domain_ids) > 1: augmented_inputs = tf.zeros_like(inputs) original_inputs = inputs for i, domain_id in enumerate(p.domain_ids): augmented_domain = self._AugmentationNetwork( series_length, inputs, paddings, global_seed=global_seed, domain_id_index=i) target_domain = tf.cast(tf.expand_dims( tf.tile([domain_id], [batch_size]), -1), dtype=p.dtype) # [batch, time]. domain_mask = tf.cast(tf.equal(domain_ids, target_domain), dtype=p.dtype) augmented_domain = self.EinsumBxycBxBxyc( augmented_domain, domain_mask, name='einsum_domainmasking') original_inputs = self.EinsumBxycBxBxyc( original_inputs, 1.0 - domain_mask, name='einsum_domainmasking2') augmented_inputs = augmented_domain + augmented_inputs augmented_inputs = original_inputs + augmented_inputs else: augmented_inputs = self._AugmentationNetwork( series_length, inputs, paddings, global_seed=global_seed, domain_id_index=0) return augmented_inputs, paddings
def _GreedySearchStep(self, theta, encoder_outputs, cur_step, step_ids, hyp_ids, hyp_lens, done_hyps, other_states, pre_beam_search_step_callback, post_beam_search_step_callback): """Extend greedy search hyps for one step. Args: theta: A `.NestedMap` object containing weights' values of the decoder layer and its children layers. encoder_outputs: A `.NestedMap` containing encoder outputs to be passed to the callbacks. cur_step: A scalar int tensor, the current time step, 0-based. step_ids: An int tensor of shape [num_hyps, 1]. The input ids to the current search step. hyp_ids: An int tensor of shape [num_hyps, tgt_seq_len]. hyp_lens: Valid length of all the hyps. Tokens after eos ids are not counted. done_hyps: Whether or not a hyp has finished. other_states: A `.NestedMap` of other beam search states. This `.NestedMap` is managed and updated by the client. It is expected that each of its member tensors are of rank >= 1. t[i, ...] is the state of the i-th hyp at the beginning of this search step. pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback. See class header comments for more details. post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback. See class header comments for more details. Returns: A tuple of following elements for the next greedy search step, (next step, new_step_ids, hyp_ids, hyp_lens, done_hyps, other_states) """ p = self.params # Increment hyp_lens by 1 if the hyp is not finished yet. hyp_lens = hyp_lens + (1 - tf.cast(done_hyps, tf.int32)) bs_results, new_other_states = pre_beam_search_step_callback( theta, encoder_outputs, step_ids, other_states, 1) # num_hyps_per_beam new_step_ids = tf.math.argmax(bs_results.log_probs, 1) new_step_ids = tf.cast(new_step_ids, tf.int32) new_step_ids = tf.reshape(new_step_ids, tf.shape(step_ids)) final_other_states = post_beam_search_step_callback( theta, encoder_outputs, new_step_ids, new_other_states) # Stash new_step_ids into the right slot. new_step_ids_1d = tf.reshape(new_step_ids, [-1]) hyp_ids = inplace_ops.alias_inplace_update(hyp_ids, cur_step, new_step_ids_1d) # Update done_hyps if the current step_ids is the end of sequence token. done_hyps = tf.math.logical_or( done_hyps, tf.equal(new_step_ids_1d, p.target_eos_id)) return (cur_step + 1, new_step_ids, hyp_ids, hyp_lens, done_hyps, final_other_states)
def UnstackFeatures(self, src_inputs, src_paddings): """Unstacks src_input and src_paddings based off stack height.""" sh = self.params.stack_height bs, old_series_length, _, channels = py_utils.GetShape(src_inputs) unstacked_series_length = old_series_length * sh src_inputs = tf.reshape(src_inputs, [bs, unstacked_series_length, -1, channels]) content = 1 - src_paddings lengths = tf.cast(sh * tf.reduce_sum(content, axis=1), tf.int32) mask = tf.sequence_mask(lengths, maxlen=unstacked_series_length) src_paddings = 1 - tf.cast(mask, tf.int32) return src_inputs, src_paddings
def FProp(self, theta, current_step): """Returns the current learning rate decay.""" params = self.params warmup_steps = tf.cast(params.decay_start * params.worker_replicas, tf.float32) current_step = tf.cast(current_step, tf.float32) if params.decay_end is not None: current_step = tf.where(current_step < params.decay_end, current_step, tf.cast(params.decay_end, tf.float32)) peak_learning_rate = (warmup_steps**-0.5) return (params.model_dim**-0.5) * tf.minimum( tf.minimum((current_step + 1), (current_step + 1)**-0.5), peak_learning_rate)
def FProp(self, theta, current_step): p = self.params current_step = tf.cast(current_step, tf.int64) interval_starts = [0] + p.boundaries values = [] for interval_start, schedule, schedule_theta in zip( interval_starts, self.schedules, theta.schedules): relative_step = tf.maximum( tf.cast(0, current_step.dtype), current_step - tf.cast(interval_start, current_step.dtype)) values.append(schedule.FProp(schedule_theta, relative_step)) return py_utils.PiecewiseConstant(current_step, p.boundaries, values, values[0].dtype)
def _Value(self, current_step): """Returns the current clipping cap.""" p = self.params start_step = tf.cast(p.start_step, tf.float32) end_step = tf.cast(p.end_step, tf.float32) current_step = tf.cast(current_step, tf.float32) steps_ratio = ( tf.minimum(end_step - start_step, current_step - start_step) / (end_step - start_step)) rmax_tensor = (steps_ratio * p.end_cap + (1.0 - steps_ratio) * p.start_cap) return tf.cond(tf.less(current_step, p.start_step), lambda: tf.cast(p.start_cap, tf.float32), lambda: tf.cast(rmax_tensor, tf.float32))
def SequenceConcat(x, x_paddings, y, y_paddings, pad=0): """Concats sequence `x` with sequence `y`. This function is length aware (based off the paddings). Args: x: A sequence of tokens of shape [batch_size, x_len_max]. x_paddings: The paddings of `x`. y: A sequence of tokens of shape [batch_size, y_len_max]. y_paddings: The paddings of `y`. pad: The <pad> token to fill the concatenated sequence (of type integer). Returns: A tuple. - Concatenation of `x` and `y` of shape [batch_size, x_len_max + y_len_max]. - Paddings of the concatenation of shape [batch_size, x_len_max + y_len_max]. """ # Get the length (w/ eos). x_len = tf.cast(tf.round(tf.reduce_sum(1 - x_paddings, 1)), tf.int32) y_len = tf.cast(tf.round(tf.reduce_sum(1 - y_paddings, 1)), tf.int32) batch_size = py_utils.GetShape(x)[0] y_len_max = py_utils.GetShape(y)[1] # Pad `x` with necessary <pad>. x = tf.concat([x, tf.fill(py_utils.GetShape(y), pad)], 1) # Replace all <pad> with 0. x = tf.where(tf.not_equal(x, pad), x, tf.fill(py_utils.GetShape(x), 0)) # Compute the write indices of `y` in `xy`. indices = tf.stack([ tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, y_len_max]), (tf.tile(tf.expand_dims(tf.range(y_len_max), 0), [batch_size, 1]) + tf.expand_dims(x_len, 1)), ], 2) xy = x + tf.scatter_nd(indices, y, py_utils.GetShape(x)) # We need to remap all <pad> to `pad`. xy = tf.where( tf.less(tf.expand_dims(tf.range(py_utils.GetShape(xy)[1]), 0), tf.expand_dims(x_len + y_len, 1)), xy, tf.fill(py_utils.GetShape(xy), pad)) xy_paddings = 1 - tf.sequence_mask(x_len + y_len, py_utils.GetShape(xy)[1], x_paddings.dtype) return xy, xy_paddings
def _global_seed_from_inputs(input_floats): """Generates a random seed tensor based on input floats and mode key. Args: input_floats: a set of float input tensors that are derived from the input data (for example, input tokens). The important thing is that these are usually different for each batch. Returns: A tensor of shape=[2] with integer seed tensors derived from the inputs. """ timestamp = tf.math.floormod(tf.cast(tf.timestamp(), dtype=tf.int64), 10000000) input_sum = tf.cast(tf.reduce_sum(tf.math.abs(input_floats)), dtype=tf.int64) return tf.stack([timestamp + input_sum, timestamp - input_sum], axis=-1)
def _ProcessBeamSearchDecodeOut(self, input_batch, encoder_outputs, decoder_outs): self.r1_shape = decoder_outs[0] self.r2_shape = decoder_outs[1] self.r3_shape = decoder_outs[2] tf.logging.info('r1_shape: %s', self.r1_shape) tf.logging.info('r2_shape: %s', self.r2_shape) tf.logging.info('r3_shape: %s', self.r3_shape) hyps = decoder_outs[3] prev_hyps = decoder_outs[4] done_hyps = decoder_outs[5] scores = decoder_outs[6] atten_probs = decoder_outs[7] eos_scores = decoder_outs[8] eos_atten_probs = decoder_outs[9] source_seq_lengths = decoder_outs[10] tlen = tf.cast( tf.round(tf.reduce_sum(1.0 - input_batch.tgt.paddings, 1) - 1.0), tf.int32) ret_dict = { 'target_ids': input_batch.tgt.ids[:, 1:], 'eval_weight': input_batch.eval_weight, 'tlen': tlen, 'hyps': hyps, 'prev_hyps': prev_hyps, 'done_hyps': done_hyps, 'scores': scores, 'atten_probs': atten_probs, 'eos_scores': eos_scores, 'eos_atten_probs': eos_atten_probs, 'source_seq_lengths': source_seq_lengths, } return ret_dict
def Polynomial(x): """Polynomial function of x.""" p = self.params x0, y0 = p.start x1, y1 = p.limit assert x0 < x1, '%s must be < %s' % (x0, x1) x0 = tf.cast(x0, dtype=x.dtype) x1 = tf.cast(x1, dtype=x.dtype) y0 = tf.cast(y0, dtype=x.dtype) y1 = tf.cast(y1, dtype=x.dtype) f_x = ((x - x0) / (x1 - x0))**p.power y = y0 + f_x * (y1 - y0) return tf.where(x < x0, y0, tf.where(x >= x1, y1, y))
def ApplyBias(): """Bias and update log_probs and consistent.""" def TileForBeamAndFlatten(tensor): tensor = tf.reshape(tensor, [1, -1]) # [1, src_batch] tensor = tf.tile( tensor, [num_hyps_per_beam, 1]) # [num_hyps_per_beam, src_batch] tgt_batch = tf.shape(step_ids)[0] # num_hyps_per_beam*src_batch return tf.reshape(tensor, [tgt_batch]) # Consistent if step_ids == labels from previous step # TODO(navari): Consider updating consistent only if weights > 0. Then # re-evaluate the need for bias_only_if_consistent=True. # Note that prev_label is incorrrect for step 0 but is overridden later prev_label = TileForBeamAndFlatten( tf.gather(labels, tf.maximum(time_step - 1, 0), axis=1)) is_step0 = tf.equal(time_step, 0) local_consistence = tf.math.logical_or( is_step0, tf.equal(prev_label, tf.squeeze(step_ids, 1))) consistent = tf.math.logical_and(states.consistent, local_consistence) # get label, weight slices corresponding to current time_step label = TileForBeamAndFlatten(tf.gather(labels, time_step, axis=1)) weight = TileForBeamAndFlatten(tf.gather(weights, time_step, axis=1)) if p.bias_only_if_consistent: weight = weight * tf.cast(consistent, p.dtype) # convert from dense label to sparse label probs vocab_size = tf.shape(bs_results.log_probs)[1] uncertainty = tf.constant( 1e-10, p.dtype) # avoid 0 probs which may cause issues with log label_probs = tf.one_hot( label, vocab_size, on_value=1 - uncertainty, off_value=uncertainty / tf.cast(vocab_size - 1, p.dtype), dtype=p.dtype) # [tgt_batch, vocab_size] pred_probs = tf.exp(bs_results.log_probs) # interpolate predicted probs and label probs weight = tf.expand_dims(weight, 1) probs = py_utils.with_dependencies([ py_utils.assert_less_equal(weight, 1.), py_utils.assert_greater_equal(weight, 0.) ], (1.0 - weight) * pred_probs + weight * label_probs) return tf.math.log(probs), consistent
def IncBy(self, delta): """Increment the counter by delta and return the new value.""" # NOTE: We must ensure _value is computed (_var + 0) before # updating _var with delta. delta = tf.cast(delta, tf.int64) with tf.control_dependencies([self._value]): scalar(self._name, self._value) return tf.identity(tf.assign_add(self._var, delta))
def restore(self, restored_tensors, restored_shapes): restored_tensor = restored_tensors[0] if restored_shapes is not None: restored_tensor = tf.reshape(restored_tensor, restored_shapes[0]) return tf.assign( self.op, tf.cast(restored_tensor, tf.bfloat16), validate_shape=restored_shapes is None and self.op.get_shape().is_fully_defined())
def _matmul_gather(self, values, axis=0, batch_major_state=True): """Returns values gathered. Args: values: Values to gather from. axis: Axis to gather on. Defaults to 0 (rows). batch_major_state: Whether the values to gather from use batch major or not. Defaults to True. For Transformer model, batch_major_state is set to False (time is the major dim). Returns: Gathered values. Raises: NotImplemented error if axis is not 0 nor 1. """ dtype = values.dtype if dtype != tf.float32 and dtype != tf.bfloat16: values = tf.cast(values, tf.float32) if axis == 0: if values.shape.rank is not None and values.shape.rank > 2: if not batch_major_state: values = tf.transpose(values, [1, 0, 2]) results = tf.cast( tf.gather(values, tf.cast(self._ids, tf.int32)), dtype) # pylint:disable=g-long-ternary return (tf.transpose(results, [1, 0, 2]) if not batch_major_state else results) # pylint:enable=g-long-ternary else: one_hot_ids = tf.one_hot(self._ids, self._ids_size, dtype=values.dtype) return tf.cast(tf.matmul(one_hot_ids, values), dtype) elif axis == 1: one_hot_ids = tf.one_hot(self._ids, self._ids_size, dtype=values.dtype, axis=0) return tf.cast(tf.matmul(values, one_hot_ids), dtype) else: raise NotImplementedError("Only row/col-wise gather implemented.")
def _Apply(): if self.params.use_bf16_gradients_ar: return optimizer.apply_gradients( [(tf.cast(g, tf.float32), v) for (v, g) in var_grad.Flatten()], name='meta_backprop') else: return optimizer.apply_gradients( [(g, v) for (v, g) in var_grad.Flatten()], name='meta_backprop')
def FProp(self, theta, current_step): p = self.params assert p.total_steps > 0 assert p.initial_value > p.final_value with tf.name_scope(p.name): decay_gap = p.initial_value - p.final_value return p.final_value + 0.5 * decay_gap * (1 + tf.cos( math.pi * tf.minimum(1.0, tf.cast(current_step, tf.float32) / p.total_steps)))
def MakeCausalPadding(seq_len, block_size, left_context, right_context): """Makes the causal padding tensor for a full sequence. Args: seq_len: int or scalar int tensor. Sequence length. block_size: int. Number of time frames in a block. left_context: int. Left context size. right_context: int. Right context size. Returns: A tensor of [num_blocks, block_size, context_size] taking values in {0, 1}, where context_size = block_size + (left_context - 1) + right_context. Element b, i, j is zero if in the b-th block, the i-th frame can access the j-th frame in the context. """ seq_len = py_utils.with_dependencies([ py_utils.assert_greater_equal( seq_len, 1, message='seq_len must be at least 1') ], seq_len) num_blocks = (seq_len + block_size - 1) // block_size context_size = block_size + (left_context - 1) + right_context # [num_blocks, block_size]: source positions in the original sequence. src_positions = tf.reshape(tf.range(num_blocks * block_size), [num_blocks, block_size]) # [num_blocks,]: source positions at the start of each block. block_start_positions = tf.range(0, num_blocks * block_size, block_size) # [context_size]: positions relative to the block start. relative_context_positions = tf.range(context_size) - (left_context - 1) # [num_blocks, context_size]: target positions in the original sequence. tgt_positions = (block_start_positions[:, tf.newaxis] + relative_context_positions[tf.newaxis, :]) # [num_blocks, block_size, context_size]: position differences between source- # target pairs. position_diff = src_positions[:, :, tf.newaxis] - tgt_positions[:, tf.newaxis, :] # [num_blocks, block_size, context_size]: if attention is allowed between # source-target pairs. valid_atten = tf.math.logical_and(-right_context <= position_diff, position_diff < left_context) # [num_blocks, block_size]: if the source position is valid, not padded. valid_src = src_positions < seq_len # [num_blocks, context_size]: if the target position is valid, not padded. valid_tgt = tf.math.logical_and(0 <= tgt_positions, tgt_positions < seq_len) valid_atten &= tf.math.logical_and(valid_src[:, :, tf.newaxis], valid_tgt[:, tf.newaxis, :]) padding = 1.0 - tf.cast(valid_atten, dtype=tf.float32) return padding
def FProp(self, theta, inputs): """Apply projection to inputs. Args: theta: A NestedMap object containing weights' values of this layer and its children layers. inputs: The inputs tensor. Shaped [..., input_dims]. Returns: Projected inputs. """ p = self.params with tf.name_scope(p.name): computation_cost.Add( self, 'flops', tf.reduce_prod(tf.cast(tf.shape(inputs)[:-1], tf.int64)) * tf.cast(symbolic.ToTensor(p.input_dims * p.output_dims), tf.int64) * 2) return py_utils.ProjectLastDim(inputs, theta.w, p.input_dims, p.output_dims)
def ApplyClippingWithState(self, state, x): """Applies clipping to x. Args: state: Clipping state. x: Input tensor to clip. Returns: Clipped (or identity) x. """ cap = tf.cast(state, x.dtype) return tf.clip_by_value(x, -cap, cap)
def GenerateStepSeedPair(p, unused_global_step=None, op_seed=None): """Override py_utils.GenerateStepSeedPair to use GetOverWriteGlobalStep.""" seed_dtype = tf.int32 if py_utils.use_tpu() else tf.int64 if p.is_inference and p.random_seed is None: # Unlike tf.random*, stateless random ops are completely determined by the # passed-in seeds. This means at inference time the same inputs will produce # the same outputs, even if the model is supposed to have randomness such as # dropout during inference. We inject additional randomness only during # inference if the graph is exported with random_seed=None as a workaround. return tf.random.uniform([2], maxval=seed_dtype.max, dtype=seed_dtype) with tf.name_scope('op_seed') as scope: global_step = tf.cast(GetOverWriteGlobalStep(), seed_dtype) step_seed = tf.cast(py_utils.GenerateSeedFromName(scope), seed_dtype) seeds = tf.stack([global_step, step_seed]) if p.random_seed is not None: seeds += p.random_seed if op_seed is not None: seeds += op_seed return seeds
def _TimeWarp(self, inputs, seq_lengths, global_seed, dtype=tf.float32, domain_id_index=0): """Applies time warping with given degree to inputs. Args: inputs: Batch of input features of shape (batch_size, time_length, num_freq, channels). seq_lengths: The actual sequence lengths which mask been sampled of shape (batch_size,). global_seed: an integer seed tensor for stateless random ops. dtype: Data type. domain_id_index: Domain ID index. Returns: Inputs with random time warping applied. """ p = self.params batch_size, time_length, _, _ = py_utils.GetShape(inputs) # Get parameters for warping. time_warp_max_frames = p.time_warp_max_frames[domain_id_index] max_ratio = p.time_warp_max_ratio[domain_id_index] time_warp_bound = p.time_warp_bound[domain_id_index] assert time_warp_bound in ('static', 'dynamic') # If maximum warp length is zero, do nothing. if ((time_warp_max_frames == 0 and time_warp_bound == 'static') or max_ratio <= 0.0): return inputs seq_lengths = tf.cast(seq_lengths, tf.int32) # Discard upper-bound on time-warp frames when # dynamic time warping is used. if time_warp_bound == 'dynamic': time_warp_max_frames = None # Create warping matrix in time direction and apply warp_matrix = self._GetWarpMatrix(batch_size, choose_range=seq_lengths, matrix_size=time_length, global_seed=global_seed, max_warp_frames=time_warp_max_frames, dtype=dtype, max_ratio=max_ratio) return self.EinsumBxycBzxBzyc(inputs, warp_matrix, name='einsum_forwarping')