def FProp(self, theta, *args): p = self.params # Collects all variable key and values into sets. theta_stack = _MaybeStackExtraTheta(theta.body, self.body.vars, p.repeat) def _ArgsToState(arg_list): """Returns a NestedMap from a list of FProp args.""" state = py_utils.NestedMap() # Maintains a mapping from arg_idx to tensor. states cannot contains # None tensors. for idx in range(len(args)): if arg_list[idx] is not None: state['_s{}'.format(idx)] = arg_list[idx] return state def _StateToArgs(state): """Returns a list of FProp args from a NestedMap.""" arg_list = [] for idx in range(len(args)): attr = '_s{}'.format(idx) arg_list.append(state[attr] if attr in state else None) if arg_list[-1] is not None: arg_list[-1].set_shape(args[idx].shape) return arg_list def _CellFn(unused_theta, state0, theta_i): """Recurrent cell function wrapper of body.FProp.""" # Retrieves fprop arguments from state and sets shapes. frop_inputs = _StateToArgs(state0) # Sets shapes for theta_i as well. for dst, src in zip(theta_i.Flatten(), theta_stack.Flatten()): if src is not None: dst.set_shape(tf.TensorShape(src.shape.as_list()[1:])) # Runs the actual body.FProp frop_outputs = self.body.FProp(theta_i, *frop_inputs) frop_outputs = _ToTuple(frop_outputs) assert len(frop_outputs) == len(frop_inputs) # Passes fprop outputs to the next layer through state. state1 = _ArgsToState(frop_outputs) return state1, py_utils.NestedMap() with tf.name_scope(p.name): # Add FProp arg list to state0. state0 = _ArgsToState(args) # Runs body.FProp k times using Recurrent where k = dim 0 of var_nmap. _, state1 = recurrent.Recurrent( theta=py_utils.NestedMap(), state0=state0, inputs=theta_stack, # Pass cell_fn theta through inputs. cell_fn=_CellFn) # Retrieves fprop outputs from state1 and sets shapes. output_tensors = _StateToArgs(state1) return output_tensors[0] if len(args) == 1 else tuple( output_tensors)
def FProp(self, theta, prepared_inputs, inputs, padding, state0, **kwargs): """Runs a Step layer over multiple timesteps using Recurrent. Args: theta: A NestedMap containing weights' values of this layer and its children layers. prepared_inputs: External inputs returned by Step.PrepareExternalInputs(). inputs: A NestedMap of inputs of shape [time, batch_size, dim]. padding: A 0/1 float tensor of shape [time, batch_size]; 1.0 means that this batch element is empty in this step. state0: A NestedMap containing the initial recurrent state. **kwargs: Additional kwargs to pass to Recurrent. Returns: A tuple (outputs, state1). - outputs: A NestedMap containing the accumulated outputs of all steps, containing Tensors shaped [time, batch_size, dim]. - state1: A NestedMap containing the accumulated recurrent states, containing Tensors shaped [time, batch_size, dim]. """ def RnnStep(recurrent_theta, recurrent_state0, recurrent_inputs): """Compute a single timestep.""" output, state1 = self.step.FProp( theta=recurrent_theta.theta, prepared_inputs=recurrent_theta.prepared_inputs, step_inputs=recurrent_inputs.inputs, padding=recurrent_inputs.padding, state0=recurrent_state0.state) recurrent_state1 = py_utils.NestedMap(output=output, state=state1) return recurrent_state1, py_utils.NestedMap() # In order to pass Step outputs through Recurrent, they need to be # included as part of state. output0, _ = self.step.FProp(theta.step, prepared_inputs, inputs.Transform(lambda x: x[0]), padding[0], state0) accumulated_states, _ = recurrent.Recurrent( theta=py_utils.NestedMap(theta=theta.step, prepared_inputs=prepared_inputs), state0=py_utils.NestedMap(output=output0, state=state0), inputs=py_utils.NestedMap(inputs=inputs, padding=padding), cell_fn=RnnStep, **kwargs) return accumulated_states.output, accumulated_states.state
def FProp(self, theta, inputs, paddings, state0=None, segment_id=None): """Computes LSTM forward pass. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. inputs: A single tensor or a tuple of tensors with cardinality equal to rnn_cell.inputs_arity. For every input tensor, the first dimension is assumed to be time, second dimension batch, and third dimension depth. paddings: A tensor. First dim is time, second dim is batch, and third dim is expected to be 1. state0: If not None, the initial rnn state in a `.NestedMap`. Defaults to the cell's zero-state. segment_id: A tensor to support packed inputs. First dim is time, second dim is batch, and third dim is expected to be 1. Returns: A tensor of [time, batch, dims]. The final recurrent state. """ p = self.params rcell = self.cell assert isinstance(rcell, (rnn_cell.RNNCell)) if not isinstance(inputs, (list, tuple)): inputs = [inputs] # Slicing wm to wm_{i,h} outside the loop to get 20% speedup over regular # LSTM baseline. # Keeping slicing within the loop gives only < 3% speedup. cell_theta = theta.cell.copy() num_input_nodes = p.cell.num_input_nodes cell_theta['wm_i'] = cell_theta.wm[:num_input_nodes, :] cell_theta['wm_h'] = cell_theta.wm[num_input_nodes:, :] tf.logging.vlog(1, 'cell_theta: %r', cell_theta) if p.packed_input: assert segment_id is not None reset_mask = rnn_layers.GeneratePackedInputResetMask( segment_id, is_reverse=False) reset_mask = py_utils.HasShape(reset_mask, tf.shape(paddings)) else: reset_mask = tf.zeros_like(paddings) if p.reverse: inputs = [tf.reverse(x, [0]) for x in inputs] paddings = tf.reverse(paddings, [0]) reset_mask = tf.reverse(reset_mask, [0]) if not state0: batch_size = py_utils.GetShape(paddings)[1] state0 = rcell.zero_state(cell_theta, batch_size) # [T, B, H] proj_inputs = rcell.ProjectInputSequence( cell_theta, py_utils.NestedMap(act=inputs)) proj_inputs = py_utils.NestedMap(proj_inputs=proj_inputs, padding=paddings, reset_mask=reset_mask) acc_state, final_state = recurrent.Recurrent( theta=cell_theta, state0=state0, inputs=proj_inputs, cell_fn=rcell.FPropWithProjectedInput, cell_type=rcell.layer_type, accumulator_layer=self, allow_implicit_capture=p.allow_implicit_capture) act = rcell.GetOutput(acc_state) if p.reverse: act = tf.reverse(act, [0]) return act, final_state
def Sample(self, decoder_theta, encoder_outputs, random_seed, init_state_callback, pre_step_callback, post_step_callback): """Samples target sequences, one target sequence per source sequence. (Please see beam_search_helper.py for description of decoder callbacks.) Args: decoder_theta: A NestedMap object containing weights' values of the decoder layer and its children layers, to be passed to decoder callbacks. encoder_outputs: the outputs of the encoder, to be passed to callbacks. random_seed: a scalar int32 tensor representing the random seed. init_state_callback: decoder._InitBeamSearchStateCallback. pre_step_callback: decoder._PreBeamSearchStepCallback. post_step_callback: decoder._PostBeamSearchStepCallback. Returns: A NestedMap containing the following tensors - 'logits': [batch, max_target_length, vocab_size], representing the distribution from which target sequences are sampled. - 'ids': [batch, max_target_length] of int32, representing the target sequence ids, not including target_sos_id, but maybe ending with target_eos_id if end-of-sequence is reached before target_seq_len. - 'paddings': [batch, max_target_length] of 0/1, where 1 represents a padded timestep. """ p = self.params assert p.temperature > 0 if getattr(encoder_outputs, 'segment_id', 1) is None: # Remove None values, which are not supported by recurrent. del encoder_outputs['segment_id'] # init_state_callback may modify 'encoder_outputs', e.g., by inserting # 'packed_src'. bs_result, bs_state = init_state_callback(decoder_theta, encoder_outputs, num_hyps_per_beam=1) # 'recurrent_theta' represents all cross-timestep information used by the # recurrent loop below, including layer theta and encoder outputs. recurrent_theta = py_utils.NestedMap(theta=decoder_theta, random_seed=random_seed, encoder_outputs=encoder_outputs) batch = tf.shape(bs_result.log_probs)[0] recurrent_state0 = py_utils.NestedMap( timestep=tf.zeros(shape=[], dtype=tf.int32), logits=bs_result.log_probs, # Start with target_sos_id. ids=tf.fill([batch], tf.cast(p.target_sos_id, tf.int32)), bs_state=bs_state) inputs = py_utils.NestedMap(dummy=tf.zeros([p.target_seq_len, batch])) def Step(recurrent_theta, state0, inputs): """Computes one decoder step.""" del inputs with tf.name_scope('single_sampler_step'): # Compute logits and states. bs_result, bs_state1 = pre_step_callback( recurrent_theta.theta, recurrent_theta.encoder_outputs, tf.expand_dims(state0.ids, 1), # [batch, 1]. state0.bs_state, num_hyps_per_beam=1) batch = tf.shape(bs_result.log_probs)[0] state1 = py_utils.NestedMap(timestep=state0.timestep + 1) state1.logits = bs_result.log_probs # Sample ids from logits. [batch]. state1.ids = tf.reshape( tf.random.stateless_categorical( state1.logits / p.temperature, num_samples=1, seed=tf.stack( [recurrent_theta.random_seed, state0.timestep]), dtype=state0.ids.dtype, name='sample_next_id'), [batch]) if 'is_last_chunk' in bs_result and p.target_eoc_id >= 0: state1.ids = tf.where( tf.math.logical_and( bs_result.is_last_chunk, tf.equal(state1.ids, p.target_eoc_id)), tf.fill(tf.shape(state1.ids), p.target_eos_id), state1.ids) state1.bs_state = post_step_callback( recurrent_theta.theta, recurrent_theta.encoder_outputs, state1.ids, bs_state1) return state1, py_utils.NestedMap() accumulated_states, _ = recurrent.Recurrent( recurrent_theta, recurrent_state0, inputs, Step, allow_implicit_capture=True) result = py_utils.NestedMap(logits=tf.transpose( accumulated_states.logits, [1, 0, 2]), ids=tf.transpose(accumulated_states.ids)) result.paddings = tf.cast( _ComputePaddings(result.ids, p.target_eos_id), result.logits.dtype) # Force ids to be eos_id if the timestep is padded. result.ids = tf.where(tf.equal(result.paddings, 0), result.ids, tf.fill(tf.shape(result.ids), p.target_eos_id)) static_batch_size = bs_result.log_probs.shape[0] result.ids.set_shape([static_batch_size, p.target_seq_len]) result.paddings.set_shape([static_batch_size, p.target_seq_len]) return result
def FProp(self, theta, *args): """Runs p.repeat copies of self.body.FProp independently. Args: theta: Layer model parameters. The shape of each variable in theta is always [p.repeat, ...]. And the i-th slice theta[i] becomes theta of the i-th copy of self.body. *args: Input arguments. The shape of each tensor in args is always [p.repeat, ....]. And the list [arg[i] for arg in args] becomes inputs to the i-th copy of self.body.FProp. Returns: The accumulated output_tensors. Each tensor t in the return has the shape [p.repeat, ....] and the tuple (t[i] for i in output_tensors) is the return tuple of the i-th self.body.FProp. """ p = self.params for arg in args: if arg is not None: arg = py_utils.HasShape(arg, [p.repeat], ndims=1) theta_stack = _MaybeStackExtraTheta(theta.body, self.body.vars, p.repeat) inputs = py_utils.NestedMap(theta=theta_stack, args=list(args)) # Infer out_shapes from FPropMeta. out_shapes = self._InferOutShapes(args) def _CellFn(unused_theta, unused_state0, inputs): """Recurrent cell function wrapper of body.FProp.""" # Sets shapes for both theta and inputs to self.body.FProp. for dst, src in zip(inputs.args + inputs.theta.Flatten(), list(args) + theta_stack.Flatten()): if src is not None: dst.set_shape(tf.TensorShape(src.shape.as_list()[1:])) # Runs the actual body.FProp fprop_outputs = self.body.FProp(inputs.theta, *inputs.args) fprop_outputs = _ToTuple(fprop_outputs) assert len(fprop_outputs) == len(out_shapes) # Passes fprop outputs to the next layer through state. state1 = py_utils.NestedMap(outputs=list(fprop_outputs)) return state1, py_utils.NestedMap() with tf.name_scope(p.name): # Initiate state0 with inferred output shapes. state0 = py_utils.NestedMap(outputs=[ tf.zeros(shape, args[0].dtype) for shape in out_shapes ]) # Runs body.FProp p.repeat times using Recurrent. acc_states, _ = recurrent.Recurrent(theta=py_utils.NestedMap(), state0=state0, inputs=inputs, cell_fn=_CellFn) # Retrieves fprop outputs from state1 and sets shapes. output_tensors = tuple(acc_states.outputs) for out_idx in range(len(output_tensors)): output_tensors[out_idx].set_shape( tf.TensorShape([p.repeat] + out_shapes[out_idx].as_list())) return output_tensors[0] if len(args) == 1 else tuple( output_tensors)