def Bak(inputs, outputs, d_outputs): """Backward step.""" del inputs # unused output_acts, step_seeds = outputs d_outputs = d_outputs[0] d_layer_thetas = [] for layer_idx in reversed(range(num_layers)): f_seed, g_seed = step_seeds[layer_idx] layer = self.sub_layers[layer_idx] layer_theta = theta.sub_layers[layer_idx] input_acts, d_inputs, d_theta = layer.ReverseAndGrad( layer_theta, output_acts, d_outputs, f_seed, g_seed, *extra_inputs) d_layer_thetas.append(d_theta) # Passes reconstructed inputs to the previous layer. output_acts = input_acts d_outputs = d_inputs py_utils.ResetStepSeed(final_step_seed) d_theta = py_utils.NestedMap( global_step=tf.zeros_like(initial_step_seed)) d_theta.sub_layers = list(reversed(d_layer_thetas)) extra_grads = [tf.zeros_like(t) for t in extra_inputs] return [ tf.zeros_like(initial_step_seed), d_theta, d_inputs, extra_grads ]
def _Proc(record): """Parses a serialized tf.Example record.""" outputs = [ ('inputs', tf.io.VarLenFeature(tf.int64)), ('targets', tf.io.VarLenFeature(tf.int64)), # Default eval weight to 1.0 ('eval_weight', tf.io.FixedLenFeature([], tf.float32, default_value=1.0)), ] features = tf.io.parse_single_example(record, dict(outputs)) for k, v in six.iteritems(features): if k != 'eval_weight': features[k] = v.values else: eval_weight = v src_ids = features['inputs'] tgt_labels = features['targets'] # Derive trivial segmentation for unpacked input. src_paddings, tgt_ids, tgt_paddings, tgt_weights, bucket_key = _DerivePaddingsAndIds( src_ids, tgt_labels) src_len = tf.shape(src_ids)[0] tgt_len = tf.shape(tgt_ids)[0] src_pos = tf.range(src_len, dtype=tf.int32) src_seg = tf.zeros_like(src_paddings) tgt_pos = tf.range(tgt_len, dtype=tf.int32) tgt_seg = tf.zeros_like(tgt_paddings) return [ src_ids, src_paddings, tgt_ids, tgt_paddings, tgt_labels, tgt_weights, src_pos, src_seg, tgt_pos, tgt_seg, eval_weight ], bucket_key
def _Moments(inputs, mask, enable_cross_replica_sum_on_tpu=False): """Computes mean and variance over the valid data points in inputs.""" inputs = py_utils.with_dependencies([ py_utils.assert_equal(tf.rank(inputs), tf.rank(mask)), py_utils.assert_greater_equal(mask, tf.zeros_like(mask)), ], inputs) rank = tf.rank(mask) reduce_over_dims = tf.range(0, rank - 1) sum_v = tf.reduce_sum(inputs * tf.cast(mask, inputs.dtype), reduce_over_dims) count_v = tf.reduce_sum(mask, reduce_over_dims) # Input shape is guaranteed to be a multiple of mask shape because the # inputs * mask op above was successfully broadcasted. mask_multiplier = tf.shape(inputs)[:-1] // tf.shape(mask)[:-1] count_v *= tf.cast(tf.reduce_prod(mask_multiplier), count_v.dtype) if py_utils.use_tpu() and enable_cross_replica_sum_on_tpu: sum_v = tf.tpu.cross_replica_sum(sum_v) count_v = tf.tpu.cross_replica_sum(count_v) count_v = tf.maximum(count_v, 1.0) mean = sum_v / count_v sum_vv = tf.reduce_sum((inputs - mean) * (inputs - mean) * mask, reduce_over_dims) if py_utils.use_tpu() and enable_cross_replica_sum_on_tpu: sum_vv = tf.tpu.cross_replica_sum(sum_vv) variance = py_utils.with_dependencies([ py_utils.assert_greater_equal(sum_vv, tf.zeros_like(sum_vv)), ], sum_vv / count_v) return mean, variance
def FProp(self, theta, inputs, paddings, domain_ids=None): """Applies data augmentation by randomly mask spectrum in inputs. Args: theta: A NestedMap object containing weights' values of this layer and its children layers. inputs: A tensor of shape [batch, time, freq, num_channels]. paddings: A 0/1 tensor of shape [batch, time]. domain_ids: input domain_ids of shape [batch, time]. Returns: A pair of 2 tensors: - augmented_inputs: A tensor of shape [batch, time, freq, num_channels]. - paddings: A 0/1 tensor of shape [batch, time]. """ p = self.params global_seed = None # A tensor seed in case stateless random ops are needed. if p.use_input_dependent_random_seed: global_seed = _global_seed_from_inputs(inputs) batch_size, series_length, _, _ = py_utils.GetShape(inputs) if len(p.domain_ids) > 1: augmented_inputs = tf.zeros_like(inputs) original_inputs = inputs for i, domain_id in enumerate(p.domain_ids): augmented_domain = self._AugmentationNetwork( series_length, inputs, paddings, global_seed=global_seed, domain_id_index=i) target_domain = tf.cast(tf.expand_dims( tf.tile([domain_id], [batch_size]), -1), dtype=p.dtype) # [batch, time]. domain_mask = tf.cast(tf.equal(domain_ids, target_domain), dtype=p.dtype) augmented_domain = self.EinsumBxycBxBxyc( augmented_domain, domain_mask, name='einsum_domainmasking') original_inputs = self.EinsumBxycBxBxyc( original_inputs, 1.0 - domain_mask, name='einsum_domainmasking2') augmented_inputs = augmented_domain + augmented_inputs augmented_inputs = original_inputs + augmented_inputs else: augmented_inputs = self._AugmentationNetwork( series_length, inputs, paddings, global_seed=global_seed, domain_id_index=0) return augmented_inputs, paddings
def _ApplyAndReset(): with tf.control_dependencies([ self._opt.Apply( lr, py_utils.ApplyGradMultiplier(var_grad, 1. / p.accum_steps)) ]): return tf.group(*[ tf.assign(a, tf.zeros_like(a)) for _, a in var_grad.Flatten() ])
def PostTrainingStepUpdate(self, global_step): """Updates moving_mean, moving_variance after each training step.""" p = self.params # Get sufficient stats that accumulates over microbatches. counts = self.accumulators.counts.GetValue() mean_ss = self.accumulators.mean_ss.GetValue() variance_ss = self.accumulators.variance_ss.GetValue() # Compute batch mean and batch variance from sufficient stats mean, variance = tf.nn.normalize_moments(counts, mean_ss, variance_ss, None) decay = tf.convert_to_tensor(1.0 - p.decay, p.dtype) # Update moving_mean, moving_variance from batch mean and batch variance. with tf.name_scope(p.name) as scope: with tf.ops.colocate_with(self.vars.moving_mean): mean_update = tf.assign_sub( self.vars.moving_mean, tf.where( tf.greater(counts, 0.5), (self.vars.moving_mean - tf.cast(mean, p.dtype)) * decay, tf.zeros_like(self.vars.moving_mean)), name='moving_mean_update') with tf.ops.colocate_with(self.vars.moving_variance): var_update = tf.assign_sub( self.vars.moving_variance, tf.where( tf.greater(counts, 0.5), (self.vars.moving_variance - tf.cast(variance, p.dtype)) * decay, tf.zeros_like(self.vars.moving_variance)), name='moving_variance_update') py_utils.CheckNumerics( self.vars.moving_mean, 'moving mean of {} failed numeric check'.format(scope)) py_utils.CheckNumerics( self.vars.moving_variance, 'moving variance of {} failed numeric check'.format(scope)) self.accumulators.counts.Reset() self.accumulators.mean_ss.Reset() self.accumulators.variance_ss.Reset() return tf.group(mean_update, var_update)
def FProp(self, theta, inputs, paddings=None): """Apply batch normalization. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. inputs: The inputs tensor. Shaped [..., dim]. paddings: The paddings tensor. Shaped [..., 1], with the same rank as the input tensor. Returns: Output after applying batch normalization, with the same shape as 'inputs'. """ p = self.params if paddings is None: paddings = self._GetDefaultPaddings(inputs) with tf.name_scope(p.name): norm_mean, norm_variance, beta, gamma = self.ComputeAndUpdateMoments( theta, inputs, paddings) with tf.control_dependencies([ py_utils.assert_greater_equal(norm_variance, tf.zeros_like(norm_variance)), py_utils.assert_shape_match([tf.shape(inputs)[-1]], tf.shape(norm_mean)), py_utils.assert_shape_match([tf.shape(inputs)[-1]], tf.shape(norm_variance)), ]): if p.use_fused_batch_norm_for_eval and self.do_eval: bn_output, _, _ = nn.fused_batch_norm( inputs, gamma, beta, norm_mean, norm_variance, self._epsilon, is_training=False) else: bn_output = tf.nn.batch_normalization(inputs, norm_mean, norm_variance, beta, gamma, self._epsilon) if p.set_padded_output_to_zero: bn_output *= 1.0 - paddings return bn_output
def _TokenizeOneSentence(i, strs, token_ids_ta, target_ids_ta, paddings_ta): """Tokenizes a single sentence.""" ids, _ = self._wpm_encoder.Encode(strs[i]) if append_eos: ids = tf.concat([ids, [self.eos_id]], axis=0) # This truncates after the eos is added, so some sentences might # not have </s> at the end. token_ids_ta = token_ids_ta.write( i, py_utils.PadOrTrimTo(tf.concat([[self.sos_id], ids], axis=0), [max_length], self.eos_id)) target_ids_ta = target_ids_ta.write( i, py_utils.PadOrTrimTo(ids, [max_length], self.eos_id)) paddings_ta = paddings_ta.write( i, py_utils.PadOrTrimTo(tf.zeros_like(ids, dtype=tf.float32), [max_length], 1.)) return i + 1, strs, token_ids_ta, target_ids_ta, paddings_ta
def FProp(self, theta, input_batch): """Embeds source ids and transforms with TransformerStack. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. input_batch: A `.NestedMap` with fields: - ids: The inputs tensor. It is expected to be of shape [batch, time]. - paddings: The paddings tensor. Expected shape [batch, time]. - task_ids: If p.task_emb is provided, must contain per-token task ids of shape [batch, time]. Returns: A NestedMap containing - encoded: The encoded features, either a tensor of shape [time, batch, depth], or a list of tensors if is_transparent is set in transformer_stack. - padding: of shape [time, batch] - segment_id: [time, batch] if packed inputs are supported by the model (and all layers), or None otherwise. - embedded_inputs: [time, batch, depth] embedded inputs tokens without positional encodings. """ p = self.params with tf.name_scope(p.name): src_segment_id = None src_segment_pos = None input_ids = py_utils.with_dependencies([ py_utils.assert_shape_match(tf.shape(input_batch.ids), tf.shape(input_batch.paddings)), py_utils.assert_equal(tf.rank(input_batch.ids), 2) ], input_batch.ids) if (not py_utils.use_tpu() and tf.flags.FLAGS.transformer_encoder_truncates_inputs): max_seq_length = tf.cast( tf.reduce_max(tf.reduce_sum(1.0 - input_batch.paddings, 1)), tf.int32) paddings = py_utils.with_dependencies([ py_utils.assert_equal( tf.constant(True, tf.bool), tf.reduce_all( input_batch.paddings[:, max_seq_length:] > 0.5)) ], input_batch.paddings) input_ids = input_ids[:, :max_seq_length] paddings = paddings[:, :max_seq_length] if p.packed_input: src_segment_id = input_batch.segment_ids[:, : max_seq_length] src_segment_pos = input_batch.segment_pos[:, : max_seq_length] else: paddings = input_batch.paddings if p.packed_input: src_segment_id = input_batch.segment_ids src_segment_pos = input_batch.segment_pos max_time = tf.shape(input_ids)[1] # Input token embeddings + positional embeddings if not p.shared_emb: input_embs = self.token_emb.EmbLookup( theta.token_emb, tf.reshape(input_ids, [-1])) else: input_embs = self.softmax.EmbLookup( theta.softmax, tf.reshape(input_ids, [-1])) input_embs = tf.reshape(input_embs, [-1, max_time, p.token_emb.embedding_dim]) # [time, batch, dim] orig_input_embs = tf.transpose(input_embs, [1, 0, 2]) if p.packed_input: position_embs = self.position_emb.FPropWithPosition( theta.position_emb, src_segment_pos) else: position_embs = self.position_emb.FProp( theta.position_emb, max_time) position_embs = tf.reshape( position_embs, [1, max_time, p.token_emb.embedding_dim]) input_embs += position_embs if p.task_emb: input_embs += self.task_emb.EmbLookup(theta.task_emb, input_batch.task_ids) if p.model_dim != p.token_emb.embedding_dim: input_embs = self.emb_proj.FProp(theta.emb_proj, input_embs) paddings = tf.cast(tf.transpose(paddings), py_utils.FPropDtype(p)) if p.packed_input: src_segment_id = tf.transpose(src_segment_id) input_embs = self.input_dropout.FProp(theta.input_dropout, input_embs) # [time, batch, dim] transformer_input = tf.transpose(input_embs, [1, 0, 2]) if not self.do_eval and p.apply_source_mask: # Augment padding for masked source word positions. dtype = paddings.dtype source_mask = tf.where(tf.equal(input_ids, p.source_mask_id), tf.ones_like(input_ids, dtype=dtype), tf.zeros_like(input_ids, dtype=dtype)) # Make sure padding is between 0 and 1. paddings = tf.clip_by_value(paddings + tf.transpose(source_mask), 0.0, 1.0) encoded, padding, segment_id = self.transformer_stack.FProp( theta.transformer_stack, transformer_input, paddings, src_segment_id) return py_utils.NestedMap(encoded=encoded, padding=padding, segment_id=segment_id, embedded_inputs=orig_input_embs)
def FProp(self, theta, x, x_paddings=None, eos_id=1, force_sample_last_token=True): """Applies SymbolInsertionLayer. We take in a `x`, which represents the groundtruth sequence (i.e., English sequence). We return a sampled rollin (observed) canvas (i.e., random subset of the English sequence), as well as the target (indices) for an insertion-based model (i.e., the targets given the random observed subset). Args: theta: Ignored, this can be None. x: The symbol ids of shape `[batch_size, time_dim]`. x_paddings: The paddings (1 or 0) of shape `[batch_size, time_dim]` where 0 is valid and 1 is invalid. eos_id: The <eos> token id to represent end-of-slot. force_sample_last_token: Set True to force sample the last token of `x`. Returns: A `NestedMap`. - canvas: The canvas (based off of the `rollin_policy`) of shape [batch_size, c_dim]. Note that, `c_dim` <= `time_dim` but need not be equal. - canvas_indices: The canvas indices (into `x`). - canvas_paddings: The paddings of `canvas_indices`. - target_indices: The target indices of shape [num_targets, 3]. `num_targets` is the number of total targets in the entire batch. [:, 0] captures the batch, [:, 1] captures the slot, and [:, 2] captures the token. Each row [batch, slot, vocab] represents the indices of the target -- i.e., the batch, slot and vocab combination of the target. Typical usage of these indices is to tf.gather_nd the log-probs (from the softmax layer). - target_weights: The target weights. Raises: ValueError: If invalid params. """ p = self.params batch_size = py_utils.GetShape(x)[0] time_dim = py_utils.GetShape(x)[1] if x_paddings is None: x_paddings = tf.zeros([batch_size, time_dim], tf.float32) oracle_policy = p.oracle_policy rollin_policy = (oracle_policy if p.rollin_policy == 'oracle' else p.rollin_policy) if rollin_policy != 'uniform': raise ValueError('Unknown or unsupported rollin policy: %s' % rollin_policy) if oracle_policy != 'uniform': raise ValueError('Unknown or unsupported oracle policy: %s' % oracle_policy) x_len = tf.cast(tf.round(tf.reduce_sum(1 - x_paddings, 1)), tf.int32) # Compute the desired length per example in the batch. ratio = tf.random.uniform([batch_size], 0.0, 1.0, seed=p.random_seed) if force_sample_last_token: c_len = tf.minimum( tf.cast(ratio * tf.cast(x_len, tf.float32), tf.int32), x_len - 1) + 1 else: c_len = tf.minimum( tf.cast(ratio * tf.cast(x_len + 1, tf.float32), tf.int32), x_len) # Compute the maximum length across the batch. c_len_max = tf.reduce_max(c_len) # Grab subset of random valid indices per example. z_logits = tf.cast( tf.expand_dims(tf.range(time_dim), 0) >= tf.expand_dims(x_len, 1), tf.float32) * -1e9 if force_sample_last_token: # Force sample the last token -- i.e., as indexed by `x_len - 1`. We can # accomplish this by add +LARGE_NUMBER to the logits. z_logits += tf.cast( tf.equal(tf.expand_dims(tf.range(time_dim), 0), tf.expand_dims(x_len - 1, 1)), tf.float32) * 1e9 # Gumbel-max trick to sample (we only sample valid positions per sample in # the batch). z = -tf.math.log(-tf.math.log( tf.random.uniform([batch_size, time_dim], seed=p.random_seed))) unused_c_values, c_indices = tf.nn.top_k(z_logits + z, time_dim) # Trim everything > c_len_max. c_indices = c_indices[:, :c_len_max] # Invalidate any indices >= c_len, we use the last index as the default # invalid index. c_indices = tf.where( tf.expand_dims(tf.range(c_len_max), 0) < tf.expand_dims(c_len, 1), c_indices, tf.fill(py_utils.GetShape(c_indices), time_dim - 1)) # Materialize the canvas. c_indices = tf.sort(c_indices) c = tf.gather_nd( x, tf.stack([ tf.reshape( tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, c_len_max]), [-1]), tf.reshape(c_indices, [-1]) ], 1)) c = tf.reshape(c, [batch_size, c_len_max]) # Compute the paddings. c_paddings = 1 - tf.sequence_mask( c_len, c_len_max, dtype=x_paddings.dtype) c *= tf.cast(1 - c_paddings, tf.int32) indices = tf.concat([ tf.reshape( tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, c_len_max]), [batch_size * c_len_max, 1]), tf.reshape(c_indices, [batch_size * c_len_max, 1]) ], 1) x_token_is_observed = tf.scatter_nd( indices, tf.ones([batch_size * c_len_max], tf.int32), py_utils.GetShape(x)) # `x_segments` captures which slot each `x` belongs to (both observed and # tokens that need to be observed). x_segments = tf.cumsum(x_token_is_observed, 1, exclusive=True) x_token_is_observed = tf.cast(x_token_is_observed, tf.bool) prev_x_token_is_observed = tf.pad(x_token_is_observed[:, :-1], [[0, 0], [1, 0]], constant_values=True) x_token_is_observed = tf.reshape(x_token_is_observed, [-1]) prev_x_token_is_observed = tf.reshape(prev_x_token_is_observed, [-1]) x_is_valid = tf.cast(1 - x_paddings, tf.bool) x_is_valid = tf.reshape(x_is_valid, [-1]) # Remap all the observed to <eos>, note some of these need a zero weight # (or else there would be <eos> and valid token in the same slot). target_indices = tf.cast(tf.reshape(x, [-1, 1]), tf.int32) target_indices = tf.where( x_token_is_observed, tf.fill(py_utils.GetShape(target_indices), eos_id), target_indices) # TODO(williamchan): We give uniform 1.0 weight, however, math suggests # we may want to weigh this term by the original sequence length. target_weights = tf.ones_like(target_indices, tf.float32) # We need to set all the weights for <eos> which actually have valid tokens # in the slot to zero. target_weights = tf.where( x_token_is_observed & ~prev_x_token_is_observed, tf.zeros_like(target_weights), target_weights) # TODO(williamchan): Consider dropping the entries w/ weight zero. # Add the batch and slot indices. target_indices = tf.concat([ tf.reshape( tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, time_dim]), [batch_size * time_dim, 1]), tf.reshape(x_segments, [-1, 1]), target_indices ], 1) # Select only the valid indices. The selected valid ones include slots w/ # <eos>. target_indices = target_indices[x_is_valid] target_weights = target_weights[x_is_valid] return py_utils.NestedMap(canvas=c, canvas_indices=c_indices, canvas_paddings=c_paddings, target_indices=target_indices, target_weights=target_weights)
def _CreateCanvasAndTargets(self, batch): # pyformat: disable """Create the canvas and targets. Args: batch: A `.NestedMap`. - src: A `.NestedMap`. - ids: The source ids, ends in <eos>. - paddings: The source paddings. - tgt: A `.NestedMap`. - ids: The target ids, ends in <eos>. - paddings: The target paddings. Returns: A `NestedMap`. - canvas: The canvas (based off of the `rollin_policy`) of shape [batch_size, c_dim]. - canvas_paddings: The paddings of `canvas_indices`. - target_indices: The target indices (i.e., use these indices to tf.gather_nd the log-probs). Optional, only during training. - target_weights: The target weights. Optional, only during training. """ # pyformat: enable p = self.params if not self.do_eval: # Sample our src and tgt canvas. src_descriptor = self._SampleCanvasAndTargets( batch.src.ids, batch.src.paddings) tgt_descriptor = self._SampleCanvasAndTargets( batch.tgt.ids, batch.tgt.paddings) # Offset the src ids (to unshare embeddings between src/tgt). Note, we # only offset the canvas ids, but we do not offset the vocab ids. This # will result in unshared embeddings, but shared softmax. This is due to # GPU/TPU memory limitations, empirically it is known that unsharing # everything results in better performance. vocab_size = p.decoder.softmax.num_classes src_descriptor.canvas = tf.where( tf.equal(src_descriptor.canvas_paddings, 0), src_descriptor.canvas + vocab_size, src_descriptor.canvas) # Offset the tgt indices (need shift according to src length). batch_size = py_utils.GetShape(batch.src.ids)[0] # `target_batch` is a [num_targets, batch_size] tensor where each row # identifies which batch the target belongs to. Note the observation that, # tf.reduce_sum(target_batch, 1) == 1 \forall rows. target_batch = tf.cast( tf.equal( tf.expand_dims(tf.range(batch_size), 0), tf.expand_dims(tgt_descriptor.target_indices[:, 0], 1)), tf.int32) src_lens = tf.cast( tf.reduce_sum(1 - src_descriptor.canvas_paddings, 1), tf.int32) # `tgt_offset` is shape [num_targets] where each entry corresponds to the # offset needed for that target (due to the source length). tgt_offset = tf.matmul(target_batch, tf.expand_dims(src_lens, 1)) # We shift the tgt slot without touching the batch or vocab. tgt_descriptor.target_indices += tf.concat([ tf.zeros_like(tgt_offset), tgt_offset, tf.zeros_like(tgt_offset) ], 1) # The canvas is simply the sequence-level concat of the src and tgt. canvas, canvas_paddings = insertion.SequenceConcat( src_descriptor.canvas, src_descriptor.canvas_paddings, tgt_descriptor.canvas, tgt_descriptor.canvas_paddings) target_indices = tf.concat( [src_descriptor.target_indices, tgt_descriptor.target_indices], 0) target_weights = tf.concat( [src_descriptor.target_weights, tgt_descriptor.target_weights], 0) return py_utils.NestedMap(canvas=canvas, canvas_paddings=canvas_paddings, target_indices=target_indices, target_weights=target_weights)
def FProp(self, theta, inputs, paddings, state0=None, segment_id=None): """Computes LSTM forward pass. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. inputs: A single tensor or a tuple of tensors with cardinality equal to rnn_cell.inputs_arity. For every input tensor, the first dimension is assumed to be time, second dimension batch, and third dimension depth. paddings: A tensor. First dim is time, second dim is batch, and third dim is expected to be 1. state0: If not None, the initial rnn state in a `.NestedMap`. Defaults to the cell's zero-state. segment_id: A tensor to support packed inputs. First dim is time, second dim is batch, and third dim is expected to be 1. Returns: A tensor of [time, batch, dims]. The final recurrent state. """ p = self.params rcell = self.cell assert isinstance(rcell, (rnn_cell.RNNCell)) if not isinstance(inputs, (list, tuple)): inputs = [inputs] # Slicing wm to wm_{i,h} outside the loop to get 20% speedup over regular # LSTM baseline. # Keeping slicing within the loop gives only < 3% speedup. cell_theta = theta.cell.copy() num_input_nodes = p.cell.num_input_nodes cell_theta['wm_i'] = cell_theta.wm[:num_input_nodes, :] cell_theta['wm_h'] = cell_theta.wm[num_input_nodes:, :] tf.logging.vlog(1, 'cell_theta: %r', cell_theta) if p.packed_input: assert segment_id is not None reset_mask = rnn_layers.GeneratePackedInputResetMask( segment_id, is_reverse=False) reset_mask = py_utils.HasShape(reset_mask, tf.shape(paddings)) else: reset_mask = tf.zeros_like(paddings) if p.reverse: inputs = [tf.reverse(x, [0]) for x in inputs] paddings = tf.reverse(paddings, [0]) reset_mask = tf.reverse(reset_mask, [0]) if not state0: batch_size = py_utils.GetShape(paddings)[1] state0 = rcell.zero_state(cell_theta, batch_size) # [T, B, H] proj_inputs = rcell.ProjectInputSequence( cell_theta, py_utils.NestedMap(act=inputs)) proj_inputs = py_utils.NestedMap(proj_inputs=proj_inputs, padding=paddings, reset_mask=reset_mask) acc_state, final_state = recurrent.Recurrent( theta=cell_theta, state0=state0, inputs=proj_inputs, cell_fn=rcell.FPropWithProjectedInput, cell_type=rcell.layer_type, accumulator_layer=self, allow_implicit_capture=p.allow_implicit_capture) act = rcell.GetOutput(acc_state) if p.reverse: act = tf.reverse(act, [0]) return act, final_state
def _ComputePaddings(ids, eos_id): is_eos = tf.cast(tf.equal(ids, eos_id), tf.int32) # eos_in_prefix[i, j] = any(ids[i, k] == eos_id for k in range(j)) eos_in_prefix = tf.cumsum(is_eos, axis=-1, exclusive=True) return tf.where(tf.equal(eos_in_prefix, 0), tf.zeros_like(ids), tf.ones_like(ids))
def ReverseAndGrad(self, theta, outputs, d_outputs, f_seed, g_seed, *extra_inputs): """Implements Algorithm 1 in the revnet paper. Args: theta: A NestedMap object containing weights' values of this layer and its children layers. outputs: A NestedMap: .split1 and .split2 corresponding to y1 and y2. d_outputs: A NestedMap: .split1 and .split2 corresponding to dy1 and dy2, the total derivatives. f_seed: Scalar tensor. The step seed used in forward for the f block. g_seed: Scalar tensor. The step seed used in forward for the g block. The step seeds are needed for deterministic randomness, e.g. to ensure dropout generate the same random mask in forward and reverse_grad. *extra_inputs: additional inputs that will be passed to both f and g. No gradient will be computed for these inputs. Returns: A tuple of NestedMaps - inputs: .split1 and .split2 corresponding to x1 and x2. - d_inputs: .split1 and .split2 corresponding to dx1 and dx2, the total derivatives with respect to inputs. - d_theta: has the same structure as theta. The total derivatives with respect to weights. """ # Stop gradient on the outputs to avoid circular symbolic dependency. y1 = tf.stop_gradient(outputs.split1) y2 = tf.stop_gradient(outputs.split2) dy1 = d_outputs.split1 dy2 = d_outputs.split2 # Computes the reverse. z1 = y1 py_utils.ResetStepSeed(g_seed) gz1 = self.g_block.FProp(theta.g_block, z1, *extra_inputs) x2 = y2 - gz1 py_utils.ResetStepSeed(f_seed) fx2 = self.f_block.FProp(theta.f_block, x2, *extra_inputs) x1 = z1 - fx2 # Computes the gradients. dz1 = dy1 + tf.gradients(gz1, z1, dy2)[0] dx2 = dy2 + tf.gradients(fx2, x2, dz1)[0] dgw = tf.gradients(gz1, theta.g_block.Flatten(), dy2, unconnected_gradients=tf.UnconnectedGradients.ZERO) dgw = theta.g_block.Pack(dgw) dfw = tf.gradients(fx2, theta.f_block.Flatten(), dz1, unconnected_gradients=tf.UnconnectedGradients.ZERO) dfw = theta.f_block.Pack(dfw) return (py_utils.NestedMap(split1=x1, split2=x2), py_utils.NestedMap(split1=dz1, split2=dx2), py_utils.NestedMap(f_block=dfw, g_block=dgw, global_step=tf.zeros_like( theta.global_step)))