示例#1
0
    def FProp(self, theta, *args):
        p = self.params
        # Collects all variable key and values into sets.
        theta_stack = _MaybeStackExtraTheta(theta.body, self.body.vars,
                                            p.repeat)

        def _ArgsToState(arg_list):
            """Returns a NestedMap from a list of FProp args."""
            state = py_utils.NestedMap()
            # Maintains a mapping from arg_idx to tensor. states cannot contains
            # None tensors.
            for idx in range(len(args)):
                if arg_list[idx] is not None:
                    state['_s{}'.format(idx)] = arg_list[idx]
            return state

        def _StateToArgs(state):
            """Returns a list of FProp args from a NestedMap."""
            arg_list = []
            for idx in range(len(args)):
                attr = '_s{}'.format(idx)
                arg_list.append(state[attr] if attr in state else None)
                if arg_list[-1] is not None:
                    arg_list[-1].set_shape(args[idx].shape)
            return arg_list

        def _CellFn(unused_theta, state0, theta_i):
            """Recurrent cell function wrapper of body.FProp."""
            # Retrieves fprop arguments from state and sets shapes.
            frop_inputs = _StateToArgs(state0)

            # Sets shapes for theta_i as well.
            for dst, src in zip(theta_i.Flatten(), theta_stack.Flatten()):
                if src is not None:
                    dst.set_shape(tf.TensorShape(src.shape.as_list()[1:]))

            # Runs the actual body.FProp
            frop_outputs = self.body.FProp(theta_i, *frop_inputs)
            frop_outputs = _ToTuple(frop_outputs)
            assert len(frop_outputs) == len(frop_inputs)

            # Passes fprop outputs to the next layer through state.
            state1 = _ArgsToState(frop_outputs)
            return state1, py_utils.NestedMap()

        with tf.name_scope(p.name):
            # Add FProp arg list to state0.
            state0 = _ArgsToState(args)
            # Runs body.FProp k times using Recurrent where k = dim 0 of var_nmap.
            _, state1 = recurrent.Recurrent(
                theta=py_utils.NestedMap(),
                state0=state0,
                inputs=theta_stack,  # Pass cell_fn theta through inputs.
                cell_fn=_CellFn)

            # Retrieves fprop outputs from state1 and sets shapes.
            output_tensors = _StateToArgs(state1)
            return output_tensors[0] if len(args) == 1 else tuple(
                output_tensors)
示例#2
0
    def FProp(self, theta, prepared_inputs, inputs, padding, state0, **kwargs):
        """Runs a Step layer over multiple timesteps using Recurrent.

    Args:
      theta: A NestedMap containing weights' values of this layer and its
        children layers.
      prepared_inputs: External inputs returned by Step.PrepareExternalInputs().
      inputs: A NestedMap of inputs of shape [time, batch_size, dim].
      padding: A 0/1 float tensor of shape [time, batch_size]; 1.0 means that
        this batch element is empty in this step.
      state0: A NestedMap containing the initial recurrent state.
      **kwargs: Additional kwargs to pass to Recurrent.

    Returns:
      A tuple (outputs, state1).

      - outputs: A NestedMap containing the accumulated outputs of all steps,
        containing Tensors shaped [time, batch_size, dim].
      - state1: A NestedMap containing the accumulated recurrent states,
        containing Tensors shaped [time, batch_size, dim].
    """
        def RnnStep(recurrent_theta, recurrent_state0, recurrent_inputs):
            """Compute a single timestep."""
            output, state1 = self.step.FProp(
                theta=recurrent_theta.theta,
                prepared_inputs=recurrent_theta.prepared_inputs,
                step_inputs=recurrent_inputs.inputs,
                padding=recurrent_inputs.padding,
                state0=recurrent_state0.state)
            recurrent_state1 = py_utils.NestedMap(output=output, state=state1)
            return recurrent_state1, py_utils.NestedMap()

        # In order to pass Step outputs through Recurrent, they need to be
        # included as part of state.
        output0, _ = self.step.FProp(theta.step, prepared_inputs,
                                     inputs.Transform(lambda x: x[0]),
                                     padding[0], state0)

        accumulated_states, _ = recurrent.Recurrent(
            theta=py_utils.NestedMap(theta=theta.step,
                                     prepared_inputs=prepared_inputs),
            state0=py_utils.NestedMap(output=output0, state=state0),
            inputs=py_utils.NestedMap(inputs=inputs, padding=padding),
            cell_fn=RnnStep,
            **kwargs)

        return accumulated_states.output, accumulated_states.state
    def FProp(self, theta, inputs, paddings, state0=None, segment_id=None):
        """Computes LSTM forward pass.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      inputs: A single tensor or a tuple of tensors with cardinality equal to
        rnn_cell.inputs_arity. For every input tensor, the first dimension is
        assumed to be time, second dimension batch, and third dimension depth.
      paddings: A tensor. First dim is time, second dim is batch, and third dim
        is expected to be 1.
      state0: If not None, the initial rnn state in a `.NestedMap`. Defaults to
        the cell's zero-state.
      segment_id: A tensor to support packed inputs. First dim is time, second
        dim is batch, and third dim is expected to be 1.

    Returns:
      A tensor of [time, batch, dims].
      The final recurrent state.
    """
        p = self.params
        rcell = self.cell
        assert isinstance(rcell, (rnn_cell.RNNCell))

        if not isinstance(inputs, (list, tuple)):
            inputs = [inputs]

        # Slicing wm to wm_{i,h} outside the loop to get 20% speedup over regular
        # LSTM baseline.
        # Keeping slicing within the loop gives only < 3% speedup.
        cell_theta = theta.cell.copy()
        num_input_nodes = p.cell.num_input_nodes
        cell_theta['wm_i'] = cell_theta.wm[:num_input_nodes, :]
        cell_theta['wm_h'] = cell_theta.wm[num_input_nodes:, :]
        tf.logging.vlog(1, 'cell_theta: %r', cell_theta)
        if p.packed_input:
            assert segment_id is not None
            reset_mask = rnn_layers.GeneratePackedInputResetMask(
                segment_id, is_reverse=False)
            reset_mask = py_utils.HasShape(reset_mask, tf.shape(paddings))
        else:
            reset_mask = tf.zeros_like(paddings)

        if p.reverse:
            inputs = [tf.reverse(x, [0]) for x in inputs]
            paddings = tf.reverse(paddings, [0])
            reset_mask = tf.reverse(reset_mask, [0])

        if not state0:
            batch_size = py_utils.GetShape(paddings)[1]
            state0 = rcell.zero_state(cell_theta, batch_size)

        # [T, B, H]
        proj_inputs = rcell.ProjectInputSequence(
            cell_theta, py_utils.NestedMap(act=inputs))
        proj_inputs = py_utils.NestedMap(proj_inputs=proj_inputs,
                                         padding=paddings,
                                         reset_mask=reset_mask)

        acc_state, final_state = recurrent.Recurrent(
            theta=cell_theta,
            state0=state0,
            inputs=proj_inputs,
            cell_fn=rcell.FPropWithProjectedInput,
            cell_type=rcell.layer_type,
            accumulator_layer=self,
            allow_implicit_capture=p.allow_implicit_capture)

        act = rcell.GetOutput(acc_state)
        if p.reverse:
            act = tf.reverse(act, [0])
        return act, final_state
    def Sample(self, decoder_theta, encoder_outputs, random_seed,
               init_state_callback, pre_step_callback, post_step_callback):
        """Samples target sequences, one target sequence per source sequence.

    (Please see beam_search_helper.py for description of decoder callbacks.)

    Args:
      decoder_theta: A NestedMap object containing weights' values of the
        decoder layer and its children layers, to be passed to decoder
        callbacks.
      encoder_outputs: the outputs of the encoder, to be passed to callbacks.
      random_seed: a scalar int32 tensor representing the random seed.
      init_state_callback: decoder._InitBeamSearchStateCallback.
      pre_step_callback: decoder._PreBeamSearchStepCallback.
      post_step_callback: decoder._PostBeamSearchStepCallback.

    Returns:
      A NestedMap containing the following tensors

      - 'logits': [batch, max_target_length, vocab_size], representing the
        distribution from which target sequences are sampled.
      - 'ids': [batch, max_target_length] of int32, representing the target
        sequence ids, not including target_sos_id, but maybe ending with
        target_eos_id if end-of-sequence is reached before target_seq_len.
      - 'paddings': [batch, max_target_length] of 0/1, where 1 represents
        a padded timestep.
    """
        p = self.params
        assert p.temperature > 0
        if getattr(encoder_outputs, 'segment_id', 1) is None:
            # Remove None values, which are not supported by recurrent.
            del encoder_outputs['segment_id']
        # init_state_callback may modify 'encoder_outputs', e.g., by inserting
        # 'packed_src'.
        bs_result, bs_state = init_state_callback(decoder_theta,
                                                  encoder_outputs,
                                                  num_hyps_per_beam=1)
        # 'recurrent_theta' represents all cross-timestep information used by the
        # recurrent loop below, including layer theta and encoder outputs.
        recurrent_theta = py_utils.NestedMap(theta=decoder_theta,
                                             random_seed=random_seed,
                                             encoder_outputs=encoder_outputs)
        batch = tf.shape(bs_result.log_probs)[0]
        recurrent_state0 = py_utils.NestedMap(
            timestep=tf.zeros(shape=[], dtype=tf.int32),
            logits=bs_result.log_probs,
            # Start with target_sos_id.
            ids=tf.fill([batch], tf.cast(p.target_sos_id, tf.int32)),
            bs_state=bs_state)
        inputs = py_utils.NestedMap(dummy=tf.zeros([p.target_seq_len, batch]))

        def Step(recurrent_theta, state0, inputs):
            """Computes one decoder step."""
            del inputs
            with tf.name_scope('single_sampler_step'):
                # Compute logits and states.
                bs_result, bs_state1 = pre_step_callback(
                    recurrent_theta.theta,
                    recurrent_theta.encoder_outputs,
                    tf.expand_dims(state0.ids, 1),  # [batch, 1].
                    state0.bs_state,
                    num_hyps_per_beam=1)
                batch = tf.shape(bs_result.log_probs)[0]
                state1 = py_utils.NestedMap(timestep=state0.timestep + 1)
                state1.logits = bs_result.log_probs
                # Sample ids from logits. [batch].
                state1.ids = tf.reshape(
                    tf.random.stateless_categorical(
                        state1.logits / p.temperature,
                        num_samples=1,
                        seed=tf.stack(
                            [recurrent_theta.random_seed, state0.timestep]),
                        dtype=state0.ids.dtype,
                        name='sample_next_id'), [batch])
                if 'is_last_chunk' in bs_result and p.target_eoc_id >= 0:
                    state1.ids = tf.where(
                        tf.math.logical_and(
                            bs_result.is_last_chunk,
                            tf.equal(state1.ids, p.target_eoc_id)),
                        tf.fill(tf.shape(state1.ids), p.target_eos_id),
                        state1.ids)
                state1.bs_state = post_step_callback(
                    recurrent_theta.theta, recurrent_theta.encoder_outputs,
                    state1.ids, bs_state1)
            return state1, py_utils.NestedMap()

        accumulated_states, _ = recurrent.Recurrent(
            recurrent_theta,
            recurrent_state0,
            inputs,
            Step,
            allow_implicit_capture=True)
        result = py_utils.NestedMap(logits=tf.transpose(
            accumulated_states.logits, [1, 0, 2]),
                                    ids=tf.transpose(accumulated_states.ids))
        result.paddings = tf.cast(
            _ComputePaddings(result.ids, p.target_eos_id), result.logits.dtype)
        # Force ids to be eos_id if the timestep is padded.
        result.ids = tf.where(tf.equal(result.paddings, 0), result.ids,
                              tf.fill(tf.shape(result.ids), p.target_eos_id))
        static_batch_size = bs_result.log_probs.shape[0]
        result.ids.set_shape([static_batch_size, p.target_seq_len])
        result.paddings.set_shape([static_batch_size, p.target_seq_len])
        return result
示例#5
0
    def FProp(self, theta, *args):
        """Runs p.repeat copies of self.body.FProp independently.

    Args:
      theta: Layer model parameters. The shape of each variable in theta is
        always [p.repeat, ...]. And the i-th slice theta[i] becomes theta of the
        i-th copy of self.body.
      *args: Input arguments. The shape of each tensor in args is always
        [p.repeat, ....]. And the list [arg[i] for arg in args] becomes inputs
        to the i-th copy of self.body.FProp.

    Returns:
      The accumulated output_tensors. Each tensor t in the return has the shape
      [p.repeat, ....] and the tuple (t[i] for i in output_tensors) is the
      return tuple of the i-th self.body.FProp.
    """
        p = self.params
        for arg in args:
            if arg is not None:
                arg = py_utils.HasShape(arg, [p.repeat], ndims=1)

        theta_stack = _MaybeStackExtraTheta(theta.body, self.body.vars,
                                            p.repeat)
        inputs = py_utils.NestedMap(theta=theta_stack, args=list(args))
        # Infer out_shapes from FPropMeta.
        out_shapes = self._InferOutShapes(args)

        def _CellFn(unused_theta, unused_state0, inputs):
            """Recurrent cell function wrapper of body.FProp."""
            # Sets shapes for both theta and inputs to self.body.FProp.
            for dst, src in zip(inputs.args + inputs.theta.Flatten(),
                                list(args) + theta_stack.Flatten()):
                if src is not None:
                    dst.set_shape(tf.TensorShape(src.shape.as_list()[1:]))

            # Runs the actual body.FProp
            fprop_outputs = self.body.FProp(inputs.theta, *inputs.args)
            fprop_outputs = _ToTuple(fprop_outputs)
            assert len(fprop_outputs) == len(out_shapes)
            # Passes fprop outputs to the next layer through state.
            state1 = py_utils.NestedMap(outputs=list(fprop_outputs))
            return state1, py_utils.NestedMap()

        with tf.name_scope(p.name):
            # Initiate state0 with inferred output shapes.
            state0 = py_utils.NestedMap(outputs=[
                tf.zeros(shape, args[0].dtype) for shape in out_shapes
            ])
            # Runs body.FProp p.repeat times using Recurrent.
            acc_states, _ = recurrent.Recurrent(theta=py_utils.NestedMap(),
                                                state0=state0,
                                                inputs=inputs,
                                                cell_fn=_CellFn)

            # Retrieves fprop outputs from state1 and sets shapes.
            output_tensors = tuple(acc_states.outputs)
            for out_idx in range(len(output_tensors)):
                output_tensors[out_idx].set_shape(
                    tf.TensorShape([p.repeat] + out_shapes[out_idx].as_list()))

            return output_tensors[0] if len(args) == 1 else tuple(
                output_tensors)