def Step(recurrent_theta, state0, inputs): """Computes one decoder step.""" del inputs with tf.name_scope('single_sampler_step'): # Compute logits and states. bs_result, bs_state1 = pre_step_callback( recurrent_theta.theta, recurrent_theta.encoder_outputs, tf.expand_dims(state0.ids, 1), # [batch, 1]. state0.bs_state, num_hyps_per_beam=1) batch = tf.shape(bs_result.log_probs)[0] state1 = py_utils.NestedMap(timestep=state0.timestep + 1) state1.logits = bs_result.log_probs # Sample ids from logits. [batch]. state1.ids = tf.reshape( tf.random.stateless_categorical( state1.logits / p.temperature, num_samples=1, seed=tf.stack( [recurrent_theta.random_seed, state0.timestep]), dtype=state0.ids.dtype, name='sample_next_id'), [batch]) if 'is_last_chunk' in bs_result and p.target_eoc_id >= 0: state1.ids = tf.where( tf.math.logical_and( bs_result.is_last_chunk, tf.equal(state1.ids, p.target_eoc_id)), tf.fill(tf.shape(state1.ids), p.target_eos_id), state1.ids) state1.bs_state = post_step_callback( recurrent_theta.theta, recurrent_theta.encoder_outputs, state1.ids, bs_state1) return state1, py_utils.NestedMap()
def SequenceAppendToken(x, x_paddings, token, extend=False): """Appends <token> to sequence `x`. Args: x: A sequence of tokens of shape [batch_size, x_len_max]. x_paddings: The paddings of `x`. token: The token to append (of type integer). extend: Whether to extend `x` along the length dimension, this must be true for any sequence length in `x` that is `x_len_max` or else an invalid sequence will be emitted. Returns: A tuple. - The new sequence, Tensor of shape [batch_size, x_len_max]. - The new paddings, Tensor of shape [batch_size, x_len_max]. """ batch_size = py_utils.GetShape(x)[0] x_len = tf.cast(tf.round(tf.reduce_sum(1 - x_paddings, 1)), tf.int32) if extend: x = tf.pad(x, [[0, 0], [0, 1]]) # Mask all invalid entries of `x` to 0. x *= tf.sequence_mask(x_len, py_utils.GetShape(x)[1], x.dtype) # Append the <token> based on `x_len`. x += tf.scatter_nd(tf.stack([tf.range(batch_size), x_len], axis=1), tf.cast(tf.fill([batch_size], token), x.dtype), py_utils.GetShape(x)) x_paddings = 1 - tf.sequence_mask(x_len + 1, py_utils.GetShape(x)[1], x_paddings.dtype) return x, x_paddings
def Update(self, new_value): state0 = self.GetValue() state1 = tf.stack([ state0[0] + new_value[0], tf.minimum(state0[1], new_value[1]), tf.maximum(state0[2], new_value[2]), ]) self.SetValue(state1)
def _MaybeStackExtraTheta(theta, all_vars, repeat): var_set = set([key for key, _ in all_vars.FlattenItems()]) values = [] for key, value in theta.FlattenItems(): if key not in var_set and value is not None: # Replicate non-variable theta by p.repeat times. value = tf.stack([value] * repeat) values.append(value) return theta.Pack(values)
def QuantizeTensors(self, t_name, ts, eval_only=False): p = self.params # Always straddle a real zero point. if self.do_eval: # At eval/inference time, use the memorized range. # Important: Don't capture these variables in training mode so as to # avoid extra/unnecessary captures. min_var = self._GetQStateVar(t_name, 'min') max_var = self._GetQStateVar(t_name, 'max') return [ self._MaybeFakeQuant(t, min_var, max_var, num_bits=p.bits) for t in ts ] else: # At training time, use the batch calculated min/max. accumulator_name = self._GetAccumulatorNameForTensor(t_name) # Calculate min/max for all tensors. batch_min = 0.0 batch_max = 0.0 for t in ts: batch_min = tf.minimum(tf.reduce_min(t), batch_min) batch_max = tf.maximum(tf.reduce_max(t), batch_max) # New state. state1 = tf.stack([1.0, batch_min, batch_max]) self.accumulators[accumulator_name].Update(state1) # Results. ts_out = [] for i, t in enumerate(ts): if eval_only: # If only quantizing at eval time, still record ranges as above # but don't quantize. quant_t = t else: # If quantizing during training, skip quantization if it produces # NANs. Sometimes early in the training process, things are unstable # and ranges can produce numerical instability that makes it # impossible to perform a fake_quant. quant_t = self._MaybeFakeQuant(t, batch_min, batch_max, num_bits=p.bits) # TODO(laurenzo): Plumb quant_t_has_nans through state and report. quant_t_has_nans = tf.math.is_nan(quant_t) quant_t = tf.where(quant_t_has_nans, t, quant_t) ts_out.append(quant_t) summary_utils.histogram( '%s/%s_%d' % (self._qvars_scope.name, t_name, i), t) return ts_out
def SequenceConcat(x, x_paddings, y, y_paddings, pad=0): """Concats sequence `x` with sequence `y`. This function is length aware (based off the paddings). Args: x: A sequence of tokens of shape [batch_size, x_len_max]. x_paddings: The paddings of `x`. y: A sequence of tokens of shape [batch_size, y_len_max]. y_paddings: The paddings of `y`. pad: The <pad> token to fill the concatenated sequence (of type integer). Returns: A tuple. - Concatenation of `x` and `y` of shape [batch_size, x_len_max + y_len_max]. - Paddings of the concatenation of shape [batch_size, x_len_max + y_len_max]. """ # Get the length (w/ eos). x_len = tf.cast(tf.round(tf.reduce_sum(1 - x_paddings, 1)), tf.int32) y_len = tf.cast(tf.round(tf.reduce_sum(1 - y_paddings, 1)), tf.int32) batch_size = py_utils.GetShape(x)[0] y_len_max = py_utils.GetShape(y)[1] # Pad `x` with necessary <pad>. x = tf.concat([x, tf.fill(py_utils.GetShape(y), pad)], 1) # Replace all <pad> with 0. x = tf.where(tf.not_equal(x, pad), x, tf.fill(py_utils.GetShape(x), 0)) # Compute the write indices of `y` in `xy`. indices = tf.stack([ tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, y_len_max]), (tf.tile(tf.expand_dims(tf.range(y_len_max), 0), [batch_size, 1]) + tf.expand_dims(x_len, 1)), ], 2) xy = x + tf.scatter_nd(indices, y, py_utils.GetShape(x)) # We need to remap all <pad> to `pad`. xy = tf.where( tf.less(tf.expand_dims(tf.range(py_utils.GetShape(xy)[1]), 0), tf.expand_dims(x_len + y_len, 1)), xy, tf.fill(py_utils.GetShape(xy), pad)) xy_paddings = 1 - tf.sequence_mask(x_len + y_len, py_utils.GetShape(xy)[1], x_paddings.dtype) return xy, xy_paddings
def _global_seed_from_inputs(input_floats): """Generates a random seed tensor based on input floats and mode key. Args: input_floats: a set of float input tensors that are derived from the input data (for example, input tokens). The important thing is that these are usually different for each batch. Returns: A tensor of shape=[2] with integer seed tensors derived from the inputs. """ timestamp = tf.math.floormod(tf.cast(tf.timestamp(), dtype=tf.int64), 10000000) input_sum = tf.cast(tf.reduce_sum(tf.math.abs(input_floats)), dtype=tf.int64) return tf.stack([timestamp + input_sum, timestamp - input_sum], axis=-1)
def GenerateStepSeedPair(p, unused_global_step=None, op_seed=None): """Override py_utils.GenerateStepSeedPair to use GetOverWriteGlobalStep.""" seed_dtype = tf.int32 if py_utils.use_tpu() else tf.int64 if p.is_inference and p.random_seed is None: # Unlike tf.random*, stateless random ops are completely determined by the # passed-in seeds. This means at inference time the same inputs will produce # the same outputs, even if the model is supposed to have randomness such as # dropout during inference. We inject additional randomness only during # inference if the graph is exported with random_seed=None as a workaround. return tf.random.uniform([2], maxval=seed_dtype.max, dtype=seed_dtype) with tf.name_scope('op_seed') as scope: global_step = tf.cast(GetOverWriteGlobalStep(), seed_dtype) step_seed = tf.cast(py_utils.GenerateSeedFromName(scope), seed_dtype) seeds = tf.stack([global_step, step_seed]) if p.random_seed is not None: seeds += p.random_seed if op_seed is not None: seeds += op_seed return seeds
def GetState(self, theta): """Gets the state from theta.""" p = self.params if p.is_inference: # State is not used for inference. Just return dummy. return tf.zeros([1], tf.float32) else: # Calculations/vars need to be float but these can be ints in the params. clip_end_step = tf.cast(p.clip_end_step, tf.float32) clip_start_step = tf.cast(p.clip_start_step, tf.float32) quant_start_step = tf.cast(p.quant_start_step, tf.float32) global_step = tf.cast(theta.global_step, tf.float32) # Will be negative if before clipping starts. clip_ratio = (tf.minimum(clip_end_step - clip_start_step, global_step - clip_start_step) / tf.maximum(1.0, clip_end_step - clip_start_step)) # Currently fq is either on (1.0) or off (-1.0). Progressive quantization # may later occupy 0..1.0. fq_ratio = tf.where(global_step < quant_start_step, -1.0, 1.0) return tf.stack([clip_ratio, fq_ratio])
def SplitTensors(xs, num_splits): """Splits tensors in `xs` evenly into num_splits along the 1st dimenion. Args: xs: A tuple of tensors. Each tensor's 1st dimension is the same size. num_splits: A python integer. Returns: A tuple of lists of tensors, num elements in the tuple = len(xs). i-th element in each list corresponds to i-th split of each tensor in xs along the first dimension of each tensor. """ # assert first dim of all tensors in xs is equal batch_dims = [tf.shape(x)[0] for x in xs] all_batch_dims = tf.stack(batch_dims) all_batch_dims = py_utils.with_dependencies([ py_utils.assert_equal(all_batch_dims, tf.shape(xs[0])[0], message='first dim of tensors in xs must match'), py_utils.assert_greater_equal( tf.shape(xs[0])[0], num_splits, message='first dim of tensors in xs must be greater than num_splits' ) ], all_batch_dims) splits = ComputeSplits(tf.shape(xs[0])[0], num_splits) # add the above assertion into the compute graph splits = py_utils.with_dependencies([all_batch_dims], splits) split_xs = [ tf.split(axis=0, num_or_size_splits=splits, value=x) for x in xs ] return split_xs
def FProp(self, theta, x, x_paddings=None, eos_id=1, force_sample_last_token=True): """Applies SymbolInsertionLayer. We take in a `x`, which represents the groundtruth sequence (i.e., English sequence). We return a sampled rollin (observed) canvas (i.e., random subset of the English sequence), as well as the target (indices) for an insertion-based model (i.e., the targets given the random observed subset). Args: theta: Ignored, this can be None. x: The symbol ids of shape `[batch_size, time_dim]`. x_paddings: The paddings (1 or 0) of shape `[batch_size, time_dim]` where 0 is valid and 1 is invalid. eos_id: The <eos> token id to represent end-of-slot. force_sample_last_token: Set True to force sample the last token of `x`. Returns: A `NestedMap`. - canvas: The canvas (based off of the `rollin_policy`) of shape [batch_size, c_dim]. Note that, `c_dim` <= `time_dim` but need not be equal. - canvas_indices: The canvas indices (into `x`). - canvas_paddings: The paddings of `canvas_indices`. - target_indices: The target indices of shape [num_targets, 3]. `num_targets` is the number of total targets in the entire batch. [:, 0] captures the batch, [:, 1] captures the slot, and [:, 2] captures the token. Each row [batch, slot, vocab] represents the indices of the target -- i.e., the batch, slot and vocab combination of the target. Typical usage of these indices is to tf.gather_nd the log-probs (from the softmax layer). - target_weights: The target weights. Raises: ValueError: If invalid params. """ p = self.params batch_size = py_utils.GetShape(x)[0] time_dim = py_utils.GetShape(x)[1] if x_paddings is None: x_paddings = tf.zeros([batch_size, time_dim], tf.float32) oracle_policy = p.oracle_policy rollin_policy = (oracle_policy if p.rollin_policy == 'oracle' else p.rollin_policy) if rollin_policy != 'uniform': raise ValueError('Unknown or unsupported rollin policy: %s' % rollin_policy) if oracle_policy != 'uniform': raise ValueError('Unknown or unsupported oracle policy: %s' % oracle_policy) x_len = tf.cast(tf.round(tf.reduce_sum(1 - x_paddings, 1)), tf.int32) # Compute the desired length per example in the batch. ratio = tf.random.uniform([batch_size], 0.0, 1.0, seed=p.random_seed) if force_sample_last_token: c_len = tf.minimum( tf.cast(ratio * tf.cast(x_len, tf.float32), tf.int32), x_len - 1) + 1 else: c_len = tf.minimum( tf.cast(ratio * tf.cast(x_len + 1, tf.float32), tf.int32), x_len) # Compute the maximum length across the batch. c_len_max = tf.reduce_max(c_len) # Grab subset of random valid indices per example. z_logits = tf.cast( tf.expand_dims(tf.range(time_dim), 0) >= tf.expand_dims(x_len, 1), tf.float32) * -1e9 if force_sample_last_token: # Force sample the last token -- i.e., as indexed by `x_len - 1`. We can # accomplish this by add +LARGE_NUMBER to the logits. z_logits += tf.cast( tf.equal(tf.expand_dims(tf.range(time_dim), 0), tf.expand_dims(x_len - 1, 1)), tf.float32) * 1e9 # Gumbel-max trick to sample (we only sample valid positions per sample in # the batch). z = -tf.math.log(-tf.math.log( tf.random.uniform([batch_size, time_dim], seed=p.random_seed))) unused_c_values, c_indices = tf.nn.top_k(z_logits + z, time_dim) # Trim everything > c_len_max. c_indices = c_indices[:, :c_len_max] # Invalidate any indices >= c_len, we use the last index as the default # invalid index. c_indices = tf.where( tf.expand_dims(tf.range(c_len_max), 0) < tf.expand_dims(c_len, 1), c_indices, tf.fill(py_utils.GetShape(c_indices), time_dim - 1)) # Materialize the canvas. c_indices = tf.sort(c_indices) c = tf.gather_nd( x, tf.stack([ tf.reshape( tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, c_len_max]), [-1]), tf.reshape(c_indices, [-1]) ], 1)) c = tf.reshape(c, [batch_size, c_len_max]) # Compute the paddings. c_paddings = 1 - tf.sequence_mask( c_len, c_len_max, dtype=x_paddings.dtype) c *= tf.cast(1 - c_paddings, tf.int32) indices = tf.concat([ tf.reshape( tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, c_len_max]), [batch_size * c_len_max, 1]), tf.reshape(c_indices, [batch_size * c_len_max, 1]) ], 1) x_token_is_observed = tf.scatter_nd( indices, tf.ones([batch_size * c_len_max], tf.int32), py_utils.GetShape(x)) # `x_segments` captures which slot each `x` belongs to (both observed and # tokens that need to be observed). x_segments = tf.cumsum(x_token_is_observed, 1, exclusive=True) x_token_is_observed = tf.cast(x_token_is_observed, tf.bool) prev_x_token_is_observed = tf.pad(x_token_is_observed[:, :-1], [[0, 0], [1, 0]], constant_values=True) x_token_is_observed = tf.reshape(x_token_is_observed, [-1]) prev_x_token_is_observed = tf.reshape(prev_x_token_is_observed, [-1]) x_is_valid = tf.cast(1 - x_paddings, tf.bool) x_is_valid = tf.reshape(x_is_valid, [-1]) # Remap all the observed to <eos>, note some of these need a zero weight # (or else there would be <eos> and valid token in the same slot). target_indices = tf.cast(tf.reshape(x, [-1, 1]), tf.int32) target_indices = tf.where( x_token_is_observed, tf.fill(py_utils.GetShape(target_indices), eos_id), target_indices) # TODO(williamchan): We give uniform 1.0 weight, however, math suggests # we may want to weigh this term by the original sequence length. target_weights = tf.ones_like(target_indices, tf.float32) # We need to set all the weights for <eos> which actually have valid tokens # in the slot to zero. target_weights = tf.where( x_token_is_observed & ~prev_x_token_is_observed, tf.zeros_like(target_weights), target_weights) # TODO(williamchan): Consider dropping the entries w/ weight zero. # Add the batch and slot indices. target_indices = tf.concat([ tf.reshape( tf.tile(tf.expand_dims(tf.range(batch_size), 1), [1, time_dim]), [batch_size * time_dim, 1]), tf.reshape(x_segments, [-1, 1]), target_indices ], 1) # Select only the valid indices. The selected valid ones include slots w/ # <eos>. target_indices = target_indices[x_is_valid] target_weights = target_weights[x_is_valid] return py_utils.NestedMap(canvas=c, canvas_indices=c_indices, canvas_paddings=c_paddings, target_indices=target_indices, target_weights=target_weights)
def Combined(x): ys = [s.Value(x) for s in self.schedules] return tf.reduce_min(tf.stack(ys), axis=0)
def _StackAndSplit(x): # Split tensors into microbatches. if x is None: return None return tf.stack( tf.split(x, p.num_micro_batches, axis=p.batch_dim))
def FProp(self, theta, *args): """Run multiple cells in different devices in a pipelining manner. Args: theta: A NestedMap object containing weights' values of this layer and its children layers. *args: Non-keyworded variable length argument list of input tensors. Returns: A list of output tensors """ # TODO(huangyp): handle optional None inputs. p = self.params if self.do_eval: outputs = copy.copy(args) for (name, l) in self._before_layers + self._cells: outputs = _ToTuple(outputs) outputs = l.FProp(theta[name], *outputs) return outputs num_cells = len(p.cell_tpl) cluster = self.cluster # Compute shapes of input and output tensors. input_shapes = self._get_input_shapes(*args) state_dtype = self._get_state_dtype(*args) state_shapes = self._CalculateOutputShapes(input_shapes) tf.logging.info('state_shapes={}'.format(state_shapes)) def GetCellFn(i): """Get the ith feature extraction layer.""" def CellFn(theta, state0, inputs): """A cell fn is exectued inside of StackedRecurrent.""" del state0 def _FPropInputSetShape(name, t_shape): if t_shape is None: return None inputs[name].set_shape(t_shape.ToTensorShape().as_list()) return inputs[name] if p.nested_map_fprop: # pylint: disable=protected-access fprop_inputs = state_shapes[i]._RecursiveMap( _FPropInputSetShape) # pylint: enable=protected-access else: fprop_inputs = [] for input_idx, input_shape in enumerate(state_shapes[i]): name = 's{}'.format(input_idx) fprop_inputs.append( _FPropInputSetShape(name, input_shape)) with py_utils.RemoveAssertContext(remove=True): with CellFnFPropOpReplacementWrapper(): tf.logging.info('cell {} input {}'.format( i, fprop_inputs)) mb_tensor = inputs[_MICRO_BATCH_STATE_NAME] SetOverWriteGlobalStep(mb_tensor) _, cell = self._cells[i] fprop_inputs = _ToTuple(fprop_inputs) outputs = cell.FProp(theta, *fprop_inputs) if p.nested_map_fprop: assert py_utils.IsCompatible(outputs, state_shapes[i + 1]) state1 = outputs.Filter(lambda x: x is not None) else: state1 = py_utils.NestedMap() outputs = _ToTuple(outputs) assert len(outputs) == len(state_shapes[i + 1]) for output_idx in range(len(outputs)): if outputs[output_idx] is not None: name = 's{}'.format(output_idx) state1[name] = outputs[output_idx] state1[_MICRO_BATCH_STATE_NAME] = mb_tensor return state1, py_utils.NestedMap() return CellFn cell_fns = [] accumulator_layers = [] thetas = [] init_states = [] devices = [] for cell_idx in range(num_cells): cell_name, cell = self._cells[cell_idx] accumulator_layers.append(cell) cell_fns.append(GetCellFn(cell_idx)) thetas.append(theta[cell_name]) def _TfZeros(t_shape): if t_shape is None: return None return tf.zeros(t_shape.ToTensorShape().as_list(), dtype=state_dtype) if p.nested_map_fprop: init_state = py_utils.Transform(_TfZeros, state_shapes[cell_idx + 1]) init_state = init_state.Filter(lambda x: x is not None) else: init_state = py_utils.NestedMap() for output_idx, state in enumerate(state_shapes[cell_idx + 1]): state = _TfZeros(state) if state is not None: name = 's{}'.format(output_idx) init_state[name] = state init_state[_MICRO_BATCH_STATE_NAME] = tf.cast(0, dtype=state_dtype) init_states.append(init_state) devices.append(cluster.WorkerDeviceInModelSplit(cell_idx)) cell_grads = [None] * num_cells cell_outs = [lambda x: x] * num_cells cell_out_grads = [lambda x: x] * num_cells with tf.device(devices[0]): previous = _ToTuple(args) for (name, l) in self._before_layers: previous = l.FProp(theta[name], *previous) previous = _ToTuple(previous) def _StackAndSplit(x): # Split tensors into microbatches. if x is None: return None return tf.stack( tf.split(x, p.num_micro_batches, axis=p.batch_dim)) if p.nested_map_fprop: inputs = py_utils.Transform(_StackAndSplit, previous[0]) inputs = inputs.Filter(lambda x: x is not None) else: inputs = py_utils.NestedMap() for output_idx, output_tensor in enumerate(previous): output_tensor = _StackAndSplit(output_tensor) if output_tensor is not None: name = 's{}'.format(output_idx) inputs[name] = output_tensor gs_tensor = py_utils.GetGlobalStep() inputs[_MICRO_BATCH_STATE_NAME] = tf.stack([ tf.cast(gs_tensor * p.num_micro_batches + t, dtype=state_dtype) for t in range(p.num_micro_batches) ]) tf.logging.info('pipeline input = {}'.format(inputs)) output_state, _ = recurrent.StackedRecurrent( devices=devices, cell_fns=cell_fns, cell_grads=cell_grads, cell_outs=cell_outs, cell_out_grads=cell_out_grads, thetas=thetas, init_states=init_states, inputs=inputs, accumulator_layers=accumulator_layers, unused_acc_state=True) with tf.device(devices[-1]): def _ReshapeRetVal(name, t_shape): """Restore shape for tensors in microbatches.""" if t_shape is None: return None output_tensor = output_state[name] if p.batch_dim != 0: perm = list(range(1, p.batch_dim + 1)) + [0] perm += list(range(p.batch_dim + 1, t_shape.rank + 1)) output_tensor = tf.transpose(output_tensor, perm=perm) output_shape = t_shape.ToTensorShape().as_list() output_shape[p.batch_dim] *= p.num_micro_batches output_tensor = tf.reshape(output_tensor, output_shape) return output_tensor # Construct the final return values from output_state. if p.nested_map_fprop: # pylint: disable=protected-access output_tensors = state_shapes[-1]._RecursiveMap(_ReshapeRetVal) # pylint: enable=protected-access else: output_tensors = [] for output_idx, state_shape in enumerate(state_shapes[-1]): output_name = 's{}'.format(output_idx) output_tensor = _ReshapeRetVal(output_name, state_shape) output_tensors.append(output_tensor) if len(output_tensors) == 1: output_tensors = output_tensors[0] else: output_tensors = tuple(output_tensors) tf.logging.info('pipeline output = {}'.format(output_tensors)) return output_tensors
def MergeBeamSearchOutputs(max_hyps_per_beam, beam_search_outputs): """Merges beam search hyps from multiple decoders. Args: max_hyps_per_beam: the number of top hyps in the merged results. Must be less than or equal to total number of input hyps. beam_search_outputs: a list of BeamSearchDecodeOutput objects. Must share the same source_batch and max sequence length. Returns: A BeamSearchDecodeOutput object containing max_hyps_per_beam hypotheses per beam. """ source_batch = tf.shape(beam_search_outputs[0].topk_hyps)[0] value_dict = {} for output in beam_search_outputs: hyps_per_beam = py_utils.with_dependencies([ py_utils.assert_equal(source_batch, tf.shape(output.topk_hyps)[0]), ], tf.shape( output.topk_hyps)[1]) for k, v in six.iteritems(output._asdict()): if v is None: continue if k == 'done_hyps': v = tf.transpose(v) if k not in value_dict: value_dict[k] = [] value_dict[k].append( tf.reshape(v, [source_batch, hyps_per_beam, -1])) # Concatenate the tensors along the 'num_hyps_per_beam' dimension. concatenated = {} for k, values in six.iteritems(value_dict): if len(values) != len(beam_search_outputs): raise ValueError('Incomplete values for %s: %s' % (k, beam_search_outputs)) concatenated[k] = tf.concat(values, axis=1) scores = concatenated['topk_scores'] scores = tf.where(tf.equal(concatenated['topk_lens'], 0), tf.fill(tf.shape(scores), -1e6), scores) scores = tf.squeeze(scores, -1) # Select top max_hyps_per_beam indices per beam. _, top_indices = tf.nn.top_k(scores, max_hyps_per_beam) batch_ids = tf.tile(tf.expand_dims(tf.range(source_batch), -1), [1, max_hyps_per_beam]) # [source_batch, max_hyps_per_beam, 2] gather_indices = tf.stack([batch_ids, top_indices], axis=-1) # Gather the merged top hyps according to 'gather_indices'. top = beam_search_outputs[0]._asdict() total_hyps = source_batch * max_hyps_per_beam for k, v in six.iteritems(concatenated): v = tf.gather_nd(v, gather_indices) if k == 'done_hyps': v = tf.transpose(tf.reshape(v, [total_hyps, -1])) elif k == 'topk_hyps': v = tf.reshape(v, [source_batch, max_hyps_per_beam]) elif k == 'topk_ids': v = tf.reshape(v, [total_hyps, -1]) elif k in ('topk_lens', 'topk_scores', 'topk_decoded'): v = tf.reshape(v, [total_hyps]) else: raise ValueError('Unexpected field: %s' % k) top[k] = v return BeamSearchDecodeOutput(**top)
def FProp(self, theta, transformer_input, paddings, src_segment_id=None, aux_vecs=None, aux_paddings=None, aux_segment_id=None): """Transforms source sequence of Tensors with Transformers layers. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. transformer_input: A sequence of input Tensors of [time, batch, dim] shape. paddings: A sequence of 0s and 1s indicating input paddings of [time, batch] shape. src_segment_id: A sequence of ints indicating segment ids of [time, batch] shape. aux_vecs: A sequence of input Tensors of [aux_time, batch, dim] shape, as context for the cross-attention layer. aux_paddings: A sequence of 0s and 1s indicating input paddings of [aux_time, batch] shape. aux_segment_id: A sequence of ints indicating segment ids of [aux_time, batch] shape. Returns: (outputs, out_paddings, segment_ids) tuple. `outputs` is of the shape [time, batch, depth], and `out_paddings` has shape [time, batch]. If is_transparent is True, can return a list of num_transformer_layers tensors of shape [time, batch, depth] if `self.do_eval` is False, and a [time, batch, depth, num_transparent_outputs] tensor if `self.do_eval` is True. If packed_input is True, also returns segment_id, otherwise returns None. """ p = self.params if p.packed_input: assert src_segment_id is not None, ( 'Need to specify src_segment_id if ' 'packed input is supported.') outputs_list = [transformer_input] with tf.name_scope(p.name): for i, transformer_l in enumerate(self.trans): # For encoder, keys, values and queries are the same transformer_output, _ = transformer_l.FProp( theta.trans[i], transformer_input, paddings, aux_vecs=aux_vecs, aux_paddings=aux_paddings, source_segment_id=src_segment_id, aux_segment_id=aux_segment_id) transformer_input = transformer_output outputs_list.append(transformer_output) if p.ln_output: transformer_output = self.layer_norm_out.FProp( theta.layer_norm_out, transformer_output) # When is_transparent is set, it outputs a list of tensors during # training and the stacked tensors otherwise. This dual behavior is meant # to avoid excessive memory usage during training (which was prohibiting # training on TPUs), and simplify the beam search interface. if p.is_transparent: if p.num_transparent_outputs == 1: transformer_output = self.transparent_merger[0].FProp( theta.transparent_merger[0], outputs_list) else: transformer_output = [] for i in range(p.num_transparent_outputs): merged_outputs = self.transparent_merger[i].FProp( theta.transparent_merger[i], outputs_list) transformer_output.append(merged_outputs) if self.do_eval: transformer_output = tf.stack(transformer_output, 3) return transformer_output, paddings, src_segment_id