Ejemplo n.º 1
0
    def _ProcessMASSInput(self, source_id, src):
        """Perform MASS input processing."""
        # TODO(yuancao): By doing so we assume that right now for monolingual
        # eval/dev sets (xx->xx) are in double-column format (since it bypasses
        # the Mass op). Ideally we should add a dedicated eval/dev processing
        # procedure for unsupervised MT cases, so that single-column eval/devs sets
        # are also supported. This should not be handled by any specific ops like
        # Mass, but inside the TextPackedInput class.
        assert not self.do_eval, 'MASS input can only be used for training.'

        _, labels, paddings = self.StringsToIds(tf.reshape(src, [1]),
                                                is_source=True,
                                                key=self._src_tokenizer_key)
        weights = 1 - paddings
        actual_seq_len = tf.cast(tf.reduce_sum(weights, 1), tf.int32)
        src_lang_ids, tgt_lang_ids = self._GetTaskIds(source_id)

        mass_out = self.mass_layer.Mask(labels, weights, actual_seq_len)

        features = py_utils.NestedMap()
        features.src = py_utils.NestedMap()
        features.src.ids = mass_out.src.ids
        features.src.paddings = paddings
        features.src.weights = weights
        features.src.task_ids = tf.cast(features.src.weights,
                                        dtype=tf.int32) * src_lang_ids
        features.src.ids_indicator = weights
        features.tgt = py_utils.NestedMap()
        features.tgt.ids = mass_out.tgt.ids
        features.tgt.labels = mass_out.tgt.labels
        features.tgt.paddings = paddings
        features.tgt.weights = mass_out.tgt.weights
        features.tgt.task_ids = tf.ones_like(features.src.task_ids,
                                             dtype=tf.int32) * tgt_lang_ids
        features.tgt.ids_indicator = weights

        if not py_utils.use_tpu():
            features.src.strs = src
            features.tgt.strs = src
        return features.Transform(tf.squeeze)
Ejemplo n.º 2
0
    def FProp(self, theta, input_batch):
        p = self.params
        with tf.name_scope(p.name):
            inputs = py_utils.with_dependencies([
                py_utils.assert_shape_match(tf.shape(input_batch.ids),
                                            [-1, -1]),
                py_utils.assert_shape_match(tf.shape(input_batch.ids),
                                            tf.shape(input_batch.paddings))
            ], tf.transpose(input_batch.ids))
            paddings = tf.expand_dims(tf.transpose(input_batch.paddings), 2)
            if p.packed_input:
                src_segment_id = tf.expand_dims(
                    tf.transpose(input_batch.segment_ids), 2)
            else:
                src_segment_id = None
            xs = self.emb.EmbLookup(theta.emb, inputs)
            xs = self.ApplyClipping(theta, xs)
            summary_utils.histogram('input_emb', xs)
            xs = self.dropout.FProp(theta.dropout, xs)
            ps = paddings
            # Now the rnn layers.
            outputs_list = []
            for i in range(0, p.num_lstm_layers):
                layer = self.rnn[i]
                ys = layer.FProp(theta.rnn[i],
                                 xs,
                                 ps,
                                 segment_id=src_segment_id)
                ys = self.dropout.FProp(theta.dropout, ys)
                if i >= p.residual_start:
                    xs += ys  # Residual skip
                    xs = self.ApplyClipping(theta, xs)
                else:
                    xs = ys
                outputs_list.append(xs)
                summary_utils.histogram('layer_out_%s' % i, xs)

            if p.is_transparent:
                xs = self.transparent_merger.FProp(theta.transparent_merger,
                                                   outputs_list)

            if p.lstm_cell_size * 2 != p.encoder_out_dim:
                # Project to the right depth.
                xs = self.final_proj.FProp(theta.final_proj, xs, ps)
                summary_utils.histogram('final_proj_out', xs)

            if src_segment_id is not None:
                src_segment_id = tf.squeeze(src_segment_id, [2])

            return py_utils.NestedMap(encoded=xs,
                                      padding=tf.squeeze(ps, [2]),
                                      segment_id=src_segment_id)
  def FProp(self, theta, prepared_inputs, step_inputs, padding, state0):
    """Produces a context vector from the attention algorithm.

    The context vector is a summary of the inputs from external_inputs
    which the attention algorithm has determined would be useful for decoding
    the next output.

    Args:
      theta: A NestedMap containing weights' values of this layer and its
        children layers.
      prepared_inputs: A set of encoded tensors that have been pre-processed by
        PrepareExternalInputs.
      step_inputs: A NestedMap containing an 'inputs' tensor with the query
        vector to use.
      padding: A [batch, 1] 0/1 float tensor, where 1.0 means that this batch
        slot is not used.
      state0: A NestedMap of state, either produced by ZeroState or a previous
        invocation of this graph.

    Returns:
      output, state1, defined as follows:
      - output: a NestedMap containing a query tensor, a context tensor, and
        cum_atten_probs, the log of attention probabilities for each input
        vector.
      - state1: a NestedMap of state to be used in subsequent invocations of
        this graph.
    """
    (new_atten_context, new_atten_probs,
     new_atten_states) = self.atten.ComputeContextVectorWithSource(
         theta.atten,
         prepared_inputs.packed_src,
         tf.concat(step_inputs.inputs, axis=1),
         attention_state=state0.atten_state)
    new_atten_probs = py_utils.ApplyPadding(padding, new_atten_probs)
    output = py_utils.NestedMap(
        context=new_atten_context, probs=new_atten_probs)
    state1 = py_utils.NestedMap(
        atten_context=new_atten_context, atten_state=new_atten_states)
    return output, state1
Ejemplo n.º 4
0
 def FnMeta(*shapes):
   """A lambda tuple(tshape.Shape) -> NestedMap{flops, out_shapes}."""
   if fn_out:
     out_shapes = fn_out(*shapes)
     if isinstance(out_shapes, tshape.Shape):
       out_shapes = (out_shapes,)
   else:
     out_shapes = shapes
   if fn_flops:
     flops = fn_flops(*shapes)
   else:
     flops = sum([s.size for s in shapes])
   return py_utils.NestedMap(flops=flops, out_shapes=out_shapes)
 def ParseAndProcess(*cols):
     """Parses a Tensorflow example into features."""
     # Assume either one or two column input. If one, then the record is
     # assumed to be that column. If 2, then it is assumed to be a KV store
     # and the record is the second column.
     assert len(cols) in [
         1, 2
     ], ('BaseExampleInputGenerator supports one or two column input')
     record = cols[-1]
     feature_spec = self.GetFeatureSpec()
     features = py_utils.NestedMap(
         tf.io.parse_example(record, feature_spec))
     return self._PreprocessInputBatch(features)
  def FProp(self, theta, prepared_inputs, step_inputs, padding, state0):
    """Produces a query vector and a context vector for the next decoder step.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      prepared_inputs: A set of encoded tensors that have been pre-processed by
        PrepareExternalInputs.
      step_inputs: Unused. All of the input for this step comes from
        external_inputs and previous step state.
      padding: A [batch, 1] 0/1 float tensor, where 1.0 means that this batch
        slot is not used.
      state0: A NestedMap of state, either produced by ZeroState or a previous
        invocation of this graph.

    Returns:
      output, state1, are defined as follows.
      output, a NestedMap containing an atten_query tensor,
      an atten_context tensor, and atten_probs, attention probabilities for each
      input vector.
      state1, a NestedMap of state to be used in subsequent invocations of this
      graph.
    """
    query_output, query_state1 = self.query_generator.FProp(
        theta.query_generator, prepared_inputs.query_generator,
        py_utils.NestedMap(inputs=[state0.atten_state.atten_context]), padding,
        state0.query_state)
    atten_input = py_utils.NestedMap(inputs=[query_output.output])
    atten_output, atten_state1 = self.attention.FProp(theta.attention,
                                                      prepared_inputs.attention,
                                                      atten_input, padding,
                                                      state0.atten_state)
    state1 = py_utils.NestedMap(
        atten_state=atten_state1, query_state=query_state1)
    return py_utils.NestedMap(
        atten_context=atten_output.context,
        atten_query=query_output.output,
        atten_probs=atten_output.probs), state1
Ejemplo n.º 7
0
 def FPropMeta(cls, p, *args):
   py_utils.CheckShapes(args)
   input_shapes = [
       None if arg is None else tshape.Shape(arg.get_shape().as_list()[1:])
       for arg in args
   ]
   meta = p.body.cls.FPropMeta(p.body, *input_shapes)
   py_utils.CheckShapes(meta.out_shapes)
   total = meta.flops * p.repeat
   out_shapes = [
       None if s is None else tshape.Shape([p.repeat] + s[:])
       for s in meta.out_shapes
   ]
   return py_utils.NestedMap(flops=total, out_shapes=tuple(out_shapes))
Ejemplo n.º 8
0
  def _ConsumeMap(self):
    """Return the NestedMap that starts at the current position, and increment.

    Returns:
      The NestedMap that starts at the current token position.
    """
    if self._i >= len(self._tokens):
      raise ValueError('Ran out of tokens while looking for a NestedMap.')
    if self._tokens[self._i] != '(':
      raise ValueError('Expected ( at token position %d' % (self._i))
    self._i += 1
    if self._MaybeConsumeSymbol(')'):
      # Empty NestedMaps are allowed.
      return py_utils.NestedMap()
    result = py_utils.NestedMap()
    while self._i < len(self._tokens):
      name = self._ConsumeKey()
      self._ConsumeSymbol('=')
      result[name] = self._ConsumeItem()
      if self._MaybeConsumeSymbol(')'):
        return result
      self._ConsumeSymbol(',')
    raise ValueError('Ran out of tokens while looking for end of NestedMap.')
    def _BuildDataSourceWithMetadata(self, task_id=None):
        """Read and return input batch from `p.file_pattern`.

    `p.file_pattern` may be a string file_pattern or a
    list of (file_pattern, weight, [bprop_variable_filter]) tuples.
    bprop_variable_filter is optional. When bprop_variable_filter is used,
    batches will always contain the examples from the same source. Otherwise,
    examples from different sources may be mixed together.
    Args:
      task_id: Host index for partitioning input shards.
    Returns:
      A `.NestedMap` containing

      - data: a tuple of tf.Tensor or `.NestedMap` of
        tf.Tensor same as `self._DataSourceFromFilePattern()`
      - source_selected: a tensor of size [batch_size, number of data sources]
        or None.
      - selected_bprop: a tensor of size [number of data sources] or None.
      - bprop_variable_filters: a list of bprop_variable filters for each source
        or None.

    Raises:
      ValueError: If file_datasource is not set
    """
        p = self.params
        if not p.file_datasource and p.file_pattern:
            # This is a workaround for subclasses which have defined
            # their own data source-like functionality.
            tf.logging.info(
                'Creating data source-like output from class %s using '
                'file_pattern %s', self, p.file_pattern)
            ret = py_utils.NestedMap()
            ret.data = self._DataSourceFromFilePattern(p.file_pattern,
                                                       task_id=task_id)
        else:
            tf.logging.info(
                'Building data source %s with params %s and '
                'file_pattern %s', self.datasource, self.datasource.params,
                p.file_pattern)
            ret = self.datasource.BuildDataSource(
                self._DataSourceFromFilePattern,
                task_id=task_id)  #, task_id=task_id)
        if 'selected_bprop' in ret:
            self._bprop_onehot = ret.selected_bprop
        if 'bprop_variable_filters' in ret:
            self._bprop_variable_filters = ret.bprop_variable_filters
        if 'source_selected' not in ret:
            ret.source_selected = None
        return ret
Ejemplo n.º 10
0
    def FPropWithProjectedInput(self, theta, state0, inputs):
        """FProp with inputs already projected.

    This method is for parallelizing the input projection across time steps to
    accelerate training.

    The following are equivalent:

    >>> inputs = <a tensor of [T, B, D]>
    >>> paddings = tf.zeros([T, B])
    >>> theta = cell.theta
    >>> state = cell.zero_state(theta, B)

    # a. Use FProp().
    >>> for i in range(T):
    ...  state, _ = cell.FProp(theta, inputs[i, :, :], paddings, state)

    # b. Use FPropWithProjectedInput().
    >>> proj_inputs = cell.ProjectInputSequence(theta, inputs)
    >>> for i in range(T):
    ...  state, _ = cell.FPropWithProjectedInputs(
    ...    theta, proj_inputs[i, :, :], paddings, state)

    Args:
      theta: a NestedMap of layer weights. Notably, it's expected to contain
        separate weight tensors for input and hidden state projections, for
        performance reasons, under the key 'wm_i' (input) and 'wm_h' (hidden
        state).
      state0: A NestedMap with the same structure as return value of
        `self.zero_state()`.
      inputs: A NestedMap with the following fields:
        - proj_inputs: A single Tensors of shape [batch, 4 * hidden_dim].
        - padding: A Tensor of shape [batch, 1].
        - reset_mask: A Tensor of shape [batch, 1].

    Returns:
      state1: A NestedMap of the same structure as `state0`.
      extras: Intermediate results to facilitate backprop. A NestedMap.
    """
        if self.params.reset_cell_state:
            state0_modified = self._ResetState(state0.DeepCopy(), inputs)
        else:
            state0_modified = state0
        xmw = self._MixWithProjectedInput(theta, state0_modified,
                                          inputs.proj_inputs)
        gates_input = inputs.copy()
        gates_input.act = [inputs.proj_inputs]
        state1 = self._Gates(xmw, theta, state0_modified, gates_input)
        return state1, py_utils.NestedMap()
Ejemplo n.º 11
0
 def PrepareExternalInputs(self, theta, external_inputs):
     """Returns the prepared external inputs, e.g., packed_src for attention."""
     if not external_inputs:
         external_inputs = py_utils.NestedMap()
     packed = external_inputs.DeepCopy()
     for name, child in six.iteritems(self.children):
         child_external_inputs = external_inputs.get(
             name, py_utils.NestedMap())
         if isinstance(child, (tuple, list)):
             output = []
             for i, sub in enumerate(child):
                 if isinstance(sub, Step):
                     output.append(
                         sub.PrepareExternalInputs(theta[name][i],
                                                   child_external_inputs))
             if output:
                 if len(output) != len(child):
                     raise ValueError(
                         'Expecting child list to be instances of Step.')
                 packed[name] = type(child)(output)
         elif isinstance(child, Step):
             packed[name] = child.PrepareExternalInputs(
                 theta[name], child_external_inputs)
     return packed
def _common_gpipe_transformer_fprop_meta(p, inputs, *args):
    """GPipe FPropMeta function."""
    # TODO(huangyp): return accurate estimate of flops.
    py_utils.CheckShapes((inputs, ))
    flops_per_element = 5
    src_time, source_batch, dim = inputs
    flops = flops_per_element * src_time * src_time * source_batch * dim
    args = args if isinstance(args, tuple) else (args, )
    if not p.has_aux_atten and p.is_transparent:  # Transparent Encoder FPropMeta
        if p.transparent_merger_tpl is not None:
            args = args[:5] + (
                inputs, tshape.Shape([p.transparent_merger_tpl.num_sources]))
        args = args[:6] + (tshape.Shape([args[6][0] - 1]), )
        if p.final_enc_layer:
            args = args[:5] + (None, None)
    return py_utils.NestedMap(flops=flops, out_shapes=(inputs, ) + args)
Ejemplo n.º 13
0
 def _DecodeStep():
   """Decode call to be compiled for TPU."""
   with py_utils.OpportunisticVariableReuseScope(True):
     with cluster_factory.SetEval(True):
       self._decode_model = self._decode_task_params.Instantiate()
       self._decode_model_task = self._decode_model.GetTask()
       self._decode_model_task.AddChild('input', self._decode_input)
       input_batch = self._decode_model_task.input_generator.TpuDequeueBatch(
       )
       metrics_dict = self._decode_model_task.Decode(input_batch)
       self.metrics_nm = py_utils.NestedMap(metrics_dict)
       device = tpu.core(0) if self.spmd else ''
       with tf.device(device):
         outfeed_enqueue = tpu_ops.outfeed_enqueue_tuple(
             self.metrics_nm.Flatten())
         return [outfeed_enqueue]
Ejemplo n.º 14
0
 def FPropMeta(cls, p, *args):
     assert len(args) > p.num_act_inputs
     seq_args = args[:-p.num_act_inputs] if p.num_act_inputs > 0 else args
     extra_args = args[-p.num_act_inputs:] if p.num_act_inputs > 0 else ()
     total = 0
     act_fetch_metas = {}
     for sub in p.sub:
         meta = sub.cls.FPropMeta(sub, *seq_args)
         if sub.name in p.act_fetch_layers:
             act_fetch_metas[sub.name] = meta.out_shapes[0]
         total += meta.flops
         seq_args = meta.out_shapes
     for fetch_layer in p.act_fetch_layers:
         extra_args += (act_fetch_metas[fetch_layer], )
     return py_utils.NestedMap(flops=total,
                               out_shapes=seq_args + extra_args)
Ejemplo n.º 15
0
    def ZeroState(self, theta, prepared_inputs, batch_size):
        """Creates a zero state NestedMap for this step.

    Args:
      theta: variables used by sub-steps.
      prepared_inputs: Output from a call to PrepareExternalInputs.
      batch_size: The number of items in the batch that FProp will process.

    Returns:
      A NestedMap of ZeroState results for each sub-step.
    """
        state0 = py_utils.NestedMap()
        with tf.name_scope(self.params.name):
            for seq in self._seq:
                state0[seq.name] = seq.step.ZeroState(
                    theta[seq.name], prepared_inputs[seq.name], batch_size)
        return state0
Ejemplo n.º 16
0
    def BuildDataSource(self, data_source_from_file_pattern_fn):
        """Read and return input batch from p.file_patterns list weighted by p.weights.

    Examples in the batch will be mixed together from different file_pattern
    source proportionally to the weights.

    Args:
      data_source_from_file_pattern_fn: a function that takes file_pattern and
        input_source_weights as arguments and returns an input batch from a
        string file_pattern.

    Returns:
      A NestedMap containing: data: a tuple of tf.Tensor or `.NestedMap` of
      tf.Tensor

    Raises:
      ValueError: If unknown token type.
    """
        p = self.params
        if not isinstance(p.file_patterns, list):
            raise ValueError('Expected a list, got %s' % (p.file_patterns, ))
        if not isinstance(p.weights, list):
            raise ValueError('Expected a list, got %s' % (p.weights, ))
        if len(p.file_patterns) != len(p.weights):
            raise ValueError(
                'Expected p.file_patterns and p.weights to be the same length. '
                'Found %d file_patterns, and %d weights' %
                (len(p.file_patterns), len(p.weights)))
        # TODO(rosenberg) confirm that weights are numeric
        if not all(isinstance(x, six.string_types) for x in p.file_patterns):
            raise ValueError(
                'Expected all elements of p.file_patterns to be strings')

        file_patterns = p.file_patterns
        weights = p.weights
        for file_pattern in file_patterns:
            if ',' in file_pattern:
                raise ValueError(
                    'Can not use commas in file_pattern when within-batch '
                    'mixing is used. file_pattern: %s' % (file_pattern, ))
        ret = py_utils.NestedMap()
        ret.data = data_source_from_file_pattern_fn(
            ','.join(file_patterns), input_source_weights=weights)
        ret.bprop_variable_filters = [''] * len(file_patterns)
        return ret
Ejemplo n.º 17
0
    def FProp(self, theta, prepared_inputs, step_inputs, padding, state0):
        """Returns a A single inference step for this step graph.

    Args:
      theta: unused.
      prepared_inputs: Output from a call to PrepareExternalInputs.
      step_inputs: unused.
      padding: unused.
      state0: A NestedMap of state variables produced by either ZeroState or a
        previous invocation of this FProp step.

    Returns:
      (output, state1), both of which are NestedMaps.
      output is implementation-dependent and is defined by the output_signature
      parameter.
      state1 is a NestedMap where the keys are names of sub-steps and the values
      are state outputs from their FProp methods.
    """
        del theta
        del step_inputs
        del padding

        def _Slice(tensor):
            """Return a slice of this tensor at time=state0.t."""
            shape = py_utils.GetShape(tensor)
            # All zeros except for t in the time dimension.
            # e.g. if params.axis=1, begin is [0, t, 0, 0, 0, ...]
            begin = tf.one_hot(self.params.axis,
                               tf.rank(tensor),
                               on_value=state0.t)
            # Same as shape, but with a 1 in the time dimension.
            # e.g. if params.axis=1, shape is [shape[0], 1, shape[2], shape[3], ...]
            size = tf.concat([
                shape[0:self.params.axis],
                tf.constant([1], dtype=tf.int32), shape[self.params.axis + 1:]
            ],
                             axis=0)
            # Make a slice where the time dimension is fixed at state0.t.
            time_slice = tf.slice(tensor, begin, size)
            # Remove the time dimension.
            return tf.squeeze(time_slice, axis=self.params.axis)

        output = prepared_inputs.Transform(_Slice)
        state1 = py_utils.NestedMap(t=state0.t + 1)
        return output, state1
Ejemplo n.º 18
0
    def PrepareExternalInputs(self, theta, external_inputs):
        """Delegates external inputs preparation to sub-layers.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      external_inputs: A `.NestedMap` object. The structure of the internal
        fields is defined by the sub-steps.

    Returns:
      A `.NestedMap` containing a pre-processed version of the external_inputs,
      one per sub-step.
    """
        packed = py_utils.NestedMap(sub=[])
        for i in range(len(self.sub)):
            packed.sub.append(self.sub[i].PrepareExternalInputs(
                theta.sub[i], external_inputs))
        return packed
Ejemplo n.º 19
0
    def ZeroState(self, theta, prepared_inputs, batch_size):
        """Computes a zero state for each sub-step.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      prepared_inputs: An output from PrepareExternalInputs.
      batch_size: The number of items in the batch that FProp will process.

    Returns:
      A `.NestedMap` containing a state0 object for each sub-step.
    """
        state = py_utils.NestedMap(sub=[])
        for i in range(len(self.sub)):
            state.sub.append(self.sub[i].ZeroState(theta.sub[i],
                                                   prepared_inputs,
                                                   batch_size))
        return state
Ejemplo n.º 20
0
    def _CellFn(unused_theta, state0, theta_i):
      """Recurrent cell function wrapper of body.FProp."""
      # Retrieves fprop arguments from state and sets shapes.
      frop_inputs = _StateToArgs(state0)

      # Sets shapes for theta_i as well.
      for dst, src in zip(theta_i.Flatten(), theta_stack.Flatten()):
        if src is not None:
          dst.set_shape(tf.TensorShape(src.shape.as_list()[1:]))

      # Runs the actual body.FProp
      frop_outputs = self.body.FProp(theta_i, *frop_inputs)
      frop_outputs = _ToTuple(frop_outputs)
      assert len(frop_outputs) == len(frop_inputs)

      # Passes fprop outputs to the next layer through state.
      state1 = _ArgsToState(frop_outputs)
      return state1, py_utils.NestedMap()
 def FPropMeta(cls, p, inputs, *args):
     # TODO(ankurbpn): return accurate estimate of flops.
     py_utils.CheckShapes((inputs, ))
     flops_per_element = 2  # Is this correct?
     vocab = p.token_emb.vocab_size
     dim = p.token_emb.embedding_dim
     src_dim_0, src_dim_1 = inputs
     flops = flops_per_element * src_dim_0 * src_dim_1 * dim * vocab
     args = args if isinstance(args, tuple) else (args, )
     new_inputs = tshape.Shape([src_dim_0, src_dim_1, dim])
     new_args = list(args)
     if p.add_tgt_embedding_layer:
         tgt_dim_0, tgt_dim_1 = args[1]
         new_args[1] = tshape.Shape([tgt_dim_0, tgt_dim_1, dim])
     if p.ret_task_ids:
         new_args = new_args[:5] + [None, None] + new_args[7:]
     else:
         new_args = new_args[:5] + [None, None]
     new_args = tuple(new_args)
     return py_utils.NestedMap(flops=flops,
                               out_shapes=(new_inputs, ) + new_args)
  def ZeroState(self, theta, prepared_inputs, batch_size):
    """Produce a zero state for this step.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      prepared_inputs: A set of inputs pre-processed by using
        PrepareExternalInputs.
      batch_size: Number of elements in the batched input.

    Returns:
      state0, a state parameter to pass to FProp on its first invocation.
    """
    query_state0 = self.query_generator.ZeroState(
        theta.query_generator, prepared_inputs.query_generator, batch_size)
    atten_state0 = self.attention.ZeroState(theta.attention,
                                            prepared_inputs.attention,
                                            batch_size)
    state0 = py_utils.NestedMap(
        query_state=query_state0, atten_state=atten_state0)
    return state0
Ejemplo n.º 23
0
  def FProp(self, theta, *args):
    p = self.params

    graph_tensors = self._fprop = GraphTensors()
    with tf.name_scope(p.name):
      if len(p.input_endpoints) != len(args):
        raise ValueError(
            'Wrong number of inputs for {}: required={}, provided={}'.format(
                p.name, len(p.input_endpoints), len(args)))
      for n, t in zip(p.input_endpoints, args):
        if isinstance(t, py_utils.NestedMap):
          assert all(isinstance(x, tf.Tensor) for x in t.Flatten()), t
        else:
          assert isinstance(t, tf.Tensor)
        graph_tensors.StoreTensor(n, t)

      ch_out = None
      for i, (name, sig, ch) in enumerate(self._seq):
        th = theta[name]
        template = py_utils.NestedMap(inputs=sig.inputs)
        packed = template.Transform(graph_tensors.GetTensor)
        input_args = packed.inputs
        tf.logging.vlog(1, 'signature: %s', p.sub[i][0])
        tf.logging.vlog(1, 'GraphLayer: call %s %s %d %s', ch.params.name,
                             ch, len(input_args), str(input_args))
        ch_out = ch.FProp(th, *input_args)
        if len(sig.outputs) == 1:
          ch_out = (ch_out,)
        assert len(sig.outputs) == len(ch_out)
        for n, t in zip(sig.outputs, ch_out):
          graph_tensors.StoreTensor(n, t)

      layer_out = tuple(graph_tensors.GetTensor(x) for x in p.output_endpoints)
      if len(layer_out) == 1:
        layer_out = layer_out[0]

      return layer_out
  def ZeroState(self, theta, prepared_inputs, batch_size):
    """Produce a zero state for this step.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      prepared_inputs: A set of inputs pre-processed by using
        PrepareExternalInputs.
      batch_size: Number of elements in the batched input.

    Returns:
      state0, a state parameter to pass to FProp on its first invocation.
    """
    max_seq_length = py_utils.GetShape(prepared_inputs.src, 3)[0]
    atten_state = self.atten.ZeroAttentionState(max_seq_length, batch_size)
    (new_atten_context, _,
     new_atten_states) = self.atten.ComputeContextVectorWithSource(
         theta.atten,
         prepared_inputs.packed_src,
         tf.zeros([batch_size, self.params.atten.query_dim],
                  dtype=py_utils.FPropDtype(self.params)),
         attention_state=atten_state)
    return py_utils.NestedMap(
        atten_context=new_atten_context, atten_state=new_atten_states)
Ejemplo n.º 25
0
    def FProp(self, theta, prepared_inputs, step_inputs, padding, state0):
        """Perform inference on a stateless layer.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      prepared_inputs: unused.
      step_inputs: A NestedMap containing 'inputs', which are passed directly to
        the layer.
      padding: A 0/1 float tensor of shape [batch_size]; 1.0 means that this
        batch element is empty in this step.
      state0: unused.

    Returns:
      (output, state1), where output is the output of the layer, and
      state1 is an empty NestedMap.
    """
        del state0
        del prepared_inputs
        args = {}
        if padding is not None:
            args['padding'] = padding
        output = self.layer.FProp(theta.layer, step_inputs.inputs, **args)
        return output, py_utils.NestedMap()
 def FPropMeta(cls, p, inputs, *args):
     dim1, dim2 = args[1][:2] if p.inputs_from_decoder else inputs[:2]
     logits = tshape.Shape([dim1, dim2, p.num_classes])
     return py_utils.NestedMap(flops=100, out_shapes=(logits, ))
Ejemplo n.º 27
0
    def FProp(self,
              theta,
              x,
              x_paddings=None,
              eos_id=1,
              force_sample_last_token=True):
        """Applies SymbolInsertionLayer.

    We take in a `x`, which represents the groundtruth sequence (i.e., English
    sequence). We return a sampled rollin (observed) canvas (i.e., random subset
    of the English sequence), as well as the target (indices) for an
    insertion-based model (i.e., the targets given the random observed subset).

    Args:
      theta: Ignored, this can be None.
      x: The symbol ids of shape `[batch_size, time_dim]`.
      x_paddings: The paddings (1 or 0) of shape `[batch_size, time_dim]` where
        0 is valid and 1 is invalid.
      eos_id: The <eos> token id to represent end-of-slot.
      force_sample_last_token: Set True to force sample the last token of `x`.

    Returns:
      A `NestedMap`.
        - canvas: The canvas (based off of the `rollin_policy`) of shape
          [batch_size, c_dim]. Note that, `c_dim` <= `time_dim` but need not be
          equal.
        - canvas_indices: The canvas indices (into `x`).
        - canvas_paddings: The paddings of `canvas_indices`.
        - target_indices: The target indices of shape [num_targets, 3].
          `num_targets` is the number of total targets in the entire batch.
          [:, 0] captures the batch, [:, 1] captures the slot, and [:, 2]
          captures the token. Each row [batch, slot, vocab] represents the
          indices of the target -- i.e., the batch, slot and vocab combination
          of the target. Typical usage of these indices is to tf.gather_nd
          the log-probs (from the softmax layer).
        - target_weights: The target weights.

    Raises:
      ValueError: If invalid params.
    """
        p = self.params

        batch_size = py_utils.GetShape(x)[0]
        time_dim = py_utils.GetShape(x)[1]

        if x_paddings is None:
            x_paddings = tf.zeros([batch_size, time_dim], tf.float32)

        oracle_policy = p.oracle_policy
        rollin_policy = (oracle_policy
                         if p.rollin_policy == 'oracle' else p.rollin_policy)

        if rollin_policy != 'uniform':
            raise ValueError('Unknown or unsupported rollin policy: %s' %
                             rollin_policy)
        if oracle_policy != 'uniform':
            raise ValueError('Unknown or unsupported oracle policy: %s' %
                             oracle_policy)

        x_len = tf.cast(tf.round(tf.reduce_sum(1 - x_paddings, 1)), tf.int32)

        # Compute the desired length per example in the batch.
        ratio = tf.random.uniform([batch_size], 0.0, 1.0, seed=p.random_seed)
        if force_sample_last_token:
            c_len = tf.minimum(
                tf.cast(ratio * tf.cast(x_len, tf.float32), tf.int32),
                x_len - 1) + 1
        else:
            c_len = tf.minimum(
                tf.cast(ratio * tf.cast(x_len + 1, tf.float32), tf.int32),
                x_len)
        # Compute the maximum length across the batch.
        c_len_max = tf.reduce_max(c_len)

        # Grab subset of random valid indices per example.
        z_logits = tf.cast(
            tf.expand_dims(tf.range(time_dim), 0) >= tf.expand_dims(x_len, 1),
            tf.float32) * -1e9
        if force_sample_last_token:
            # Force sample the last token -- i.e., as indexed by `x_len - 1`. We can
            # accomplish this by add +LARGE_NUMBER to the logits.
            z_logits += tf.cast(
                tf.equal(tf.expand_dims(tf.range(time_dim), 0),
                         tf.expand_dims(x_len - 1, 1)), tf.float32) * 1e9
        # Gumbel-max trick to sample (we only sample valid positions per sample in
        # the batch).
        z = -tf.math.log(-tf.math.log(
            tf.random.uniform([batch_size, time_dim], seed=p.random_seed)))
        unused_c_values, c_indices = tf.nn.top_k(z_logits + z, time_dim)

        # Trim everything > c_len_max.
        c_indices = c_indices[:, :c_len_max]

        # Invalidate any indices >= c_len, we use the last index as the default
        # invalid index.
        c_indices = tf.where(
            tf.expand_dims(tf.range(c_len_max), 0) < tf.expand_dims(c_len, 1),
            c_indices, tf.fill(py_utils.GetShape(c_indices), time_dim - 1))

        # Materialize the canvas.
        c_indices = tf.sort(c_indices)
        c = tf.gather_nd(
            x,
            tf.stack([
                tf.reshape(
                    tf.tile(tf.expand_dims(tf.range(batch_size), 1),
                            [1, c_len_max]), [-1]),
                tf.reshape(c_indices, [-1])
            ], 1))
        c = tf.reshape(c, [batch_size, c_len_max])

        # Compute the paddings.
        c_paddings = 1 - tf.sequence_mask(
            c_len, c_len_max, dtype=x_paddings.dtype)
        c *= tf.cast(1 - c_paddings, tf.int32)

        indices = tf.concat([
            tf.reshape(
                tf.tile(tf.expand_dims(tf.range(batch_size), 1),
                        [1, c_len_max]), [batch_size * c_len_max, 1]),
            tf.reshape(c_indices, [batch_size * c_len_max, 1])
        ], 1)
        x_token_is_observed = tf.scatter_nd(
            indices, tf.ones([batch_size * c_len_max], tf.int32),
            py_utils.GetShape(x))
        # `x_segments` captures which slot each `x` belongs to (both observed and
        # tokens that need to be observed).
        x_segments = tf.cumsum(x_token_is_observed, 1, exclusive=True)

        x_token_is_observed = tf.cast(x_token_is_observed, tf.bool)
        prev_x_token_is_observed = tf.pad(x_token_is_observed[:, :-1],
                                          [[0, 0], [1, 0]],
                                          constant_values=True)
        x_token_is_observed = tf.reshape(x_token_is_observed, [-1])
        prev_x_token_is_observed = tf.reshape(prev_x_token_is_observed, [-1])
        x_is_valid = tf.cast(1 - x_paddings, tf.bool)
        x_is_valid = tf.reshape(x_is_valid, [-1])

        # Remap all the observed to <eos>, note some of these need a zero weight
        # (or else there would be <eos> and valid token in the same slot).
        target_indices = tf.cast(tf.reshape(x, [-1, 1]), tf.int32)
        target_indices = tf.where(
            x_token_is_observed,
            tf.fill(py_utils.GetShape(target_indices), eos_id), target_indices)

        # TODO(williamchan): We give uniform 1.0 weight, however, math suggests
        # we may want to weigh this term by the original sequence length.
        target_weights = tf.ones_like(target_indices, tf.float32)

        # We need to set all the weights for <eos> which actually have valid tokens
        # in the slot to zero.
        target_weights = tf.where(
            x_token_is_observed & ~prev_x_token_is_observed,
            tf.zeros_like(target_weights), target_weights)

        # TODO(williamchan): Consider dropping the entries w/ weight zero.

        # Add the batch and slot indices.
        target_indices = tf.concat([
            tf.reshape(
                tf.tile(tf.expand_dims(tf.range(batch_size), 1),
                        [1, time_dim]), [batch_size * time_dim, 1]),
            tf.reshape(x_segments, [-1, 1]), target_indices
        ], 1)

        # Select only the valid indices. The selected valid ones include slots w/
        # <eos>.
        target_indices = target_indices[x_is_valid]
        target_weights = target_weights[x_is_valid]

        return py_utils.NestedMap(canvas=c,
                                  canvas_indices=c_indices,
                                  canvas_paddings=c_paddings,
                                  target_indices=target_indices,
                                  target_weights=target_weights)
Ejemplo n.º 28
0
 def __init__(self):
   self._named_tensors = py_utils.NestedMap()
Ejemplo n.º 29
0
 def FPropMeta(cls, p, *args):
   py_utils.CheckShapes(args)
   meta = p.body.cls.FPropMeta(p.body, *args)
   py_utils.CheckShapes(meta.out_shapes)
   total = meta.flops * p.repeat
   return py_utils.NestedMap(flops=total, out_shapes=args)
Ejemplo n.º 30
0
 def FPropMeta(cls, p, *args):
   py_utils.CheckShapes(args)
   return py_utils.NestedMap(flops=0, out_shapes=args)