Ejemplo n.º 1
0
 def FProp(self, theta, inputs):
     p = self.params
     with tf.name_scope(p.name):
         vec = self.conv.FProp(theta.conv, inputs.vec)
         return py_utils.NestedMap(vec=vec, paddings=inputs.paddings)
Ejemplo n.º 2
0
 def FProp(self, theta, external_inputs, step_inputs, padding, state0):
     return py_utils.NestedMap(
         output=':'.join(step_inputs.inputs + [external_inputs]) +
         state0), (state0 + ':'.join(step_inputs.inputs))
Ejemplo n.º 3
0
def GenericInput(processor, **kwargs):
    """Builds a generic input pipeline.

  Example usage::

    def ParseRecord(record):
      # Given a tf.string record, return a (NestedMap, bucketing key) pair.
      feature_map = ...
      features = tf.parse_single_example(record, feature_map)
      # Each example is represented by a NestedMap of tensors (without a
      # batch dimension).
      example = py_utils.NestedMap(field1=..., field2=...)
      # bucketing_key is a scalar convertible to tf.int32.
      # Use 1 if all examples are of the same size.
      bucketing_key = 1
      return example, bucketing_key

    input_batch, bucket_keys = GenericInput(ParseRecord, file_pattern=..., ...)
    # input_batch is a NestedMap of tensors, where dim 0 of each tensor
    # represents the batch dimension.
    input_batch.field1 = ...

  ParseRecord can also take both 'source_id' and 'record' as inputs (the arg
  names must be exactly 'source_id' and 'record'):

    def ParseRecord(source_id, record):
      # Given a tf.int32 source_id and a tf.string record, return a (NestedMap,
      # bucketing key) pair.
      example = py_utils.NestedMap(source_id=source_id, ...)
      ...
      return example, bucketing_key

    input_batch, bucket_keys = GenericInput(ParseRecord, file_pattern=..., ...)

  Args:
    processor: a function that takes either a tf.string record or a
      (source_id: tf.int32, record: tf.string) pair as input and returns a
      tuple (output, bucketing_key).
      `output` must be a NestedMap or a list of tensors representing an example.
      `bucketing_key` must be a scalar convertible to a tf.int32 tensor that
      represents the bucketing key (e.g., sequence length for sequence inputs).
      If `bucketing_key` is a negative number, the record is dropped.
    **kwargs: additional keyword args for x_ops.generic_input.

  Returns:
    A tuple of (outputs, bucket_keys):

    - outputs: a NestedMap or a list of tensors, similar to `processor`'s
      return,  except every tensor will have an additional dimension 0 that
      represents the batch dimension.
    - bucket_keys: a tf.int32 vector.
  """
    output_tmpl = py_utils.NestedMap()

    def _FlatOutputProcessor(source_id, record):
        """Returns a flattened list of 'processor(inputs)'."""
        processor_spec = tf_inspect.getargspec(processor)
        tf.logging.debug('GenericInput.processor.argspec=%s', processor_spec)
        processor_args = set(processor_spec.args) - set(['self'])
        if len(processor_args) == 1:
            output, bucketing_key = processor(record)
        elif processor_args == set(['source_id', 'record']):
            output, bucketing_key = processor(source_id=source_id,
                                              record=record)
        else:
            raise ValueError(
                'GenericInput: processor should take either a single arg '
                'or two args named as "source_id" and "record". '
                'Actual: %s' % processor_args)
        if isinstance(output, list):
            assert output
            assert all(isinstance(x, tf.Tensor)
                       for x in output), '{}'.format(output)
        else:
            assert isinstance(output, py_utils.NestedMap), '{}'.format(output)
            assert output
            assert all(isinstance(x, tf.Tensor)
                       for x in output.Flatten()), '{}'.format(
                           output.DebugString())
        bucketing_key = tf.cast(bucketing_key, tf.int32)
        tf.logging.debug('Processor outputs=%s bucketing_key=%s', output,
                         bucketing_key)
        output_tmpl.out_values = output
        flat_output_tmpl = output_tmpl.Flatten()
        tf.logging.debug('Processor flat outputs=%s', flat_output_tmpl)
        tf.logging.debug('extra_inputs=%s extra_args=%s extra_vars=%s',
                         function.get_extra_inputs(),
                         function.get_extra_args(), function.get_extra_vars())
        assert not function.get_extra_args(), (
            'fns {} is not pure: extra_args={}'.format(
                processor, function.get_extra_args()))
        return flat_output_tmpl + [bucketing_key]

    proc_fn = tf.Defun(tf.int32, tf.string)(_FlatOutputProcessor)

    out_types = [
        tf.DType(a.type) for a in proc_fn.definition.signature.output_arg
    ]
    assert out_types[-1] == tf.int32, ('%s is not expected.' % out_types[-1])
    flat_outputs, bucket_keys = ops.gen_x_ops.generic_input(
        processor=proc_fn, out_types=out_types[:-1], **kwargs)
    tf.logging.debug('x_ops.generic_input flat_outputs=%s', flat_outputs)
    # Pack flat_outputs to outputs.
    outputs = output_tmpl.Pack(flat_outputs).out_values
    tf.logging.debug('x_ops.generic_input outputs=%s', outputs)
    return outputs, bucket_keys
Ejemplo n.º 4
0
    def Decode(self, input_batch):
        """Decode an input batch, computing predicted bboxes from residuals."""
        p = self.params

        predictions = self.ComputePredictions(self.theta, input_batch)
        bboxes_and_logits = self._BBoxesAndLogits(input_batch, predictions)
        predicted_bboxes = bboxes_and_logits.predicted_bboxes
        batch_size, num_bboxes, _ = py_utils.GetShape(predicted_bboxes, 3)
        classification_logits = bboxes_and_logits.classification_logits
        classification_logits = py_utils.HasShape(
            classification_logits, [batch_size, num_bboxes, p.num_classes])

        classification_scores = tf.sigmoid(classification_logits)

        _, per_example_dict = self.ComputeLoss(self.theta, predictions,
                                               input_batch)
        if 'score_scaler' in per_example_dict:
            classification_scores *= per_example_dict['score_scaler']

        with tf.device('/cpu:0'):
            # Decode the predicted bboxes, performing NMS.
            _, per_cls_bboxes, per_cls_bbox_scores, per_cls_valid_mask = (
                detection_decoder.DecodeWithNMS(
                    predicted_bboxes,
                    classification_scores,
                    nms_iou_threshold=p.nms_iou_threshold,
                    score_threshold=p.nms_score_threshold,
                    max_boxes_per_class=p.max_nms_boxes,
                    use_oriented_per_class_nms=p.use_oriented_per_class_nms))

            # per_cls_valid_mask is [batch, num_classes, num_boxes] Tensor that
            # indicates which boxes were selected by NMS. Each example will have a
            # different number of chosen bboxes, so the mask is present to allow us
            # to keep the boxes as a batched dense Tensor.
            #
            # We mask the scores by the per_cls_valid_mask so that none of these boxes
            # will be interpreted as valid.
            per_cls_bbox_scores *= per_cls_valid_mask
            visualization_weights = py_utils.HasShape(
                per_cls_bbox_scores,
                [batch_size, p.num_classes, p.max_nms_boxes])

            # For top down visualization, filter boxes whose scores are not above the
            # visualization threshold.
            visualization_weights = tf.where(
                tf.greater_equal(visualization_weights,
                                 p.visualization_classification_threshold),
                visualization_weights, tf.zeros_like(visualization_weights))

        model_outputs = py_utils.NestedMap()
        model_outputs.per_class_predicted_bboxes = per_cls_bboxes
        model_outputs.per_class_predicted_bbox_scores = per_cls_bbox_scores
        model_outputs.per_class_valid_mask = per_cls_valid_mask

        decoder_outputs = py_utils.NestedMap({
            'per_class_predicted_bboxes':
            per_cls_bboxes,
            'per_class_predicted_bbox_scores':
            per_cls_bbox_scores,
            'per_class_valid_mask':
            per_cls_valid_mask,
            'visualization_weights':
            visualization_weights,
        })

        decoder_outputs.update(
            self.output_decoder.ProcessOutputs(input_batch, model_outputs))

        # Produce global step as an output (which is the step
        # of the checkpoint being decoded.)
        decoder_outputs.global_step = py_utils.GetGlobalStep()

        return decoder_outputs
Ejemplo n.º 5
0
    def FProp(self, theta, input_batch):
        """Embeds source ids and transforms with TransformerStack.

    Args:
      theta: A `.NestedMap` object containing weights' values of this
        layer and its children layers.
      input_batch: A `.NestedMap` with fields:

        - ids: The inputs tensor. It is expected to be of shape [batch, time].
        - paddings: The paddings tensor. Expected shape [batch, time].

    Returns:
      A NestedMap containing:
        - encoded: The encoded features, either a tensor of shape [time, batch,
            depth], or a list of tensors if is_transparent is set in
            transformer_stack.
        - padding: of shape [time, batch]
        - segment_id: [time, batch] if packed inputs are supported by the model
            (and all layers), or None otherwise.
        - embedded_inputs: [time, batch, depth] embedded inputs tokens without
            positional encodings.
    """

        p = self.params
        with tf.name_scope(p.name):
            src_segment_id = None
            src_segment_pos = None
            input_ids = py_utils.with_dependencies([
                py_utils.assert_shape_match(tf.shape(input_batch.ids),
                                            tf.shape(input_batch.paddings)),
                py_utils.assert_equal(tf.rank(input_batch.ids), 2)
            ], input_batch.ids)

            if (not py_utils.use_tpu()
                    and tf.flags.FLAGS.transformer_encoder_truncates_inputs):
                max_seq_length = tf.cast(
                    tf.reduce_max(tf.reduce_sum(1.0 - input_batch.paddings,
                                                1)), tf.int32)
                paddings = py_utils.with_dependencies([
                    py_utils.assert_equal(
                        tf.constant(True, tf.bool),
                        tf.reduce_all(
                            input_batch.paddings[:, max_seq_length:] > 0.5))
                ], input_batch.paddings)
                input_ids = input_ids[:, :max_seq_length]
                paddings = paddings[:, :max_seq_length]
                if p.packed_input:
                    src_segment_id = input_batch.segment_ids[:, :
                                                             max_seq_length]
                    src_segment_pos = input_batch.segment_pos[:, :
                                                              max_seq_length]
            else:
                paddings = input_batch.paddings
                if p.packed_input:
                    src_segment_id = input_batch.segment_ids
                    src_segment_pos = input_batch.segment_pos

            max_time = tf.shape(input_ids)[1]

            # Input token embeddings + positional embeddings
            input_embs = self.token_emb.EmbLookup(theta.token_emb,
                                                  tf.reshape(input_ids, [-1]))
            input_embs = tf.reshape(input_embs,
                                    [-1, max_time, p.token_emb.embedding_dim])
            # [time, batch, dim]
            orig_input_embs = tf.transpose(input_embs, [1, 0, 2])

            if p.packed_input:
                position_embs = self.position_emb.FPropWithPosition(
                    theta.position_emb, src_segment_pos)
            else:
                position_embs = self.position_emb.FProp(
                    theta.position_emb, max_time)
                position_embs = tf.reshape(
                    position_embs, [1, max_time, p.token_emb.embedding_dim])
            input_embs += position_embs
            if p.task_emb:
                input_embs += self.task_emb.EmbLookup(theta.task_emb,
                                                      input_batch.task_ids)

            if p.model_dim != p.token_emb.embedding_dim:
                input_embs = self.emb_proj.FProp(theta.emb_proj, input_embs)

            paddings = tf.transpose(paddings)
            if p.packed_input:
                src_segment_id = tf.transpose(src_segment_id)
            input_embs = self.input_dropout.FProp(theta.input_dropout,
                                                  input_embs)

            # [time, batch, dim]
            transformer_input = tf.transpose(input_embs, [1, 0, 2])

        encoded, padding, segment_id = self.transformer_stack.FProp(
            theta.transformer_stack, transformer_input, paddings,
            src_segment_id)
        return py_utils.NestedMap(encoded=encoded,
                                  padding=padding,
                                  segment_id=segment_id,
                                  embedded_inputs=orig_input_embs)
Ejemplo n.º 6
0
  def BuildTpuSubgraph(self):
    if self._ml_perf_log:
      mlp_log.mlperf_print('global_batch_size', self._ml_perf.global_batch_size)
      mlp_log.mlperf_print('max_sequence_length',
                           self._ml_perf.max_sequence_length)
      mlp_log.mlperf_print('opt_name', self._ml_perf.optimizer_name)
      mlp_log.mlperf_print('opt_base_learning_rate',
                           self._ml_perf.base_learning_rate)
      mlp_log.mlperf_print('opt_learning_rate_warmup_steps',
                           self._ml_perf.warmup_steps)

    with py_utils.OpportunisticVariableReuseScope(True):
      self._eval_metrics = metrics.TpuEvalMetrics()
      data_parallelism = self.data_parallelism

      def TpuTrainStep():
        """Train a shard of a batch on a single TPU core.

        Do not calculate loss metrics.

        Returns:
         [train_op].
        """
        self._train_model = self._train_task_params.Instantiate()
        self._model = self._train_model
        self._train_model.ConstructFPropBPropGraph()
        return [self._train_model.GetTask().train_op]

      @tpu_function.on_device_training_loop
      def TpuTrain():
        loop_result = tpu_training_loop.repeat(
            self._train_steps_per_loop,
            TpuTrainStep,
            inputs=[],
            name='train_loop')
        return loop_result

    py_utils.ResetStepSeed()

    def _DecodeFn():
      """Decode call to be compiled for TPU."""
      with py_utils.OpportunisticVariableReuseScope(True):
        with cluster_factory.SetEval(True):
          self._decode_model = self._decode_task_params.Instantiate()
          self._decode_model_task = self._decode_model.GetTask()
          if py_utils.use_tpu():
            input_batch = self._decode_model_task.input_generator.CreateTpuFeeds(
            )
          else:
            input_batch = self._decode_model_task.input_generator.SplitInputBatch(
                self.cluster.num_splits_per_client)
          metrics_dict = self._decode_model_task.Decode(input_batch)
          self.metrics_nm = py_utils.NestedMap(metrics_dict)
          return self.metrics_nm.Flatten()

    def TrainAndDecode():
      with tf.control_dependencies([TpuTrain()]):
        return _DecodeFn()

    self._compile_op, batch_parallel_res = tpu.split_compile_and_shard(
        TrainAndDecode,
        num_shards=data_parallelism,
        device_assignment=py_utils.GetTpuDeviceAssignment())

    self.metrics = py_utils.NestedMap(self.metrics_nm)
    self.metrics = self.metrics.Pack(batch_parallel_res)
    return None
Ejemplo n.º 7
0
  def _PreBeamSearchStepCallback(self,
                                 theta,
                                 source_encs,
                                 source_paddings,
                                 step_ids,
                                 states,
                                 num_hyps_per_beam,
                                 additional_source_info=None):
    """Returns logits for sampling ids and the next model states.

    Args:
      source_encs: A tensor of shape [src_len, src_batch, source_dim].
      source_paddings: A tensor of shape [src_len, src_batch].
      step_ids: A tensor of shape [tgt_batch, 1].
      states: A `.NestedMap` of tensors representing states that the clients
          would like to keep track of for each of the active hyps.
      num_hyps_per_beam: Beam size.
      additional_source_info: a `.NestedMap` of tensors containing extra context
          information about the source that may be useful for decoding.
    Returns:
      A tuple (results, out_states).
      results: A `.NestedMap` of beam search results.
        atten_probs:
          The updated attention probs, of shape [tgt_batch, src_len].
        log_probs:
          Log prob for each of the tokens in the target vocab. This is of shape
          [tgt_batch, vocab_size].
      out_states: A `.NestedMap`. The updated states.
        rnn_states:
          Last state of the RNN.
        atten_context:
          Updated attention context vector.
        atten_states:
          Updates attention states.
    """
    p = self.params
    # additional_source_info is currently not used.
    del additional_source_info

    prev_rnn_states = states['rnn_states']
    prev_atten_context = states['atten_context']
    prev_atten_probs = states['atten_probs']
    prev_atten_states = states['atten_states']
    step_paddings = tf.zeros(py_utils.GetShape(step_ids), dtype=p.dtype)
    embs = self.emb.EmbLookup(theta.emb, tf.reshape(step_ids, [-1]))
    embs = self.ApplyClipping(theta, embs)
    atten_context, atten_probs, rnn_states, step_out, atten_states = (
        self._DecodeStep(theta, embs, step_paddings, prev_atten_context,
                         prev_rnn_states, prev_atten_states))
    atten_probs = tf.reshape(atten_probs, tf.shape(prev_atten_probs))

    logits = self.softmax.Logits(theta.softmax, [step_out])
    log_probs = self.fns.qlogsoftmax(
        logits, qmin=p.qlogsoftmax_range_min, qmax=0.0)

    if p.use_prev_atten_ctx:
      cur_atten_probs = prev_atten_probs
    else:
      cur_atten_probs = atten_probs

    bs_results = py_utils.NestedMap({
        'atten_probs': cur_atten_probs,  # the probs exposed to beam search
        'log_probs': log_probs,
    })
    new_states = py_utils.NestedMap({
        'rnn_states': rnn_states,
        'atten_context': atten_context,
        'atten_probs': atten_probs,  # the updated attention probs
        'atten_states': atten_states,
    })

    return bs_results, new_states
Ejemplo n.º 8
0
 def FPropFullSequence(self, theta, ids, paddings):
   return self.FProp(theta, py_utils.NestedMap(ids=ids,
                                               paddings=paddings))['encoded']
Ejemplo n.º 9
0
  def FProp(self, theta, input_batch):
    """Embeds source ids and transforms with TransformerStack.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      input_batch: A `.NestedMap` object containing: ids - The inputs tensor of
        shape [batch, time]. paddings - The ids' paddings of shape [batch,
        time].

    Returns:
      A '.NestedMap' object containing:
        encoded - The encoded features of shape [time, batch, dim] or [batch,
          time, dim], depending p.output_data_format.
        padding - The encoded features' padding of shape [time, batch] or
          [batch, time].
        segment_id - The segmentation of packed inputs of shape [time, batch] or
          [batch, time] if it is supported by the model, or None otherwise.
        embedded_inputs - The embedded inputs tokens without positional
          encodings of shape [time, batch, dim] or [batch, time, dim].
    """

    p = self.params
    with tf.name_scope(p.name):
      # [batch, time]
      input_ids = input_batch.ids
      # [batch, time]
      paddings = input_batch.paddings

      # [batch, time]
      segment_ids = input_batch.segment_ids if p.packed_input else None

      batch = py_utils.GetShape(input_ids)[0]
      time = py_utils.GetShape(input_ids)[1]

      # Embedding layer.
      # [batch, time, dim]
      if not p.shared_emb:
        input_embs = self.token_emb.EmbLookup(theta.token_emb, input_ids)
      else:
        input_embs = self.softmax.EmbLookup(theta.softmax, input_ids)
      orig_input_embs = input_embs

      # [1, time, dim]
      if p.packed_input:
        positions = input_batch.segment_pos
        position_embs = tf.expand_dims(
            self.position_emb.FPropWithPosition(theta.position_emb, positions),
            0)
      else:
        position_embs = tf.expand_dims(
            self.position_emb.FProp(theta.position_emb, time), 0)

      # [batch, time, dim]
      input_embs += position_embs

      if p.input_dropout_tpl.fprop_dtype:
        input_embs = tf.cast(input_embs, p.input_dropout_tpl.fprop_dtype)
        paddings = tf.cast(paddings, p.input_dropout_tpl.fprop_dtype)

      input_embs = self.input_dropout.FProp(theta.input_dropout, input_embs)
      # [batch, time, dim]
      transformer_input = input_embs
      # Explicitly set the input shape of Transformer layers, to avoid
      # unknown shape error occurred to tf.einsum on nonTPU devices.
      transformer_input = tf.reshape(transformer_input,
                                     [batch, time, p.model_dim])

      # Compute self-attention segment mask once.
      if p.packed_input:
        segment_mask = batch_major_attention.SegmentMask(
            segment_ids, segment_ids, dtype=transformer_input.dtype)
      else:
        segment_mask = tf.zeros([batch, 1, time, time])

      shape = py_utils.GetShape(transformer_input)
      batch_size = shape[0]
      seq_len = shape[1]
      paddings = tf.reshape(paddings, [batch_size, seq_len])
      encoded, padding = self.transformer_stack.FProp(theta.transformer_stack,
                                                      transformer_input,
                                                      paddings, segment_mask)

      if p.final_layer_norm:
        encoded = self.final_ln.FProp(theta.final_ln, encoded)

      seq_lengths = tf.cast(tf.reduce_sum(1. - padding, axis=1), tf.int32)

      if p.output_data_format == 'TBC':
        encoded = tf.transpose(encoded, [1, 0, 2])  # [time, batch, dim]
        padding = tf.transpose(padding)  # [time, batch]
        segment_ids = tf.transpose(segment_ids) if p.packed_input else None
        orig_input_embs = tf.transpose(orig_input_embs, [1, 0, 2])

      return py_utils.NestedMap(
          encoded=encoded,
          padding=padding,
          seq_lengths=seq_lengths,  # used by beam_search_helper.
          segment_id=segment_ids,
          embedded_inputs=orig_input_embs)
Ejemplo n.º 10
0
 def MulSumFnMeta(x):
     return py_utils.NestedMap(flops=2, out_shapes=(x, ))
Ejemplo n.º 11
0
 def AddFnMeta(x, y):
     del y
     return py_utils.NestedMap(flops=2, out_shapes=(x, ))
Ejemplo n.º 12
0
 def _FnMeta(*shapes):
     return py_utils.NestedMap(flops=1, out_shapes=shapes)
Ejemplo n.º 13
0
 def testGraphTensors(self):
     graph_tensors = layers.GraphTensors()
     graph_tensors.StoreTensor(
         't', py_utils.NestedMap(a=py_utils.NestedMap(b='c')))
     self.assertEqual('c', graph_tensors.GetTensor('t.a.b'))
Ejemplo n.º 14
0
 def FPropMeta(cls, p, inputs):
     py_utils.CheckShapes(
         tuple(inputs.Filter(lambda x: x is not None).Flatten()))
     return py_utils.NestedMap(flops=1, out_shapes=(inputs, ))
Ejemplo n.º 15
0
  def ReverseAndGrad(self, theta, outputs, d_outputs, f_seed, g_seed,
                     *extra_inputs):
    """Implements Algorithm 1 in the revnet paper.

    Args:
      theta: A NestedMap object containing weights' values of this layer and its
        children layers.
      outputs: A NestedMap: .split1 and .split2 corresponding to y1 and y2.
      d_outputs: A NestedMap: .split1 and .split2 corresponding to dy1 and dy2,
        the total derivatives.
      f_seed: Scalar tensor. The step seed used in forward for the f block.
      g_seed: Scalar tensor. The step seed used in forward for the g block. The
        step seeds are needed for deterministic randomness, e.g. to ensure
        dropout generate the same random mask in forward and reverse_grad.
      *extra_inputs: additional inputs that will be passed to both f and g. No
        gradient will be computed for these inputs.

    Returns:
      A tuple of NestedMaps

      - inputs: .split1 and .split2 corresponding to x1 and x2.
      - d_inputs: .split1 and .split2 corresponding to dx1 and dx2, the total
        derivatives with respect to inputs.
      - d_theta: has the same structure as theta. The total derivatives with
        respect to weights.

    """

    # Stop gradient on the outputs to avoid circular symbolic dependency.
    y1 = tf.stop_gradient(outputs.split1)
    y2 = tf.stop_gradient(outputs.split2)
    dy1 = d_outputs.split1
    dy2 = d_outputs.split2

    # Computes the reverse.
    z1 = y1
    py_utils.ResetStepSeed(g_seed)
    gz1 = self.g_block.FProp(theta.g_block, z1, *extra_inputs)
    x2 = y2 - gz1
    py_utils.ResetStepSeed(f_seed)
    fx2 = self.f_block.FProp(theta.f_block, x2, *extra_inputs)
    x1 = z1 - fx2

    # Computes the gradients.
    dz1 = dy1 + tf.gradients(gz1, z1, dy2)[0]
    dx2 = dy2 + tf.gradients(fx2, x2, dz1)[0]

    dgw = tf.gradients(
        gz1,
        theta.g_block.Flatten(),
        dy2,
        unconnected_gradients=tf.UnconnectedGradients.ZERO)
    dgw = theta.g_block.Pack(dgw)

    dfw = tf.gradients(
        fx2,
        theta.f_block.Flatten(),
        dz1,
        unconnected_gradients=tf.UnconnectedGradients.ZERO)
    dfw = theta.f_block.Pack(dfw)

    return (py_utils.NestedMap(split1=x1, split2=x2),
            py_utils.NestedMap(split1=dz1, split2=dx2),
            py_utils.NestedMap(
                f_block=dfw,
                g_block=dgw,
                global_step=tf.zeros_like(theta.global_step)))
Ejemplo n.º 16
0
 def InitBeamSearchCallBack(unused_theta, unused_encoder_outputs,
                            unused_num_hyps_per_beam):
     return py_utils.NestedMap(
         log_probs=tf.zeros([tgt_batch_size, vocab_size]),
         atten_probs=tf.zeros([tgt_batch_size,
                               0])), py_utils.NestedMap()
Ejemplo n.º 17
0
 def _OutfeedEnqueue(self, per_example_tensors):
   if not per_example_tensors:
     return tf.no_op()
   per_example_tensors = py_utils.NestedMap(per_example_tensors)
   return tpu_ops.outfeed_enqueue_tuple(per_example_tensors.Flatten())
Ejemplo n.º 18
0
def GetBeamSearchHelperResults(sess,
                               num_hyps_per_beam,
                               pass_seq_lengths=False,
                               force_eos_in_top_k=False):
    np.random.seed(9384758)
    tf.random.set_seed(8274758)
    vocab_size = 12
    src_len = 5
    tgt_len = 7
    src_batch_size = 2
    tgt_batch_size = src_batch_size * num_hyps_per_beam
    p = beam_search_helper.BeamSearchHelper.Params().Set(
        name='bsh',
        target_seq_len=tgt_len,
        force_eos_in_top_k=force_eos_in_top_k)
    bs_helper = p.Instantiate()

    def InitBeamSearchState(unused_theta, unused_encoder_outputs,
                            unused_num_hyps_per_beam):
        atten_probs = tf.constant(np.random.normal(size=(tgt_batch_size,
                                                         src_len)),
                                  dtype=tf.float32)
        return (py_utils.NestedMap({
            'log_probs':
            tf.zeros([tgt_batch_size, vocab_size]),
            'atten_probs':
            atten_probs,
        }), py_utils.NestedMap({'atten_probs': atten_probs}))

    def PreBeamSearchStepCallback(unused_theta, unused_encoder_outputs,
                                  unused_step_ids, states,
                                  unused_num_hyps_per_beam):
        atten_probs = tf.identity(states.atten_probs)
        logits = tf.random.normal([tgt_batch_size, vocab_size], seed=8273747)
        return (py_utils.NestedMap({
            'atten_probs': atten_probs,
            'log_probs': logits
        }), states)

    def PostBeamSearchStepCallback(unused_theta, unused_encoder_outputs,
                                   unused_new_step_ids, states):
        return states

    src_enc = tf.random.normal([src_len, src_batch_size, 8], seed=982774838)
    src_enc_padding = tf.constant(
        [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 1.0], [1.0, 1.0]],
        dtype=tf.float32)
    encoder_outputs = py_utils.NestedMap(encoded=src_enc,
                                         padding=src_enc_padding)
    if pass_seq_lengths:
        encoder_outputs['seq_lengths'] = tf.constant([4, 3], dtype=tf.int32)

    theta = py_utils.NestedMap()
    decoder_output = bs_helper.BeamSearchDecode(theta, encoder_outputs,
                                                num_hyps_per_beam,
                                                InitBeamSearchState,
                                                PreBeamSearchStepCallback,
                                                PostBeamSearchStepCallback)

    topk_ids, topk_lens, topk_scores = sess.run([
        decoder_output.topk_ids, decoder_output.topk_lens,
        decoder_output.topk_scores
    ])
    return topk_ids, topk_lens, topk_scores
Ejemplo n.º 19
0
  def _PreBeamSearchStepCallback(self,
                                 theta,
                                 source_encs,
                                 source_paddings,
                                 step_ids,
                                 states,
                                 num_hyps_per_beam,
                                 additional_source_info=None):
    """Returns logits for sampling ids and the next model states.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      source_encs: A tensor of shape [src_len, src_batch, source_dim].
          Can be [time, batch, depth, num_layers] if is_transparent is set.
      source_paddings: A tensor of shape [src_len, src_batch].
      step_ids: A tensor of shape [tgt_batch, 1].
      states: A `.NestedMap` of tensors representing states that the clients
          would like to keep track of for each of the active hyps.
      num_hyps_per_beam: Beam size.
      additional_source_info: a `.NestedMap` of tensors containing extra context
          information about the source that may be useful for decoding.
    Returns:
      A tuple (results, out_states).
        results: A `.NestedMap` of beam search results.
          atten_probs:
            The updated attention probs, of shape [tgt_batch, src_len].
          log_probs:
            Log prob for each of the tokens in the target vocab. This is of
            shape [tgt_batch, vocab_size].
        out_states: A `.NestedMap`. The updated states.
           source_encs:
             A tensor of shape [src_batch, src_len, source_dim].
           source_paddings:
             A tensor of shape [src_batch, src_len].
           target_ids:
             Updated list of decoded ids. [num_hyps, Num of decoded ids].
    """
    p = self.params
    # additional_source_info is currently not used.
    del additional_source_info

    target_time = states.time_step
    prefix_states = states.prefix_states

    new_states = states.Pack(states.Flatten())

    layer_out, updated_prefix_states = self.ExtendStep(
        theta, source_encs, source_paddings, tf.squeeze(step_ids, 1),
        target_time[0][0], prefix_states)

    new_states.prefix_states = updated_prefix_states
    new_states.time_step = target_time + 1

    softmax_input = tf.reshape(layer_out, [-1, p.softmax.input_dim])
    logits = self.softmax.Logits(theta.softmax, [softmax_input])

    num_hyps = py_utils.GetShape(step_ids)[0]
    source_len = py_utils.GetShape(source_encs)[0]
    # [time * batch, num_classes] -> [time, batch, num_classes]
    logits = tf.reshape(logits, (-1, num_hyps, p.softmax.num_classes))
    # [time, batch, num_classes] -> [batch, time, num_classes]
    logits = tf.transpose(logits, (1, 0, 2))

    # Dummy attention probs
    atten_probs = tf.ones([num_hyps, source_len]) / tf.to_float(source_len)

    # Only return logits for the last ids
    log_probs = tf.nn.log_softmax(tf.squeeze(logits, axis=1))

    bs_results = py_utils.NestedMap({
        'atten_probs': atten_probs,
        'log_probs': log_probs,
    })

    return bs_results, new_states
Ejemplo n.º 20
0
    def testBeamSearchForceLastChunkEocInTopK(self, is_last_chunk,
                                              force_last_chunk_eoc_in_top_k,
                                              eos_score, expected_topk_lens,
                                              expected_topk_scores):
        with self.session() as sess:
            vocab_size = 30
            tgt_len = 10
            num_hyps_per_beam = 3
            src_batch_size = 2
            tgt_batch_size = src_batch_size * num_hyps_per_beam
            p = beam_search_helper.BeamSearchHelper.Params().Set(
                name='bsh',
                target_eoc_id=0,
                target_seq_len=tgt_len,
                num_hyps_per_beam=num_hyps_per_beam,
                beam_size=100000.0,  # Beam search until the end.
                force_last_chunk_eoc_in_top_k=force_last_chunk_eoc_in_top_k,
            )
            bs_helper = p.Instantiate()

            def InitBeamSearchCallBack(unused_theta, unused_encoder_outputs,
                                       unused_num_hyps_per_beam):
                return py_utils.NestedMap(
                    log_probs=tf.zeros([tgt_batch_size, vocab_size]),
                    atten_probs=tf.zeros([tgt_batch_size, 0]),
                    is_last_chunk=tf.zeros([tgt_batch_size],
                                           tf.bool)), py_utils.NestedMap()

            def PreBeamSearchStepCallback(unused_theta, unused_encoder_outputs,
                                          unused_step_ids, states,
                                          unused_num_hyps_per_beam):
                # Same probs for each id.
                logits = tf.zeros([tgt_batch_size, vocab_size])
                # Except eoc has slightly lower score.
                logits = logits - 1.0 * tf.expand_dims(
                    tf.one_hot(p.target_eoc_id, vocab_size), 0)
                # eos has very low score (can not terminate by eos)
                logits = logits + eos_score * tf.expand_dims(
                    tf.one_hot(p.target_eos_id, vocab_size), 0)
                return py_utils.NestedMap(
                    atten_probs=tf.zeros([tgt_batch_size, 0]),
                    log_probs=logits,
                    is_last_chunk=tf.fill([tgt_batch_size],
                                          value=is_last_chunk)), states

            def PostBeamSearchStepCallback(unused_theta,
                                           unused_encoder_outputs,
                                           unused_new_step_ids, states):
                return states

            encoder_outputs = py_utils.NestedMap(
                seq_lengths=tf.zeros([src_batch_size], dtype=tf.int32))
            theta = py_utils.NestedMap()

            beam_search_output = bs_helper.BeamSearchDecode(
                theta,
                encoder_outputs,
                init_beam_search_state=InitBeamSearchCallBack,
                pre_beam_search_step_callback=PreBeamSearchStepCallback,
                post_beam_search_step_callback=PostBeamSearchStepCallback)

            topk_lens, topk_scores = sess.run(
                [beam_search_output.topk_lens, beam_search_output.topk_scores])
            self.assertAllEqual(topk_lens, expected_topk_lens)
            self.assertAllClose(topk_scores, expected_topk_scores, atol=1e-6)
Ejemplo n.º 21
0
    def __init__(self, params):
        """Layer constructor.

    Args:
      params: A params used to construct this layer.
    """
        assert params.name, ('Layer params for %s must have a "name"' %
                             self.__class__.__name__)

        tf_module_name = params.name
        tf_module_name = re.sub('[^a-zA-Z0-9_]+', '_', tf_module_name)
        tf_module_name = 'bbf_' + self.__class__.__name__ + '_' + tf_module_name
        py_utils.NestedMap.CheckKey(tf_module_name)

        # initialize the base class.
        super().__init__(tf_module_name)

        # Note AutoTracking doesn't work properly due to its inability to walk
        # through py_utils.NestedMap data structures which are used widely
        # throughout the Lingvo codebase. Also there seems to be some performance
        # hit in turning on auto-tracking in constructing graphs. For now, we
        # disable auto-tracking.
        # TODO(lingvo): Re-enable auto-tracking when fuller support is
        # added for key data structures used in Lingvo, and performance issue is
        # debugged more and understood better.
        self._setattr_tracking = False

        self._parent = None
        for parent in reversed(_LAYER_STACK.stack):
            if parent is not self:
                self._parent = parent
                break
        self._params = params.Copy()
        tf.logging.debug('Creating layer %s with params: \n %s \n',
                         self.__class__.__name__, str(params))
        # Vars created by this layer.
        self._private_vars = py_utils.NestedMap()
        # Theta derived from this layer's vars.
        self._private_theta = py_utils.NestedMap()
        # Child layers created by this layer through CreateChild/CreateChildren.
        self._private_children = py_utils.NestedMap()
        # Child layers created by this layer. A well-formed layer should
        # have self._private_children equals to self._children_list. I.e.,
        # all child layers are created using CreateChild/CreateChildren.
        self._children_list = []
        # Extra theta's not directly correspond to any underlying vars. For example,
        # the concatenated sharded variables.
        self._extra_theta = py_utils.NestedMap()
        # All registered accumulators.
        self._private_accumulators = py_utils.NestedMap()
        # Layer-private functions. Add with AddFunction.
        self._private_fns = dict()
        # Mapping from variable names to its symbolic shape.
        # self._var_symbolic_shape_map['var_name'] will be a tuple of integers or
        # symbolic expressions, one for each dimension of the variable.
        self._var_symbolic_shape_map = {}

        self._is_variable_free = False
        self._variables_to_create = {}
        self._create_variables_status = _CreateLayerVariablesStatus.NOT_CALLED
        # Keep track of the tf.variable_scope(p.name) this layer creates so we can
        # reenter it without creating a new one.
        self._self_variable_scope = None
Ejemplo n.º 22
0
    def testCustomStepIds(self):
        with self.session(use_gpu=False):
            np.random.seed(9384758)
            tf.random.set_seed(8274758)
            vocab_size = 12
            src_len = 5
            tgt_len = 7
            num_hyps_per_beam = 3
            src_batch_size = 2
            tgt_batch_size = src_batch_size * num_hyps_per_beam
            p = beam_search_helper.BeamSearchHelper.Params().Set(
                name='bsh', target_seq_len=tgt_len)
            bs_helper = p.Instantiate()

            def InitBeamSearchState(unused_theta, unused_encoder_outputs,
                                    unused_num_hyps_per_beam):
                atten_probs = tf.constant(
                    np.random.normal(size=(tgt_batch_size, src_len)),
                    dtype=tf.float32)
                return (py_utils.NestedMap({
                    'log_probs':
                    tf.zeros([tgt_batch_size, vocab_size]),
                    'atten_probs':
                    atten_probs,
                    'step_ids':
                    tf.zeros([tgt_batch_size, 1], dtype=tf.int32)
                }), py_utils.NestedMap({'atten_probs': atten_probs}))

            def PreBeamSearchStepCallback(unused_theta, unused_encoder_outputs,
                                          unused_step_ids, states,
                                          unused_num_hyps_per_beam):
                atten_probs = tf.identity(states.atten_probs)
                logits = tf.random.normal([tgt_batch_size, vocab_size],
                                          seed=8273747)
                return (py_utils.NestedMap({
                    'atten_probs': atten_probs,
                    'log_probs': logits
                }), states)

            def PostBeamSearchStepCallback(unused_theta,
                                           unused_encoder_outputs,
                                           unused_new_step_ids, states):
                return states

            src_enc = tf.random.normal([src_len, src_batch_size, 8],
                                       seed=982774838)
            src_enc_padding = tf.constant(
                [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 1.0], [1.0, 1.0]],
                dtype=tf.float32)
            encoder_outputs = py_utils.NestedMap(encoded=src_enc,
                                                 padding=src_enc_padding)

            theta = py_utils.NestedMap()
            decoder_output = bs_helper.BeamSearchDecode(
                theta, encoder_outputs, num_hyps_per_beam, InitBeamSearchState,
                PreBeamSearchStepCallback, PostBeamSearchStepCallback)

            topk_ids, topk_lens, topk_scores = self.evaluate([
                decoder_output.topk_ids, decoder_output.topk_lens,
                decoder_output.topk_scores
            ])
            print(np.array_repr(topk_ids))
            print(np.array_repr(topk_lens))
            print(np.array_repr(topk_scores))
            expected_topk_ids = [[4, 3, 4, 3, 2, 0, 0], [4, 3, 11, 2, 0, 0, 0],
                                 [4, 3, 6, 2, 0, 0, 0], [6, 0, 4, 6, 6, 11, 2],
                                 [6, 0, 4, 6, 1, 2, 0], [6, 0, 4, 6, 6, 2, 0]]
            expected_topk_lens = [5, 4, 4, 7, 6, 6]
            expected_topk_scores = [[8.27340603, 6.26949024, 5.59490776],
                                    [9.74691486, 8.46679497, 7.14809656]]
            self.assertAllEqual(expected_topk_ids, topk_ids.tolist())
            self.assertAllEqual(expected_topk_lens, topk_lens.tolist())
            self.assertAllClose(expected_topk_scores, topk_scores)
Ejemplo n.º 23
0
 def zero_state(self, theta, batch_size):
     return py_utils.NestedMap(rnn=[
         self.rnn[i].zero_state(theta.rnn[i], batch_size)
         for i in range(len(self.rnn))
     ])
Ejemplo n.º 24
0
    def testGreedySearchHelper(self):
        with self.session(use_gpu=False):
            np.random.seed(9384758)
            tf.random.set_seed(8274758)
            vocab_size = 12
            src_len = 5
            tgt_len = 7
            src_batch_size = 2
            tgt_batch_size = src_batch_size
            p = beam_search_helper.GreedySearchHelper.Params().Set(
                name='gsh', target_seq_len=tgt_len)
            gs_helper = p.Instantiate()

            def InitGreedySearchState(unused_theta, unused_encoder_outputs,
                                      unused_num_hyps_per_beam):
                atten_probs = tf.constant(
                    np.random.normal(size=(tgt_batch_size, src_len)),
                    dtype=tf.float32)
                return (py_utils.NestedMap({
                    'log_probs':
                    tf.zeros([tgt_batch_size, vocab_size]),
                    'atten_probs':
                    atten_probs,
                }), py_utils.NestedMap({'atten_probs': atten_probs}))

            def PreGreedySearchStepCallback(unused_theta,
                                            unused_encoder_outputs,
                                            unused_step_ids, states,
                                            unused_num_hyps_per_beam):
                atten_probs = tf.identity(states.atten_probs)
                logits = tf.random.normal([tgt_batch_size, vocab_size],
                                          seed=8273747)
                return (py_utils.NestedMap({
                    'atten_probs': atten_probs,
                    'log_probs': logits
                }), states)

            def PostGreedySearchStepCallback(unused_theta,
                                             unused_encoder_outputs,
                                             unused_new_step_ids, states):
                return states

            src_enc = tf.random.normal([src_len, src_batch_size, 8],
                                       seed=982774838)
            src_enc_padding = tf.constant(
                [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 1.0], [1.0, 1.0]],
                dtype=tf.float32)
            encoder_outputs = py_utils.NestedMap(encoded=src_enc,
                                                 padding=src_enc_padding)

            theta = py_utils.NestedMap()
            (final_hyp_ids, final_hyp_lens,
             final_done_hyps) = gs_helper.GreedySearchDecode(
                 theta, encoder_outputs, InitGreedySearchState,
                 PreGreedySearchStepCallback, PostGreedySearchStepCallback)

            (final_hyp_ids, final_hyp_lens, final_done_hyps) = self.evaluate(
                [final_hyp_ids, final_hyp_lens, final_done_hyps])

            print(np.array_repr(final_hyp_ids))
            print(np.array_repr(final_hyp_lens))
            print(np.array_repr(final_done_hyps))

            expected_hyp_ids = [[2, 2, 6, 7, 1, 9, 4], [3, 9, 3, 9, 6, 5, 10]]
            expected_hyp_lens = [1, 7]
            expected_done_hyps = [True, False]
            self.assertAllEqual(expected_hyp_ids, final_hyp_ids.tolist())
            self.assertAllEqual(expected_hyp_lens, final_hyp_lens.tolist())
            self.assertAllEqual(expected_done_hyps, final_done_hyps.tolist())
Ejemplo n.º 25
0
 def FPropMeta(cls, p, inputs, *args):
     dim1, dim2 = args[1][:2] if p.inputs_from_decoder else inputs[:2]
     logits = tshape.Shape([dim1, dim2, p.num_classes])
     return py_utils.NestedMap(flops=100, out_shapes=(logits, ))
Ejemplo n.º 26
0
  def BuildInputBatch(self, batch_size, features_list, bucket_keys=None):
    """Builds an input batch.

    Args:
      batch_size: batch size to use, defaults to infeed batch size.
      features_list: Use this list to build the batch.
      bucket_keys: If None, bucket_keys[i] is the bucketing key of the i-th
        sample.

    Returns:
      py_utils.NestedMap with feature names as keys and tensors as values.
    """
    p = self.params

    ret = py_utils.NestedMap()
    ret.bucket_keys = bucket_keys

    (src_ids, src_paddings, tgt_ids, tgt_paddings, tgt_labels,
     tgt_weights) = features_list
    if p.pad_to_max_seq_length:
      assert p.source_max_length

      if min(self.infeed_bucket_batch_limit) == max(
          self.infeed_bucket_batch_limit):
        source_shape = [
            min(self.infeed_bucket_batch_limit), p.source_max_length
        ]
        target_shape = [
            min(self.infeed_bucket_batch_limit), p.target_max_length
        ]
      else:
        source_shape = None
        target_shape = None
      src_ids = py_utils.PadSequenceDimension(src_ids, p.source_max_length, 0,
                                              source_shape)
      src_paddings = py_utils.PadSequenceDimension(src_paddings,
                                                   p.source_max_length, 1,
                                                   source_shape)
      tgt_ids = py_utils.PadSequenceDimension(tgt_ids, p.target_max_length, 0,
                                              target_shape)
      tgt_paddings = py_utils.PadSequenceDimension(tgt_paddings,
                                                   p.target_max_length, 1,
                                                   target_shape)
      tgt_labels = py_utils.PadSequenceDimension(tgt_labels,
                                                 p.target_max_length, 0,
                                                 target_shape)
      tgt_weights = py_utils.PadSequenceDimension(tgt_weights,
                                                  p.target_max_length, 0,
                                                  target_shape)

    ret.src = py_utils.NestedMap()
    ret.src.ids = tf.cast(src_ids, dtype=tf.int32)
    ret.src.paddings = src_paddings

    ret.tgt = py_utils.NestedMap()
    ret.tgt.ids = tgt_ids
    ret.tgt.labels = tf.cast(tgt_labels, dtype=tf.int32)
    ret.tgt.weights = tgt_weights
    ret.tgt.paddings = tgt_paddings

    if (self.params.fprop_dtype is None or
        self.params.dtype == self.params.fprop_dtype):
      return ret

    def _Cast(v):
      if not v.dtype.is_floating:
        return v
      return tf.cast(v, self.params.fprop_dtype)

    return ret.Transform(_Cast)
Ejemplo n.º 27
0
    def Sample(self, decoder_theta, source_encs, source_paddings, random_seed,
               init_state_callback, pre_step_callback, post_step_callback):
        """Samples target sequences, one target sequence per source sequence.

    (Please see beam_search_helper.py for description of decoder callbacks.)

    Args:
      decoder_theta: A NestedMap object containing weights' values of the
        decoder layer and its children layers, to be passed to decoder
        callbacks.
      source_encs: source encoding, to be passed to decoder callbacks.
      source_paddings: source padding, to be passed to decoder callbacks.
      random_seed: a scalar int32 tensor representing the random seed.
      init_state_callback: decoder._InitBeamSearchStateCallback.
      pre_step_callback: decoder._PreBeamSearchStepCallback.
      post_step_callback: decoder._PostBeamSearchStepCallback.

    Returns:
      A NestedMap containing the following tensors:
      - 'logits': [batch, max_target_length, vocab_size], representing the
        distribution from which target sequences are sampled.
      - 'ids': [batch, max_target_length] of int32, representing the target
        sequence ids, not including target_sos_id, but maybe ending with
        target_eos_id if end-of-sequence is reached before target_seq_len.
      - 'paddings': [batch, max_target_length] of 0/1, where 1 represents
        a padded timestep.
    """
        p = self.params
        assert p.temperature > 0
        recurrent_theta = py_utils.NestedMap(theta=decoder_theta,
                                             random_seed=random_seed,
                                             source_encs=source_encs,
                                             source_paddings=source_paddings)
        bs_result, bs_state = init_state_callback(recurrent_theta.theta,
                                                  source_encs,
                                                  source_paddings,
                                                  num_hyps_per_beam=1)
        batch = tf.shape(bs_result.log_probs)[0]
        recurrent_state0 = py_utils.NestedMap(
            timestep=tf.zeros(shape=[], dtype=tf.int32),
            logits=bs_result.log_probs,
            # Start with target_sos_id.
            ids=tf.fill([batch], tf.to_int32(p.target_sos_id)),
            bs_state=bs_state)
        inputs = py_utils.NestedMap(dummy=tf.zeros([p.target_seq_len, batch]))

        def Step(recurrent_theta, state0, inputs):
            """Computes one decoder step."""
            del inputs
            with tf.name_scope('single_sampler_step'):
                # Compute logits and states.
                bs_result, bs_state1 = pre_step_callback(
                    recurrent_theta.theta,
                    recurrent_theta.source_encs,
                    recurrent_theta.source_paddings,
                    tf.expand_dims(state0.ids, 1),  # [batch, 1].
                    state0.bs_state,
                    num_hyps_per_beam=1)
                batch = tf.shape(bs_result.log_probs)[0]
                state1 = py_utils.NestedMap(timestep=state0.timestep + 1)
                state1.logits = bs_result.log_probs
                # Sample ids from logits. [batch].
                state1.ids = tf.reshape(
                    tf.contrib.stateless.stateless_multinomial(
                        state1.logits / p.temperature,
                        num_samples=1,
                        seed=tf.stack(
                            [recurrent_theta.random_seed, state0.timestep]),
                        output_dtype=state0.ids.dtype,
                        name='sample_next_id'), [batch])
                if 'is_last_chunk' in bs_result and p.target_eoc_id >= 0:
                    state1.ids = tf.where(
                        tf.logical_and(bs_result.is_last_chunk,
                                       tf.equal(state1.ids, p.target_eoc_id)),
                        tf.fill(tf.shape(state1.ids), p.target_eos_id),
                        state1.ids)
                state1.bs_state = post_step_callback(
                    recurrent_theta.theta, recurrent_theta.source_encs,
                    recurrent_theta.source_paddings, state1.ids, bs_state1)
            return state1, py_utils.NestedMap()

        accumulated_states, _ = recurrent.Recurrent(recurrent_theta,
                                                    recurrent_state0, inputs,
                                                    Step)
        result = py_utils.NestedMap(logits=tf.transpose(
            accumulated_states.logits, [1, 0, 2]),
                                    ids=tf.transpose(accumulated_states.ids))
        result.paddings = tf.cast(
            _ComputePaddings(result.ids, p.target_eos_id), result.logits.dtype)
        # Force ids to be eos_id if the timestep is padded.
        result.ids = tf.where(tf.equal(result.paddings, 0), result.ids,
                              tf.fill(tf.shape(result.ids), p.target_eos_id))
        static_batch_size = bs_result.log_probs.shape[0]
        result.ids.set_shape([static_batch_size, p.target_seq_len])
        result.paddings.set_shape([static_batch_size, p.target_seq_len])
        return result
Ejemplo n.º 28
0
 def CellFn(theta, state0, unused_inputs):
   out_nmap = self._FProp(theta, state0)
   return out_nmap, py_utils.NestedMap()
Ejemplo n.º 29
0
 def FPropMeta(cls, p, inputs, padding=None):
     py_utils.CheckShapes((inputs, ))
     return py_utils.NestedMap(flops=inputs.num_elements() *
                               _BN_FLOPS_PER_ELEMENT,
                               out_shapes=(inputs, ))
Ejemplo n.º 30
0
 def FPropMeta(cls, p, inputs):
     py_utils.CheckShapes((inputs, ))
     return py_utils.NestedMap(flops=1, out_shapes=(inputs, ))