Example #1
0
    def body(time, outputs_ta, state, inputs, finished, sequence_lengths):
      """Internal while_loop body.

      Args:
        time: scalar int32 tensor.
        outputs_ta: structure of TensorArray.
        state: (structure of) state tensors and TensorArrays.
        inputs: (structure of) input tensors.
        finished: bool tensor (keeping track of what's finished).
        sequence_lengths: int32 tensor (keeping track of time of finish).

      Returns:
        `(time + 1, outputs_ta, next_state, next_inputs, next_finished,
          next_sequence_lengths)`.
        ```
      """
      (next_outputs, decoder_state, next_inputs,
       decoder_finished) = decoder.step(time, inputs, state)
      next_finished = math_ops.logical_or(decoder_finished, finished)
      if maximum_iterations is not None:
        next_finished = math_ops.logical_or(
            next_finished, time + 1 >= maximum_iterations)
      next_sequence_lengths = array_ops.where(
          math_ops.logical_and(math_ops.logical_not(finished), next_finished),
          array_ops.fill(array_ops.shape(sequence_lengths), time + 1),
          sequence_lengths)

      nest.assert_same_structure(state, decoder_state)
      nest.assert_same_structure(outputs_ta, next_outputs)
      nest.assert_same_structure(inputs, next_inputs)

      # Zero out output values past finish
      if impute_finished:
        emit = nest.map_structure(
            lambda out, zero: array_ops.where(finished, zero, out),
            next_outputs,
            zero_outputs)
      else:
        emit = next_outputs

      # Copy through states past finish
      def _maybe_copy_state(new, cur):
        # TensorArrays and scalar states get passed through.
        if isinstance(cur, tensor_array_ops.TensorArray):
          pass_through = True
        else:
          new.set_shape(cur.shape)
          pass_through = (new.shape.ndims == 0)
        return new if pass_through else array_ops.where(finished, cur, new)

      if impute_finished:
        next_state = nest.map_structure(
            _maybe_copy_state, decoder_state, state)
      else:
        next_state = decoder_state

      outputs_ta = nest.map_structure(lambda ta, out: ta.write(time, out),
                                      outputs_ta, emit)
      return (time + 1, outputs_ta, next_state, next_inputs, next_finished,
              next_sequence_lengths)
  def _decode(self, image_buffer, image_format):
    """Decodes the image buffer.

    Args:
      image_buffer: T tensor representing the encoded image tensor.
      image_format: The image format for the image in `image_buffer`.

    Returns:
      A decoder image.
    """
    def decode_png():
      return image_ops.decode_png(image_buffer, self._channels)
    def decode_raw():
      return parsing_ops.decode_raw(image_buffer, dtypes.uint8)
    def decode_jpg():
      return image_ops.decode_jpeg(image_buffer, self._channels)

    image = control_flow_ops.case({
        math_ops.logical_or(math_ops.equal(image_format, 'png'),
                            math_ops.equal(image_format, 'PNG')): decode_png,
        math_ops.logical_or(math_ops.equal(image_format, 'raw'),
                            math_ops.equal(image_format, 'RAW')): decode_raw,
    }, default=decode_jpg, exclusive=True)

    image.set_shape([None, None, self._channels])
    if self._shape is not None:
      image = array_ops.reshape(image, self._shape)

    return image
Example #3
0
  def _decode(self, image_buffer, image_format):
    """Decodes the image buffer.

    Args:
      image_buffer: The tensor representing the encoded image tensor.
      image_format: The image format for the image in `image_buffer`.

    Returns:
      A tensor that represents decoded image of self._shape, or
      (?, ?, self._channels) if self._shape is not specified.
    """

    def decode_png():
      return image_ops.decode_png(
          image_buffer, self._channels, dtype=self._dtype)

    def decode_raw():
      return parsing_ops.decode_raw(image_buffer, out_type=self._dtype)

    def decode_jpg():
      if self._dtype != dtypes.uint8:
        raise ValueError(
            'jpeg decoder can only be used to decode to tf.uint8 but %s was '
            'requested for a jpeg image.' % self._dtype)
      return image_ops.decode_jpeg(image_buffer, self._channels)

    # For RGBA images JPEG is not a valid decoder option.
    if self._channels > 3:
      pred_fn_pairs = {
          math_ops.logical_or(
              math_ops.equal(image_format, 'raw'),
              math_ops.equal(image_format, 'RAW')): decode_raw,
      }
      default_decoder = decode_png
    else:
      pred_fn_pairs = {
          math_ops.logical_or(
              math_ops.equal(image_format, 'png'),
              math_ops.equal(image_format, 'PNG')): decode_png,
          math_ops.logical_or(
              math_ops.equal(image_format, 'raw'),
              math_ops.equal(image_format, 'RAW')): decode_raw,
      }
      default_decoder = decode_jpg

    image = control_flow_ops.case(
        pred_fn_pairs, default=default_decoder, exclusive=True)

    image.set_shape([None, None, self._channels])
    if self._shape is not None:
      image = array_ops.reshape(image, self._shape)

    return image
  def _decode(self, image_buffer, image_format):
    """Decodes the image buffer.

    Args:
      image_buffer: The tensor representing the encoded image tensor.
      image_format: The image format for the image in `image_buffer`. If image
        format is `raw`, all images are expected to be in this format, otherwise
        this op can decode a mix of `jpg` and `png` formats.

    Returns:
      A tensor that represents decoded image of self._shape, or
      (?, ?, self._channels) if self._shape is not specified.
    """
    def decode_image():
      """Decodes a png or jpg based on the headers."""
      return image_ops.decode_image(image_buffer, self._channels)

    def decode_raw():
      """Decodes a raw image."""
      return parsing_ops.decode_raw(image_buffer, out_type=self._dtype)

    pred_fn_pairs = {
        math_ops.logical_or(
            math_ops.equal(image_format, 'raw'),
            math_ops.equal(image_format, 'RAW')): decode_raw,
    }
    image = control_flow_ops.case(
        pred_fn_pairs, default=decode_image, exclusive=True)

    image.set_shape([None, None, self._channels])
    if self._shape is not None:
      image = array_ops.reshape(image, self._shape)

    return image
Example #5
0
  def pdf(self, x, name="pdf"):
    """The PDF of observations in `x` under these Uniform distribution(s).

    Args:
      x: tensor of dtype `dtype`, must be broadcastable with `a` and `b`.
      name: The name to give this op.

    Returns:
      pdf: tensor of dtype `dtype`, the pdf values of `x`. If `x` is `nan`, will
          return `nan`.
    """
    with ops.name_scope(self.name):
      with ops.op_scope([self.a, self.b, x], name):
        x = ops.convert_to_tensor(x, name="x")
        if x.dtype != self.dtype:
          raise TypeError("Input x dtype does not match dtype: %s vs. %s" %
                          (x.dtype, self.dtype))

        broadcasted_x = x * self._ones()
        return math_ops.select(
            math_ops.is_nan(broadcasted_x), broadcasted_x, math_ops.select(
                math_ops.logical_or(broadcasted_x < self.a,
                                    broadcasted_x > self.b),
                array_ops.zeros_like(broadcasted_x),
                (1.0 / self.range()) * array_ops.ones_like(broadcasted_x)))
Example #6
0
  def body(time, outputs_ta, state, inputs, finished):
    """Internal while_loop body.

    Args:
      time: scalar int32 tensor.
      outputs_ta: structure of TensorArray.
      state: (structure of) state tensors and TensorArrays.
      inputs: (structure of) input tensors.
      finished: 1-D bool tensor.

    Returns:
      `(time + 1, outputs_ta, next_state, next_inputs, next_finished)`.
    """
    (next_outputs, decoder_state, next_inputs, decoder_finished) = decoder.step(
        time, inputs, state)
    next_finished = math_ops.logical_or(decoder_finished, finished)

    nest.assert_same_structure(state, decoder_state)
    nest.assert_same_structure(outputs_ta, next_outputs)
    nest.assert_same_structure(inputs, next_inputs)

    # Zero out output values past finish
    emit = nest.map_structure(
        lambda out, zero: array_ops.where(finished, zero, out), next_outputs,
        zero_outputs)

    # Copy through states past finish
    def _maybe_copy_state(new, cur):
      return (new if isinstance(cur, tensor_array_ops.TensorArray) else
              array_ops.where(finished, cur, new))

    next_state = nest.map_structure(_maybe_copy_state, decoder_state, state)
    outputs_ta = nest.map_structure(lambda ta, out: ta.write(time, out),
                                    outputs_ta, emit)
    return (time + 1, outputs_ta, next_state, next_inputs, next_finished)
Example #7
0
        def body(time, elements_finished, current_input, emit_ta, state, loop_state):
            """Internal while loop body for raw_rnn.

      Args:
        time: time scalar.
        elements_finished: batch-size vector.
        current_input: possibly nested tuple of input tensors.
        emit_ta: possibly nested tuple of output TensorArrays.
        state: possibly nested tuple of state tensors.
        loop_state: possibly nested tuple of loop state tensors.

      Returns:
        Tuple having the same size as Args but with updated values.
      """
            (next_output, cell_state) = cell(current_input, state)

            nest.assert_same_structure(state, cell_state)
            nest.assert_same_structure(cell.output_size, next_output)

            next_time = time + 1
            (next_finished, next_input, next_state, emit_output, next_loop_state) = loop_fn(
                next_time, next_output, cell_state, loop_state
            )

            nest.assert_same_structure(state, next_state)
            nest.assert_same_structure(current_input, next_input)
            nest.assert_same_structure(emit_ta, emit_output)

            # If loop_fn returns None for next_loop_state, just reuse the
            # previous one.
            loop_state = loop_state if next_loop_state is None else next_loop_state

            def _copy_some_through(current, candidate):
                """Copy some tensors through via array_ops.where."""
                current_flat = nest.flatten(current)
                candidate_flat = nest.flatten(candidate)
                # pylint: disable=g-long-lambda,cell-var-from-loop
                result_flat = [
                    _on_device(
                        lambda: array_ops.where(elements_finished, current_i, candidate_i), device=candidate_i.op.device
                    )
                    for (current_i, candidate_i) in zip(current_flat, candidate_flat)
                ]
                # pylint: enable=g-long-lambda,cell-var-from-loop
                return nest.pack_sequence_as(structure=current, flat_sequence=result_flat)

            emit_output = _copy_some_through(zero_emit, emit_output)
            next_state = _copy_some_through(state, next_state)

            emit_output_flat = nest.flatten(emit_output)
            emit_ta_flat = nest.flatten(emit_ta)

            elements_finished = math_ops.logical_or(elements_finished, next_finished)

            emit_ta_flat = [ta.write(time, emit) for (ta, emit) in zip(emit_ta_flat, emit_output_flat)]

            emit_ta = nest.pack_sequence_as(structure=emit_structure, flat_sequence=emit_ta_flat)

            return (next_time, elements_finished, next_input, emit_ta, next_state, loop_state)
Example #8
0
def _dynamic_rank_in(actual_rank, given_ranks):
  if len(given_ranks) < 1:
    return ops.convert_to_tensor(False)
  result = math_ops.equal(given_ranks[0], actual_rank)
  for given_rank in given_ranks[1:]:
    result = math_ops.logical_or(
        result, math_ops.equal(given_rank, actual_rank))
  return result
Example #9
0
  def next_inputs(self, time, outputs, state, sample_ids, name=None):
    with ops.name_scope(name, "ScheduledOutputTrainingHelperNextInputs",
                        [time, outputs, state, sample_ids]):
      (finished, base_next_inputs, state) = (
          super(ScheduledOutputTrainingHelper, self).next_inputs(
              time=time,
              outputs=outputs,
              state=state,
              sample_ids=sample_ids,
              name=name))
      sample_ids = math_ops.cast(sample_ids, dtypes.bool)

      def maybe_sample():
        """Perform scheduled sampling."""

        def maybe_concatenate_auxiliary_inputs(outputs_, indices=None):
          """Concatenate outputs with auxiliary inputs, if they exist."""
          if self._auxiliary_input_tas is None:
            return outputs_

          next_time = time + 1
          auxiliary_inputs = nest.map_structure(
              lambda ta: ta.read(next_time), self._auxiliary_input_tas)
          if indices is not None:
            auxiliary_inputs = array_ops.gather_nd(auxiliary_inputs, indices)
          return nest.map_structure(
              lambda x, y: array_ops.concat((x, y), -1),
              outputs_, auxiliary_inputs)

        if self._next_inputs_fn is None:
          return array_ops.where(
              sample_ids, maybe_concatenate_auxiliary_inputs(outputs),
              base_next_inputs)

        where_sampling = math_ops.cast(
            array_ops.where(sample_ids), dtypes.int32)
        where_not_sampling = math_ops.cast(
            array_ops.where(math_ops.logical_not(sample_ids)), dtypes.int32)
        outputs_sampling = array_ops.gather_nd(outputs, where_sampling)
        inputs_not_sampling = array_ops.gather_nd(base_next_inputs,
                                                  where_not_sampling)
        sampled_next_inputs = maybe_concatenate_auxiliary_inputs(
            self._next_inputs_fn(outputs_sampling), where_sampling)

        base_shape = array_ops.shape(base_next_inputs)
        return (array_ops.scatter_nd(indices=where_sampling,
                                     updates=sampled_next_inputs,
                                     shape=base_shape)
                + array_ops.scatter_nd(indices=where_not_sampling,
                                       updates=inputs_not_sampling,
                                       shape=base_shape))

      all_finished = math_ops.reduce_all(finished)
      no_samples = math_ops.logical_not(math_ops.reduce_any(sample_ids))
      next_inputs = control_flow_ops.cond(
          math_ops.logical_or(all_finished, no_samples),
          lambda: base_next_inputs, maybe_sample)
      return (finished, next_inputs, state)
Example #10
0
 def _prob(self, x):
   broadcasted_x = x * array_ops.ones(self.batch_shape_tensor())
   return array_ops.where(
       math_ops.is_nan(broadcasted_x),
       broadcasted_x,
       array_ops.where(
           math_ops.logical_or(broadcasted_x < self.low,
                               broadcasted_x >= self.high),
           array_ops.zeros_like(broadcasted_x),
           array_ops.ones_like(broadcasted_x) / self.range()))
Example #11
0
 def _prob(self, x):
   broadcasted_x = x * array_ops.ones(self.batch_shape())
   return array_ops.where(
       math_ops.is_nan(broadcasted_x),
       broadcasted_x,
       array_ops.where(
           math_ops.logical_or(broadcasted_x < self.a,
                               broadcasted_x > self.b),
           array_ops.zeros_like(broadcasted_x),
           (1. / self.range()) * array_ops.ones_like(broadcasted_x)))
  def _decode(self, image_buffer, image_format):
    """Decodes the image buffer.

    Args:
      image_buffer: The tensor representing the encoded image tensor.
      image_format: The image format for the image in `image_buffer`. If image
        format is `raw`, all images are expected to be in this format, otherwise
        this op can decode a mix of `jpg` and `png` formats.

    Returns:
      A tensor that represents decoded image of self._shape, or
      (?, ?, self._channels) if self._shape is not specified.
    """

    def decode_image():
      """Decodes a image based on the headers."""
      return math_ops.cast(
          image_ops.decode_image(image_buffer, channels=self._channels),
          self._dtype)

    def decode_jpeg():
      """Decodes a jpeg image with specified '_dct_method'."""
      return math_ops.cast(
          image_ops.decode_jpeg(
              image_buffer,
              channels=self._channels,
              dct_method=self._dct_method), self._dtype)

    def check_jpeg():
      """Checks if an image is jpeg."""
      # For jpeg, we directly use image_ops.decode_jpeg rather than decode_image
      # in order to feed the jpeg specify parameter 'dct_method'.
      return control_flow_ops.cond(
          image_ops.is_jpeg(image_buffer),
          decode_jpeg,
          decode_image,
          name='cond_jpeg')

    def decode_raw():
      """Decodes a raw image."""
      return parsing_ops.decode_raw(image_buffer, out_type=self._dtype)

    pred_fn_pairs = {
        math_ops.logical_or(
            math_ops.equal(image_format, 'raw'),
            math_ops.equal(image_format, 'RAW')): decode_raw,
    }
    image = control_flow_ops.case(
        pred_fn_pairs, default=check_jpeg, exclusive=True)

    image.set_shape([None, None, self._channels])
    if self._shape is not None:
      image = array_ops.reshape(image, self._shape)

    return image
Example #13
0
 def max_reduce_fn(state, value):
   """Computes the maximum shape to pad to."""
   condition = math_ops.reduce_all(
       math_ops.logical_or(
           math_ops.less_equal(value.dense_shape, padded_shape),
           math_ops.equal(padded_shape, -1)))
   assert_op = control_flow_ops.Assert(condition, [
       "Actual shape greater than padded shape: ", value.dense_shape,
       padded_shape
   ])
   with ops.control_dependencies([assert_op]):
     return math_ops.maximum(state, value.dense_shape)
def softplus_inverse(x, name=None):
  """Computes the inverse softplus, i.e., x = softplus_inverse(softplus(x)).

  Mathematically this op is equivalent to:

  ```none
  softplus_inverse = log(exp(x) - 1.)
  ```

  Args:
    x: `Tensor`. Non-negative (not enforced), floating-point.
    name: A name for the operation (optional).

  Returns:
    `Tensor`. Has the same type/shape as input `x`.
  """
  with ops.name_scope(name, "softplus_inverse", values=[x]):
    x = ops.convert_to_tensor(x, name="x")
    # We begin by deriving a more numerically stable softplus_inverse:
    # x = softplus(y) = Log[1 + exp{y}], (which means x > 0).
    # ==> exp{x} = 1 + exp{y}                                (1)
    # ==> y = Log[exp{x} - 1]                                (2)
    #       = Log[(exp{x} - 1) / exp{x}] + Log[exp{x}]
    #       = Log[(1 - exp{-x}) / 1] + Log[exp{x}]
    #       = Log[1 - exp{-x}] + x                           (3)
    # (2) is the "obvious" inverse, but (3) is more stable than (2) for large x.
    # For small x (e.g. x = 1e-10), (3) will become -inf since 1 - exp{-x} will
    # be zero.  To fix this, we use 1 - exp{-x} approx x for small x > 0.
    #
    # In addition to the numerically stable derivation above, we clamp
    # small/large values to be congruent with the logic in:
    # tensorflow/core/kernels/softplus_op.h
    #
    # Finally, we set the input to one whenever the input is too large or too
    # small. This ensures that no unchosen codepath is +/- inf. This is
    # necessary to ensure the gradient doesn't get NaNs. Recall that the
    # gradient of `where` behaves like `pred*pred_true + (1-pred)*pred_false`
    # thus an `inf` in an unselected path results in `0*inf=nan`. We are careful
    # to overwrite `x` with ones only when we will never actually use this
    # value.  Note that we use ones and not zeros since `log(expm1(0.)) = -inf`.
    threshold = np.log(np.finfo(x.dtype.as_numpy_dtype).eps) + 2.
    is_too_small = math_ops.less(x, np.exp(threshold))
    is_too_large = math_ops.greater(x, -threshold)
    too_small_value = math_ops.log(x)
    too_large_value = x
    # This `where` will ultimately be a NOP because we won't select this
    # codepath whenever we used the surrogate `ones_like`.
    x = array_ops.where(math_ops.logical_or(is_too_small, is_too_large),
                        array_ops.ones_like(x), x)
    y = x + math_ops.log(-math_ops.expm1(-x))  # == log(expm1(x))
    return array_ops.where(is_too_small, too_small_value,
                           array_ops.where(is_too_large, too_large_value, y))
Example #15
0
def _maybe_convert_labels(y_true):
  """Converts binary labels into -1/1."""
  are_zeros = math_ops.equal(y_true, 0)
  are_ones = math_ops.equal(y_true, 1)
  is_binary = math_ops.reduce_all(math_ops.logical_or(are_zeros, are_ones))

  def _convert_binary_labels():
    # Convert the binary labels to -1 or 1.
    return 2. * y_true - 1.

  updated_y_true = smart_cond.smart_cond(is_binary,
                                         _convert_binary_labels, lambda: y_true)
  return updated_y_true
Example #16
0
    def body(time, elements_finished, current_input,
             emit_ta, state, loop_state):
      """Internal while loop body for raw_rnn.

      Args:
        time: time scalar.
        elements_finished: batch-size vector.
        current_input: possibly nested tuple of input tensors.
        emit_ta: possibly nested tuple of output TensorArrays.
        state: possibly nested tuple of state tensors.
        loop_state: possibly nested tuple of loop state tensors.

      Returns:
        Tuple having the same size as Args but with updated values.
      """
      (next_output, cell_state) = cell(current_input, state)

      nest.assert_same_structure(state, cell_state)
      nest.assert_same_structure(cell.output_size, next_output)

      next_time = time + 1
      (next_finished, next_input, next_state, emit_output,
       next_loop_state) = loop_fn(
           next_time, next_output, cell_state, loop_state)

      nest.assert_same_structure(state, next_state)
      nest.assert_same_structure(current_input, next_input)
      nest.assert_same_structure(emit_ta, emit_output)

      # If loop_fn returns None for next_loop_state, just reuse the
      # previous one.
      loop_state = loop_state if next_loop_state is None else next_loop_state

      def _copy_some_through(current, candidate):
        """Copy some tensors through via array_ops.where."""
        def copy_fn(cur_i, cand_i):
          return _on_device(
              lambda: array_ops.where(elements_finished, cur_i, cand_i),
              device=cand_i.op.device)
        return nest.map_structure(copy_fn, current, candidate)

      emit_output = _copy_some_through(zero_emit, emit_output)
      next_state = _copy_some_through(state, next_state)

      emit_ta = nest.map_structure(
          lambda ta, emit: ta.write(time, emit), emit_ta, emit_output)

      elements_finished = math_ops.logical_or(elements_finished, next_finished)

      return (next_time, elements_finished, next_input,
              emit_ta, next_state, loop_state)
Example #17
0
def sparsemax_loss(logits, sparsemax, labels, name=None):
  """Computes sparsemax loss function [1].

  [1]: https://arxiv.org/abs/1602.02068

  Args:
    logits: A `Tensor`. Must be one of the following types: `half`, `float32`,
      `float64`.
    sparsemax: A `Tensor`. Must have the same type as `logits`.
    labels: A `Tensor`. Must have the same type as `logits`.
    name: A name for the operation (optional).

  Returns:
    A `Tensor`. Has the same type as `logits`.
  """

  with ops.name_scope(name, "sparsemax_loss",
                      [logits, sparsemax, labels]) as name:
    logits = ops.convert_to_tensor(logits, name="logits")
    sparsemax = ops.convert_to_tensor(sparsemax, name="sparsemax")
    labels = ops.convert_to_tensor(labels, name="labels")

    # In the paper, they call the logits z.
    # A constant can be substracted from logits to make the algorithm
    # more numerically stable in theory. However, there are really no major
    # source numerical instability in this algorithm.
    z = logits

    # sum over support
    # Use a conditional where instead of a multiplication to support z = -inf.
    # If z = -inf, and there is no support (sparsemax = 0), a multiplication
    # would cause 0 * -inf = nan, which is not correct in this case.
    sum_s = array_ops.where(
        math_ops.logical_or(sparsemax > 0, math_ops.is_nan(sparsemax)),
        sparsemax * (z - 0.5 * sparsemax), array_ops.zeros_like(sparsemax))

    # - z_k + ||q||^2
    q_part = labels * (0.5 * labels - z)
    # Fix the case where labels = 0 and z = -inf, where q_part would
    # otherwise be 0 * -inf = nan. But since the lables = 0, no cost for
    # z = -inf should be consideredself.
    # The code below also coveres the case where z = inf. Howeverm in this
    # caose the sparsemax will be nan, which means the sum_s will also be nan,
    # therefor this case doesn't need addtional special treatment.
    q_part_safe = array_ops.where(
        math_ops.logical_and(math_ops.equal(labels, 0), math_ops.is_inf(z)),
        array_ops.zeros_like(z), q_part)

    return math_ops.reduce_sum(sum_s + q_part_safe, axis=1)
Example #18
0
    def control_map_fn(x, y):

      def multiply():
        return x * 2

      def divide():
        return x // 2

      pred_fn_pairs = {
          math_ops.logical_or(math_ops.equal(y, 2), math_ops.equal(y, 3)):
              divide,
      }

      return control_flow_ops.case(
          pred_fn_pairs, default=multiply, exclusive=True)
Example #19
0
 def maybe_update_masks():
   with ops.name_scope(self._spec.name):
     is_step_within_pruning_range = math_ops.logical_and(
         math_ops.greater_equal(self._global_step,
                                self._spec.begin_pruning_step),
         # If end_pruning_step is negative, keep pruning forever!
         math_ops.logical_or(
             math_ops.less_equal(self._global_step,
                                 self._spec.end_pruning_step),
             math_ops.less(self._spec.end_pruning_step, 0)))
     is_pruning_step = math_ops.less_equal(
         math_ops.add(self._last_update_step, self._spec.pruning_frequency),
         self._global_step)
     return math_ops.logical_and(is_step_within_pruning_range,
                                 is_pruning_step)
Example #20
0
        def body(time, elements_finished, current_input, state_ta, emit_ta,
                 state, loop_state):
            (next_output, cell_state) = cell(current_input, state)

            nest.assert_same_structure(state, cell_state)
            nest.assert_same_structure(cell.output_size, next_output)

            next_time = time + 1
            (next_finished, next_input, next_state, emit_output,
             next_loop_state) = loop_fn(next_time, next_output, cell_state,
                                        loop_state)

            nest.assert_same_structure(state, next_state)
            nest.assert_same_structure(current_input, next_input)
            nest.assert_same_structure(emit_ta, emit_output)

            # If loop_fn returns None for next_loop_state, just reuse the previous one.
            loop_state = loop_state if next_loop_state is None else next_loop_state

            def _copy_some_through(current, candidate):
                """Copy some tensors through via array_ops.where."""
                def copy_fn(cur_i, cand_i):
                    # TensorArray and scalar get passed through.
                    if isinstance(cur_i, tensor_array_ops.TensorArray):
                        return cand_i
                    if cur_i.shape.ndims == 0:
                        return cand_i
                    # Otherwise propagate the old or the new value.
                    with ops.colocate_with(cand_i):
                        return array_ops.where(elements_finished, cur_i,
                                               cand_i)

                return nest.map_structure(copy_fn, current, candidate)

            emit_output = _copy_some_through(zero_emit, emit_output)
            next_state = _copy_some_through(state, next_state)

            emit_ta = nest.map_structure(lambda ta, emit: ta.write(time, emit),
                                         emit_ta, emit_output)
            state_ta = nest.map_structure(
                lambda ta, state: ta.write(time, state), state_ta, next_state)

            elements_finished = math_ops.logical_or(elements_finished,
                                                    next_finished)

            return (next_time, elements_finished, next_input, state_ta,
                    emit_ta, next_state, loop_state)
Example #21
0
    def loop_fn(time, cell_output, cell_state, loop_state):
        next_cell_state = initial_state if cell_output is None else cell_state

        elements_finished = math_ops.logical_or(
            time >= sequence_length,
            cell.termination_condition(next_cell_state))
        finished = math_ops.reduce_all(elements_finished)

        next_input = control_flow_ops.cond(
            finished, lambda: array_ops.zeros_like(initial_input),
            lambda: initial_input
            if cell_output is None else cell.output_function(next_cell_state))
        emit_output = next_input[0] if cell_output is None else next_input

        next_loop_state = None
        return (elements_finished, next_input, next_cell_state, emit_output,
                next_loop_state)
Example #22
0
    def convert_nan_or_inf_to_zero(self, grad):
        """Replace grad tensor with zero tensor if grad is NaN or Inf.

     This is mainly for improving training stability. We skip updating the
     variable by setting the grad to zero when there is NaN or Inf.

    Args:
      grad: Input gradient.

    Returns:
      a Tensor with the dtype equal to grad dtype.
    """
        return array_ops.where(
            math_ops.reduce_any(
                math_ops.logical_or(math_ops.is_nan(grad),
                                    math_ops.is_inf(grad))),
            array_ops.zeros_like(grad, dtype=grad.dtype), grad)
Example #23
0
 def build_main_test():
     """Main iteration condition."""
     # TODO(b/138857806): The optimizer should handle this.
     # LogicalAnd is slow on GPU so we avoid adding it if `delta` is a
     # compile time constant.
     delta_const = tensor_util.constant_value(delta)
     if delta_const is not None:
         # Support single element arrays.
         delta_const = np.asscalar(delta_const)
         if delta_const >= 0:
             return iterate < limit
         else:
             return iterate > limit
     else:
         return math_ops.logical_or(
             math_ops.logical_and(delta >= 0, iterate < limit),
             math_ops.logical_and(delta < 0, iterate > limit))
Example #24
0
    def aug_test():
        # TODO(b/159713842): Remove once constant folding works.
        const_delta = tensor_util.constant_value(delta)
        if const_delta is not None:
            if const_delta >= 0:
                main_test = iterate.value < limit
            else:
                main_test = iterate.value > limit
        else:
            main_test = math_ops.logical_or(
                math_ops.logical_and(delta >= 0, iterate.value < limit),
                math_ops.logical_and(delta < 0, iterate.value > limit))

        if extra_test is not None:
            main_test = control_flow_ops.cond(main_test, extra_test,
                                              lambda: False)
        return main_test
    def loop_fn(time, cell_output, cell_state, loop_state):
        next_cell_state = initial_state if cell_output is None else cell_state

        elements_finished = math_ops.logical_or(
            time >= sequence_length,
            cell.termination_condition(next_cell_state)
        )
        finished = math_ops.reduce_all(elements_finished)

        next_input = control_flow_ops.cond(
            finished,
            lambda: array_ops.zeros_like(initial_input),
            lambda: initial_input if cell_output is None else cell.output_function(next_cell_state)
        )
        emit_output = next_input[0] if cell_output is None else next_input

        next_loop_state = None
        return (elements_finished, next_input, next_cell_state, emit_output, next_loop_state)
Example #26
0
 def maybe_update_masks():
     with ops.name_scope(self._spec.name):
         is_step_within_pruning_range = math_ops.logical_and(
             self._spec.pruning_on,
             math_ops.logical_and(
                 math_ops.greater_equal(self._global_step,
                                        self._spec.begin_pruning_step),
                 # If end_pruning_step is negative, keep pruning forever!
                 math_ops.logical_or(
                     math_ops.less_equal(self._global_step,
                                         self._spec.end_pruning_step),
                     math_ops.less(self._spec.end_pruning_step, 0))))
         is_pruning_step = math_ops.less_equal(
             math_ops.add(self._last_update_step,
                          self._spec.pruning_frequency),
             self._global_step)
         return math_ops.logical_and(is_step_within_pruning_range,
                                     is_pruning_step)
Example #27
0
def _shape_tensor_compatible(expected_shape, actual_shape):
    """Returns whether actual_shape is compatible with expected_shape.

  Note that -1 in `expected_shape` is recognized as unknown dimension.

  Args:
    expected_shape: Integer list defining the expected shape, or tensor of same.
    actual_shape: Shape of the tensor to test.
  Returns:
    New tensor.
  """
    with ops.name_scope('shape_tensor_equal',
                        values=[expected_shape, actual_shape]) as scope:
        return math_ops.reduce_all(math_ops.logical_or(
            math_ops.equal(expected_shape, -1),
            math_ops.equal(expected_shape, actual_shape, 'equal'),
            name='exclude_partial_shape'),
                                   name=scope)
Example #28
0
def _shape_tensor_compatible(expected_shape, actual_shape):
  """Returns whether actual_shape is compatible with expected_shape.

  Note that -1 in `expected_shape` is recognized as unknown dimension.

  Args:
    expected_shape: Integer list defining the expected shape, or tensor of same.
    actual_shape: Shape of the tensor to test.
  Returns:
    New tensor.
  """
  with ops.name_scope('shape_tensor_equal',
                      values=[expected_shape, actual_shape]) as scope:
    return math_ops.reduce_all(
        math_ops.logical_or(
            math_ops.equal(expected_shape, -1),
            math_ops.equal(expected_shape, actual_shape, 'equal'),
            name='exclude_partial_shape'),
        name=scope)
def _check_batch_beam(t, batch_size, beam_width):
    """Returns an Assert operation checking that the elements of the stacked
  TensorArray can be reshaped to [batch_size, beam_size, -1]. At this point,
  the TensorArray elements have a known rank of at least 1.
  """
    error_message = (
        "TensorArray reordering expects elements to be "
        "reshapable to [batch_size, beam_size, -1] which is "
        "incompatible with the dynamic shape of %s elements. "
        "Consider setting reorder_tensor_arrays to False to disable "
        "TensorArray reordering during the beam search." % (t.name))
    rank = t.shape.ndims
    shape = array_ops.shape(t)
    if rank == 2:
        condition = math_ops.equal(shape[1], batch_size * beam_width)
    else:
        condition = math_ops.logical_or(
            math_ops.equal(shape[1], batch_size * beam_width),
            math_ops.logical_and(math_ops.equal(shape[1], batch_size),
                                 math_ops.equal(shape[2], beam_width)))
    return control_flow_ops.Assert(condition, [error_message])
        def body(time, elements_finished, current_input, state_ta, emit_ta, state, loop_state):
            (next_output, cell_state) = cell(current_input, state)

            nest.assert_same_structure(state, cell_state)
            nest.assert_same_structure(cell.output_size, next_output)

            next_time = time + 1
            (next_finished, next_input, next_state, emit_output,
             next_loop_state) = loop_fn(next_time, next_output, cell_state, loop_state)

            nest.assert_same_structure(state, next_state)
            nest.assert_same_structure(current_input, next_input)
            nest.assert_same_structure(emit_ta, emit_output)

            # If loop_fn returns None for next_loop_state, just reuse the previous one.
            loop_state = loop_state if next_loop_state is None else next_loop_state

            def _copy_some_through(current, candidate):
                """Copy some tensors through via array_ops.where."""
                def copy_fn(cur_i, cand_i):
                    # TensorArray and scalar get passed through.
                    if isinstance(cur_i, tensor_array_ops.TensorArray):
                        return cand_i
                    if cur_i.shape.ndims == 0:
                        return cand_i
                    # Otherwise propagate the old or the new value.
                    with ops.colocate_with(cand_i):
                        return array_ops.where(elements_finished, cur_i, cand_i)
                return nest.map_structure(copy_fn, current, candidate)

            emit_output = _copy_some_through(zero_emit, emit_output)
            next_state = _copy_some_through(state, next_state)

            emit_ta = nest.map_structure(lambda ta, emit: ta.write(time, emit), emit_ta, emit_output)
            state_ta = nest.map_structure(lambda ta, state: ta.write(time, state), state_ta, next_state)

            elements_finished = math_ops.logical_or(elements_finished, next_finished)

            return (next_time, elements_finished, next_input, state_ta,
                    emit_ta, next_state, loop_state)
Example #31
0
def _clip_by_value_grad(op, grad):
    """Returns grad of clip_by_value."""
    x = op.inputs[0]
    y = op.inputs[1]
    z = op.inputs[2]
    gdtype = grad.dtype
    sx = array_ops.shape(x)
    sy = array_ops.shape(y)
    sz = array_ops.shape(z)
    gradshape = array_ops.shape(grad)
    zeros = array_ops.zeros(gradshape, gdtype)
    xymask = math_ops.less(x, y)
    xzmask = math_ops.greater(x, z)
    rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
    rx, rz = gen_array_ops.broadcast_gradient_args(sx, sz)
    xgrad = array_ops.where(math_ops.logical_or(xymask, xzmask), zeros, grad)
    ygrad = array_ops.where(xymask, grad, zeros)
    zgrad = array_ops.where(xzmask, grad, zeros)
    gx = array_ops.reshape(math_ops.reduce_sum(xgrad, rx), sx)
    gy = array_ops.reshape(math_ops.reduce_sum(ygrad, ry), sy)
    gz = array_ops.reshape(math_ops.reduce_sum(zgrad, rz), sz)
    return (gx, gy, gz)
def _check_batch_beam(t, batch_size, beam_width):
  """Returns an Assert operation checking that the elements of the stacked
  TensorArray can be reshaped to [batch_size, beam_size, -1]. At this point,
  the TensorArray elements have a known rank of at least 1.
  """
  error_message = ("TensorArray reordering expects elements to be "
                   "reshapable to [batch_size, beam_size, -1] which is "
                   "incompatible with the dynamic shape of %s elements. "
                   "Consider setting reorder_tensor_arrays to False to disable "
                   "TensorArray reordering during the beam search."
                   % (t.name))
  rank = t.shape.ndims
  shape = array_ops.shape(t)
  if rank == 2:
    condition = math_ops.equal(shape[1], batch_size * beam_width)
  else:
    condition = math_ops.logical_or(
        math_ops.equal(shape[1], batch_size * beam_width),
        math_ops.logical_and(
            math_ops.equal(shape[1], batch_size),
            math_ops.equal(shape[2], beam_width)))
  return control_flow_ops.Assert(condition, [error_message])
Example #33
0
def _clip_by_value_grad(op, grad):
  """Returns grad of clip_by_value."""
  x = op.inputs[0]
  y = op.inputs[1]
  z = op.inputs[2]
  gdtype = grad.dtype
  sx = array_ops.shape(x)
  sy = array_ops.shape(y)
  sz = array_ops.shape(z)
  gradshape = array_ops.shape(grad)
  zeros = array_ops.zeros(gradshape, gdtype)
  xymask = math_ops.less(x, y)
  xzmask = math_ops.greater(x, z)
  rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
  rx, rz = gen_array_ops.broadcast_gradient_args(sx, sz)
  xgrad = array_ops.where(math_ops.logical_or(xymask, xzmask), zeros, grad)
  ygrad = array_ops.where(xymask, grad, zeros)
  zgrad = array_ops.where(xzmask, grad, zeros)
  gx = array_ops.reshape(math_ops.reduce_sum(xgrad, rx), sx)
  gy = array_ops.reshape(math_ops.reduce_sum(ygrad, ry), sy)
  gz = array_ops.reshape(math_ops.reduce_sum(zgrad, rz), sz)
  return (gx, gy, gz)
    def control_map_fn(x, y):

      def multiply():
        return x * 2

      def divide():
        return x // 2

      def defaults_two():
        return control_flow_ops.cond(
            math_ops.equal(math_ops.mod(x, 2), 0),
            multiply,
            divide,
            name="cond_mult")

      pred_fn_pairs = {
          math_ops.logical_or(math_ops.equal(y, 2), math_ops.equal(y, 3)):
              defaults_two,
      }

      return control_flow_ops.case(
          pred_fn_pairs, default=multiply, exclusive=True)
Example #35
0
def decode_image_with_file_name(contents, file_name, channels=None, name=None):
    # Cannot use the tensorflow decode_image function because sometimes the image file does not contain the proper
    # prefix for some reason, but as long as the extension is correct, my function will work.
    assert len(contents.get_shape().as_list()) == 0 and len(
        file_name.get_shape().as_list()) == 0
    ext = get_tf_string_extension(file_name)

    def _png():
        return gen_image_ops.decode_png(contents, channels)

    def _jpeg():
        is_jpeg = tf.logical_or(math_ops.equal(ext, 'jpg', name='is_jpg'),
                                math_ops.equal(ext, 'JPG', name='is_jpg_cap'))
        decode_msg = 'Unable to decode bytes as JPEG or PNG.'
        assert_decode = control_flow_ops.Assert(is_jpeg, [decode_msg])
        with ops.control_dependencies([
                assert_decode,
        ]):
            return gen_image_ops.decode_jpeg(contents, channels)

    is_png = math_ops.logical_or(math_ops.equal(ext, 'png', name='is_png'),
                                 math_ops.equal(ext, 'PNG', name='is_png_cap'))
    return control_flow_ops.cond(is_png, _png, _jpeg, name='cond_png')
Example #36
0
def next_inputs_fn(self, time, outputs, state, sample_ids):
    def next_inputs(self, time, outputs, state):
        next_time = time + 1
        finished = (next_time >= self._sequence_length)
        all_finished = math_ops.reduce_all(finished)
        def read_from_ta(inp):
            return inp.read(next_time)
        next_inputs = control_flow_ops.cond(
            all_finished, lambda: self._zero_inputs,
            lambda: nest.map_structure(read_from_ta, self._input_tas))
        return (finished, next_inputs, state)
    (finished, base_next_inputs, state) = next_inputs(self, time, outputs, state)
    def maybe_sample():    
        where_sampling = math_ops.cast(
            array_ops.where(sample_ids), dtypes.int32)
        where_not_sampling = math_ops.cast(
            array_ops.where(math_ops.logical_not(sample_ids)), dtypes.int32)
        outputs_sampling = array_ops.gather_nd(outputs, where_sampling)
        inputs_not_sampling = array_ops.gather_nd(base_next_inputs,
                                                  where_not_sampling)
        base_shape = array_ops.shape(base_next_inputs)
        z_mean, z_logstd = tf.split(outputs_sampling, 2, 1)
        sampled_next_inputs = normal.Normal(z_mean, tf.exp(z_logstd))
        return (array_ops.scatter_nd(indices=where_sampling,
                                         updates=sampled_next_inputs,
                                         shape=base_shape)
                    + array_ops.scatter_nd(indices=where_not_sampling,
                                           updates=inputs_not_sampling,
                                           shape=base_shape))
    all_finished = math_ops.reduce_all(finished)
    no_samples = math_ops.logical_not(math_ops.reduce_any(sample_ids))
    next_inputs = control_flow_ops.cond(
        math_ops.logical_or(all_finished, no_samples),
        lambda: base_next_inputs, maybe_sample)
    return (finished, next_inputs, state)
        
    
Example #37
0
def decode(image_buffer, image_format):
    def decode_image():
        return tf.cast(image_ops.decode_image(image_buffer, channels=3),
                       tf.uint8)

    def decode_jpeg():
        return tf.cast(image_ops.decode_jpeg(image_buffer, channels=3),
                       tf.uint8)

    def check_jpeg():
        return tf.cond(image_ops.is_jpeg(image_buffer), decode_jpeg,
                       decode_image)

    def decode_raw():
        return parsing_ops.decode_raw(image_buffer, out_type=tf.uint8)

    image = tf.cond(
        math_ops.logical_or(math_ops.equal(image_format, 'raw'),
                            math_ops.equal(image_format, 'RAW')), decode_raw,
        check_jpeg)

    image.set_shape([None, None, 3])

    return image
    def _decode(self, image_buffer, image_format, image_height, image_width):
        """Decodes the image buffer.

        Args:
          image_buffer: The tensor representing the encoded image tensor.
          image_format: The image format for the image in `image_buffer`. If image
            format is `raw`, all images are expected to be in this format, otherwise
            this op can decode a mix of `jpg` and `png` formats.

        Returns:
          A tensor that represents decoded image of self._shape, or
          (?, ?, self._channels) if self._shape is not specified.
        """

        def decode_image():
            """Decodes a png or jpg based on the headers."""
            return image_ops.decode_image(image_buffer, self._channels)

        def decode_raw():
            """Decodes a raw image."""
            return parsing_ops.decode_raw(image_buffer, out_type=self._dtype)

        pred_fn_pairs = {
            math_ops.logical_or(
                math_ops.equal(image_format, 'raw'),
                math_ops.equal(image_format, 'RAW')): decode_raw,
        }

        if self._dtype == dtypes.uint8:
            image = control_flow_ops.case(pred_fn_pairs, default=decode_image, exclusive=True)
        else:
            image = decode_raw()

        image = array_ops.reshape(image, tf.stack([image_height, image_width, 3]))

        return image
Example #39
0
def dynamic_decode(decoder,
                   output_time_major=False,
                   impute_finished=False,
                   maximum_iterations=None,
                   parallel_iterations=32,
                   swap_memory=False,
                   scope=None):
    """Perform dynamic decoding with `decoder`.

  Calls initialize() once and step() repeatedly on the Decoder object.

  Args:
    decoder: A `Decoder` instance.
    output_time_major: Python boolean.  Default: `False` (batch major).  If
      `True`, outputs are returned as time major tensors (this mode is faster).
      Otherwise, outputs are returned as batch major tensors (this adds extra
      time to the computation).
    impute_finished: Python boolean.  If `True`, then states for batch
      entries which are marked as finished get copied through and the
      corresponding outputs get zeroed out.  This causes some slowdown at
      each time step, but ensures that the final state and outputs have
      the correct values and that backprop ignores time steps that were
      marked as finished.
    maximum_iterations: `int32` scalar, maximum allowed number of decoding
       steps.  Default is `None` (decode until the decoder is fully done).
    parallel_iterations: Argument passed to `tf.while_loop`.
    swap_memory: Argument passed to `tf.while_loop`.
    scope: Optional variable scope to use.

  Returns:
    `(final_outputs, final_state, final_sequence_lengths)`.

  Raises:
    TypeError: if `decoder` is not an instance of `Decoder`.
    ValueError: if `maximum_iterations` is provided but is not a scalar.
  """
    if not isinstance(decoder, Decoder):
        raise TypeError("Expected decoder to be type Decoder, but saw: %s" %
                        type(decoder))

    with variable_scope.variable_scope(scope, "decoder") as varscope:
        # Properly cache variable values inside the while_loop
        if varscope.caching_device is None:
            varscope.set_caching_device(lambda op: op.device)

        if maximum_iterations is not None:
            maximum_iterations = ops.convert_to_tensor(
                maximum_iterations,
                dtype=dtypes.int32,
                name="maximum_iterations")
            if maximum_iterations.get_shape().ndims != 0:
                raise ValueError("maximum_iterations must be a scalar")

        initial_finished, initial_inputs, initial_state = decoder.initialize()

        zero_outputs = _create_zero_outputs(decoder.output_size,
                                            decoder.output_dtype,
                                            decoder.batch_size)

        if maximum_iterations is not None:
            initial_finished = math_ops.logical_or(initial_finished,
                                                   0 >= maximum_iterations)
        initial_sequence_lengths = array_ops.zeros_like(initial_finished,
                                                        dtype=dtypes.int32)
        initial_time = constant_op.constant(0, dtype=dtypes.int32)

        def _shape(batch_size, from_shape):
            if (not isinstance(from_shape, tensor_shape.TensorShape)
                    or from_shape.ndims == 0):
                return tensor_shape.TensorShape(None)
            else:
                batch_size = tensor_util.constant_value(
                    ops.convert_to_tensor(batch_size, name="batch_size"))
                return tensor_shape.TensorShape([batch_size
                                                 ]).concatenate(from_shape)

        def _create_ta(s, d):
            return tensor_array_ops.TensorArray(dtype=d,
                                                size=0,
                                                dynamic_size=True,
                                                element_shape=_shape(
                                                    decoder.batch_size, s))

        initial_outputs_ta = nest.map_structure(_create_ta,
                                                decoder.output_size,
                                                decoder.output_dtype)

        def condition(unused_time, unused_outputs_ta, unused_state,
                      unused_inputs, finished, unused_sequence_lengths):
            return math_ops.logical_not(math_ops.reduce_all(finished))

        def body(time, outputs_ta, state, inputs, finished, sequence_lengths):
            """Internal while_loop body.

      Args:
        time: scalar int32 tensor.
        outputs_ta: structure of TensorArray.
        state: (structure of) state tensors and TensorArrays.
        inputs: (structure of) input tensors.
        finished: bool tensor (keeping track of what's finished).
        sequence_lengths: int32 tensor (keeping track of time of finish).

      Returns:
        `(time + 1, outputs_ta, next_state, next_inputs, next_finished,
          next_sequence_lengths)`.
        ```
      """
            (next_outputs, decoder_state, next_inputs,
             decoder_finished) = decoder.step(time, inputs, state)
            if decoder.tracks_own_finished:
                next_finished = decoder_finished
            else:
                next_finished = math_ops.logical_or(decoder_finished, finished)
            if maximum_iterations is not None:
                next_finished = math_ops.logical_or(
                    next_finished, time + 1 >= maximum_iterations)
            next_sequence_lengths = array_ops.where(
                math_ops.logical_and(math_ops.logical_not(finished),
                                     next_finished),
                array_ops.fill(array_ops.shape(sequence_lengths), time + 1),
                sequence_lengths)

            nest.assert_same_structure(state, decoder_state)
            nest.assert_same_structure(outputs_ta, next_outputs)
            nest.assert_same_structure(inputs, next_inputs)

            # Zero out output values past finish
            if impute_finished:
                emit = nest.map_structure(
                    lambda out, zero: array_ops.where(finished, zero, out),
                    next_outputs, zero_outputs)
            else:
                emit = next_outputs

            # Copy through states past finish
            def _maybe_copy_state(new, cur):
                # TensorArrays and scalar states get passed through.
                if isinstance(cur, tensor_array_ops.TensorArray):
                    pass_through = True
                else:
                    new.set_shape(cur.shape)
                    pass_through = (new.shape.ndims == 0)
                return new if pass_through else array_ops.where(
                    finished, cur, new)

            if impute_finished:
                next_state = nest.map_structure(_maybe_copy_state,
                                                decoder_state, state)
            else:
                next_state = decoder_state

            outputs_ta = nest.map_structure(
                lambda ta, out: ta.write(time, out), outputs_ta, emit)
            return (time + 1, outputs_ta, next_state, next_inputs,
                    next_finished, next_sequence_lengths)

        res = control_flow_ops.while_loop(
            condition,
            body,
            loop_vars=[
                initial_time,
                initial_outputs_ta,
                initial_state,
                initial_inputs,
                initial_finished,
                initial_sequence_lengths,
            ],
            parallel_iterations=parallel_iterations,
            swap_memory=swap_memory)

        final_outputs_ta = res[1]
        final_state = res[2]
        final_sequence_lengths = res[5]

        final_outputs = nest.map_structure(lambda ta: ta.stack(),
                                           final_outputs_ta)

        try:
            final_outputs, final_state = decoder.finalize(
                final_outputs, final_state, final_sequence_lengths)
        except NotImplementedError:
            pass

        if not output_time_major:
            final_outputs = nest.map_structure(_transpose_batch_time,
                                               final_outputs)

    return final_outputs, final_state, final_sequence_lengths
Example #40
0
 def add_or_or(x1, x2):
     if x1.dtype == dtypes.bool:
         assert x2.dtype == dtypes.bool
         return math_ops.logical_or(x1, x2)
     return math_ops.add(x1, x2)
Example #41
0
 def max_or_or(x1, x2):
     if x1.dtype == dtypes.bool:
         assert x2.dtype == dtypes.bool
         return math_ops.logical_or(x1, x2)
     return math_ops.maximum(x1, x2)
Example #42
0
def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
                      beam_width, end_token, length_penalty_weight):
    """Performs a single step of Beam Search Decoding.

  Args:
    time: Beam search time step, should start at 0. At time 0 we assume
      that all beams are equal and consider only the first beam for
      continuations.
    logits: Logits at the current time step. A tensor of shape
      `[batch_size, beam_width, vocab_size]`
    next_cell_state: The next state from the cell, e.g. an instance of
      AttentionWrapperState if the cell is attentional.
    beam_state: Current state of the beam search.
      An instance of `BeamSearchDecoderState`.
    batch_size: The batch size for this input.
    beam_width: Python int.  The size of the beams.
    end_token: The int32 end token.
    length_penalty_weight: Float weight to penalize length. Disabled with 0.0.

  Returns:
    A new beam state.
  """

    static_batch_size = tensor_util.constant_value(batch_size)

    # Calculate the current lengths of the predictions
    prediction_lengths = beam_state.lengths
    previously_finished = beam_state.finished

    # Calculate the total log probs for the new hypotheses
    # Final Shape: [batch_size, beam_width, vocab_size]
    step_log_probs = nn_ops.log_softmax(logits)
    #step_log_probs",Tensor shape=(?, 10, 56136)
    step_log_probs = _mask_probs(step_log_probs, end_token,
                                 previously_finished)
    #step_log_probs_masked (?, 10, 56136)
    total_probs = array_ops.expand_dims(beam_state.log_probs,
                                        2) + step_log_probs
    #total_probs (?, 10, 56136)
    # Calculate the continuation lengths by adding to all continuing beams.
    vocab_size = logits.shape[-1].value or array_ops.shape(logits)[-1]
    lengths_to_add = array_ops.one_hot(
        indices=array_ops.tile(array_ops.reshape(end_token, [1, 1]),
                               [batch_size, beam_width]),
        depth=vocab_size,
        on_value=constant_op.constant(0, dtype=dtypes.int64),
        off_value=constant_op.constant(1, dtype=dtypes.int64),
        dtype=dtypes.int64)
    #lengths_to_add shape=(?, 10, 56136)
    add_mask = (1 - math_ops.to_int64(previously_finished))
    #add_mask shape=(?, 10), dtype=int64
    lengths_to_add = array_ops.expand_dims(add_mask, 2) * lengths_to_add
    #lengths_to_add shape=(?, 10, 56136)
    new_prediction_lengths = (lengths_to_add +
                              array_ops.expand_dims(prediction_lengths, 2))
    #new_prediction_lengths shape=(?, 10, 56136)
    # Calculate the scores for each beam
    scores = _get_scores(log_probs=total_probs,
                         sequence_lengths=new_prediction_lengths,
                         length_penalty_weight=length_penalty_weight)
    scores_mask = tf.constant([step_log_probs.dtype.min, 0],
                              dtype=dtypes.float32,
                              shape=[vocab_size],
                              name='mask')
    scores_masked = tf.add(scores, scores_mask)
    scores_mask2 = tf.constant([0, 0, 0, 0, 0, step_log_probs.dtype.min, 0],
                               dtype=dtypes.float32,
                               shape=[vocab_size],
                               name='mask2')
    scores_masked = tf.add(scores_mask2, scores_masked)

    def new_scores(scores_masked):
        scores_no_stop = tf.constant([0, 0, step_log_probs.dtype.min, 0],
                                     dtype=dtypes.float32,
                                     shape=[vocab_size],
                                     name='no_stop')
        scores = tf.add(scores_masked, scores_no_stop)
        return scores

    #constrain the length
    scores = control_flow_ops.cond(
        #time <9 ,
        time < 0,
        lambda: new_scores(scores_masked),
        lambda: scores_masked)

    #scores shape=(?, 10, 56136)
    #[batch_size, beam_width, vocab_size]
    time = ops.convert_to_tensor(time, name="time")
    # During the first time step we only consider the initial beam
    scores_shape = array_ops.shape(scores)
    #scores_shape" shape=(3,)
    scores_to_flat_1 = array_ops.reshape(scores, [batch_size, 2, -1])
    print("scores_to_flat_1", scores_to_flat_1)
    scores_to_0 = scores[:, 0]
    scores_to_1 = scores[:, -1]
    scores_to_flat_2 = tf.concat([scores_to_0, scores_to_1], 1)
    scores_flat = control_flow_ops.cond(
        time > 0, lambda: scores_to_flat_1,
        lambda: array_ops.reshape(scores_to_flat_2, [batch_size, 2, -1]))
    num_available_beam = control_flow_ops.cond(
        time > 0, lambda: math_ops.reduce_prod(scores_shape[1:]),
        lambda: math_ops.reduce_prod(scores_shape[2:]))
    #scores_flat", shape=(?, ?)
    #num_available_beam" shape=()
    # Pick the next beams according to the specified successors function
    next_beam_size = math_ops.minimum(
        ops.convert_to_tensor(beam_width,
                              dtype=dtypes.int32,
                              name="beam_width"), num_available_beam)
    #scores_t = tf.reshape(scores_flat,[batch_size,2,-1])
    ############################
    #input_words=['entrencheds01', 'entrencheds02', 'forgev01', 'forgev04', \
    #             'hitn02', 'hitn03', 'vaultn02', 'vaultn04', 'deepa03', \
    #             'deeps02', 'admitv01', 'admitv02', 'plantn01', 'plantn02',\
    #             'squaren01', 'squaren05', 'drawv05', 'drawv06', 'spellv03', \
    #             'spellv02', 'shotn02', 'shotn04', 'coachv01', 'coachv02', 'casen05',\
    #             'casen09', 'focusn01', 'focusn02', 'tasten01', 'tasten04', 'footn01', \
    #             'footv01']
    input_words = get_words()
    return_list = prior_scores(input_words)
    return_array = np.array(return_list)
    return_tensor = tf.convert_to_tensor(return_array)
    tiling = [1, 5, 1]
    prior_mask = tf.tile(tf.expand_dims(return_tensor, 1), tiling)
    prior_mask = tf.cast(prior_mask, tf.float32)
    prior_mask = array_ops.reshape(prior_mask, [batch_size, -1])
    #print ("prior_mask",prior_mask)
    scores_sum = tf.reduce_sum(scores_to_flat_1, 1)

    #print ("scores_sum_1",scores_sum)
    #def cal_scores_sum(scores_sum,prior_mask):
    #    return tf.add(scores_sum,prior_mask)
    #scores_sum = control_flow_ops.cond(
    #    time > 0,
    #    lambda: cal_scores_sum(scores_sum,prior_mask),
    #    lambda: scores_sum)
    #scores_sum=tf.add(scores_sum,prior_mask)
    #print ("scores_sum_2",scores_sum)
    ############################

    #scores_final=tf.concat([scores_sum, scores_sum],1)
    def cal_scores_indices(scores_to_0, scores_to_1):
        next_beam_scores_1, word_indices_1 = nn_ops.top_k(scores_to_0, k=5)
        print("ori next_beam_scores_1,word_indices_1", next_beam_scores_1)
        print("ori word_indices_1", word_indices_1)
        next_beam_scores_2, word_indices_2 = nn_ops.top_k(scores_to_1, k=5)
        next_beam_scores = tf.concat([next_beam_scores_1, next_beam_scores_2],
                                     1)
        word_indices = tf.concat(
            [word_indices_1, word_indices_2 + 9 * vocab_size], 1)
        return next_beam_scores, word_indices

    def cal_scores_indices_t1(scores_final, next_beam_size):
        next_beam_scores_1, word_indices_1 = nn_ops.top_k(scores_final, k=5)
        #next_beam_scores_1, word_indices_1=sample(next_beam_scores_1,word_indices_1)
        print("next_beam_scores_1", next_beam_scores_1)
        print("word_indices_1", word_indices_1)
        next_beam_scores = tf.concat([next_beam_scores_1, next_beam_scores_1],
                                     1)
        word_indices = tf.concat(
            [word_indices_1, word_indices_1 + 5 * vocab_size], 1)
        return next_beam_scores, word_indices

    next_beam_scores, word_indices = control_flow_ops.cond(
        time > 0, lambda: cal_scores_indices_t1(scores_sum, next_beam_size),
        lambda: cal_scores_indices(scores_to_0, scores_to_1))

    next_beam_scores.set_shape([static_batch_size, beam_width])
    word_indices.set_shape([static_batch_size, beam_width])
    #shape=(?, ?)
    # Pick out the probs, beam_ids, and states according to the chosen predictions

    next_beam_probs = _tensor_gather_helper(gather_indices=word_indices,
                                            gather_from=total_probs,
                                            batch_size=batch_size,
                                            range_size=beam_width * vocab_size,
                                            gather_shape=[-1],
                                            name="next_beam_probs")
    # Note: just doing the following
    #   math_ops.to_int32(word_indices % vocab_size,
    #       name="next_beam_word_ids")
    # would be a lot cleaner but for reasons unclear, that hides the results of
    # the op which prevents capturing it with tfdbg debug ops.
    raw_next_word_ids = math_ops.mod(word_indices,
                                     vocab_size,
                                     name="next_beam_word_ids")
    #raw_next_word_ids shape=(?, 10)
    next_word_ids = math_ops.to_int32(raw_next_word_ids)
    next_beam_ids = math_ops.to_int32(word_indices / vocab_size,
                                      name="next_beam_parent_ids")

    # Append new ids to current predictions
    previously_finished = _tensor_gather_helper(
        gather_indices=next_beam_ids,
        gather_from=previously_finished,
        batch_size=batch_size,
        range_size=beam_width,
        gather_shape=[-1])
    next_finished = math_ops.logical_or(previously_finished,
                                        math_ops.equal(next_word_ids,
                                                       end_token),
                                        name="next_beam_finished")

    # Calculate the length of the next predictions.
    # 1. Finished beams remain unchanged
    # 2. Beams that are now finished (EOS predicted) remain unchanged
    # 3. Beams that are not yet finished have their length increased by 1
    lengths_to_add = math_ops.to_int64(
        math_ops.not_equal(next_word_ids, end_token))
    lengths_to_add = (1 - math_ops.to_int64(next_finished)) * lengths_to_add
    next_prediction_len = _tensor_gather_helper(gather_indices=next_beam_ids,
                                                gather_from=beam_state.lengths,
                                                batch_size=batch_size,
                                                range_size=beam_width,
                                                gather_shape=[-1])
    next_prediction_len += lengths_to_add

    # Pick out the cell_states according to the next_beam_ids. We use a
    # different gather_shape here because the cell_state tensors, i.e.
    # the tensors that would be gathered from, all have dimension
    # greater than two and we need to preserve those dimensions.
    # pylint: disable=g-long-lambda
    next_cell_state = nest.map_structure(
        lambda gather_from: _maybe_tensor_gather_helper(
            gather_indices=next_beam_ids,
            gather_from=gather_from,
            batch_size=batch_size,
            range_size=beam_width,
            gather_shape=[batch_size * beam_width, -1]), next_cell_state)
    # pylint: enable=g-long-lambda

    next_state = BeamSearchDecoderState(cell_state=next_cell_state,
                                        log_probs=next_beam_probs,
                                        lengths=next_prediction_len,
                                        finished=next_finished)
    print('next_beam_probs', next_beam_probs)
    output = BeamSearchDecoderOutput(scores=next_beam_scores,
                                     predicted_ids=next_word_ids,
                                     parent_ids=next_beam_ids)

    return output, next_state
Example #43
0
def dynamic_multi_decode(decoders,
                         ma_policy,
                         maps_g2l,
                         word_embeddings,
                         policy_mode=None,
                         output_time_major=False,
                         impute_finished=False,
                         maximum_iterations=None,
                         parallel_iterations=32,
                         swap_memory=False,
                         scope=None):
    """Perform dynamic decoding with `decoder`.

    Calls initialize() once and step() repeatedly on the Decoder object.

    Args:
      decoder: A `Decoder` instance.
      output_time_major: Python boolean.  Default: `False` (batch major).  If
        `True`, outputs are returned as time major tensors (this mode is faster).
        Otherwise, outputs are returned as batch major tensors (this adds extra
        time to the computation).
      impute_finished: Python boolean.  If `True`, then states for batch
        entries which are marked as finished get copied through and the
        corresponding outputs get zeroed out.  This causes some slowdown at
        each time step, but ensures that the final state and outputs have
        the correct values and that backprop ignores time steps that were
        marked as finished.
      maximum_iterations: `int32` scalar, maximum allowed number of decoding
         steps.  Default is `None` (decode until the decoder is fully done).
      parallel_iterations: Argument passed to `tf.while_loop`.
      swap_memory: Argument passed to `tf.while_loop`.
      scope: Optional variable scope to use.

    Returns:
      `(final_outputs, final_state, final_sequence_lengths)`.

    Raises:
      TypeError: if `decoder` is not an instance of `Decoder`.
      ValueError: if `maximum_iterations` is provided but is not a scalar.
    """

    decoders_zero_outputs = []
    final_outputs_ta = None

    def _shape(batch_size, from_shape):
        if not isinstance(from_shape, tensor_shape.TensorShape):
            return tensor_shape.TensorShape(None)
        else:
            batch_size = tensor_util.constant_value(
                ops.convert_to_tensor(batch_size, name="batch_size"))
            return tensor_shape.TensorShape([batch_size
                                             ]).concatenate(from_shape)

    def _create_ta(s, d):
        return tensor_array_ops.TensorArray(dtype=d,
                                            size=0,
                                            dynamic_size=True,
                                            element_shape=_shape(
                                                decoders[0].batch_size, s))

    def condition(unused_time, unused_outputs_ta, unused_state, unused_inputs,
                  finished, unused_sequence_lengths):
        return math_ops.logical_not(math_ops.reduce_all(finished[0]))

    def local2global(outputs, lno):
        # logits = output_projection_layers[lno](outputs)
        logits = tf.transpose(outputs.rnn_output, [1, 0])
        global_logits = tf.transpose(tf.gather(logits, maps_g2l[lno][0]),
                                     [1, 0]) * maps_g2l[lno][1]
        return global_logits

    def global2local(gid, lno):
        lid = tf.gather(maps_g2l[lno][0], gid)
        return tf.nn.embedding_lookup(word_embeddings[lno], lid)

    def body(time, outputs_ta, state, inputs, finished, sequence_lengths):
        """Internal while_loop body.

        Args:
          time: scalar int32 tensor.
          outputs_ta: structure of TensorArray.
          state: (structure of) state tensors and TensorArrays. list
          inputs: (structure of) input tensors. list
          finished: bool tensor (keeping track of what's finished). list
          sequence_lengths: int32 tensor (keeping track of time of finish).

        Returns:
          `(time + 1, outputs_ta, next_state, next_inputs, next_finished,
            next_sequence_lengths)`.
          ```
        """
        decoders_next_outputs = []
        decoders_next_states = []
        decoders_next_inputs = []
        decoders_next_finished = []
        decoders_next_seqlen = []
        decoders_next_ta = []

        outputs_collection = []

        decoder_cnt = 0
        for decoder in decoders:
            (next_outputs, decoder_state, next_inputs,
             decoder_finished) = decoder.step(time, inputs[decoder_cnt],
                                              state[decoder_cnt])
            next_finished = math_ops.logical_or(decoder_finished,
                                                finished[decoder_cnt])
            if maximum_iterations is not None:
                next_finished = math_ops.logical_or(
                    next_finished, time + 1 >= maximum_iterations)

            nest.assert_same_structure(state[decoder_cnt], decoder_state)
            nest.assert_same_structure(outputs_ta[decoder_cnt], next_outputs)
            nest.assert_same_structure(inputs[decoder_cnt], next_inputs)
            # Zero out output values past finish
            if impute_finished:
                emit = nest.map_structure(
                    lambda out, zero: array_ops.where(finished[decoder_cnt],
                                                      zero, out), next_outputs,
                    decoders_zero_outputs[decoder_cnt])
            else:
                emit = next_outputs

            outputs_collection.append(local2global(next_outputs, decoder_cnt))

            # Copy through states past finish
            def _maybe_copy_state(new, cur):
                # TensorArrays and scalar states get passed through.
                if isinstance(cur, tensor_array_ops.TensorArray):
                    pass_through = True
                else:
                    new.set_shape(cur.shape)
                    pass_through = (new.shape.ndims == 0)
                return new if pass_through else array_ops.where(
                    finished[decoder_cnt], cur, new)

            next_state = None
            if impute_finished:
                next_state = nest.map_structure(_maybe_copy_state,
                                                decoder_state,
                                                state[decoder_cnt])
            else:
                next_state = decoder_state
            this_outputs_ta = nest.map_structure(
                lambda ta, out: ta.write(time, out), outputs_ta[decoder_cnt],
                emit)

            next_sequence_lengths = array_ops.where(
                math_ops.logical_and(
                    math_ops.logical_not(finished[decoder_cnt]),
                    next_finished),
                array_ops.fill(array_ops.shape(sequence_lengths[decoder_cnt]),
                               time + 1), sequence_lengths[decoder_cnt])

            decoders_next_inputs.append(next_inputs)
            decoders_next_outputs.append(next_outputs)
            decoders_next_states.append(next_state)
            decoders_next_finished.append(next_finished)
            decoders_next_seqlen.append(next_sequence_lengths)
            decoders_next_ta.append(this_outputs_ta)
            decoder_cnt += 1

        ma_weights = tf.nn.softmax(tf.matmul(outputs_collection[0], ma_policy),
                                   -1)
        print('ma_weights_shape:', ma_weights)
        if policy_mode == 'FULL':
            outputs_collection = outputs_collection[1:]
        outputs_collection = tf.transpose(
            ops.convert_to_tensor(outputs_collection, dtype=dtypes.float32),
            [2, 1, 0])
        print('all_outputs_shape:', outputs_collection)
        final_outputs = tf.transpose(
            tf.reduce_sum(outputs_collection * ma_weights, -1), [1, 0])
        # final_outputs=tf.transpose(outputs_collection, [2,1,0])[0]
        print('final_outputs_shape:', final_outputs)
        sample_ids = math_ops.cast(math_ops.argmax(final_outputs, axis=-1),
                                   dtypes.int32)

        wrapped_final_outputs = BasicDecoderOutput(final_outputs, sample_ids)
        decoders_next_ta.append(
            nest.map_structure(lambda ta, out: ta.write(time, out),
                               outputs_ta[-2], wrapped_final_outputs))
        decoders_next_ta.append(
            nest.map_structure(lambda ta, out: ta.write(time, out),
                               outputs_ta[-1], ma_weights))

        for dno in range(len(decoders)):
            decoders_next_inputs[dno] = global2local(sample_ids, dno)

        outputs_ta = decoders_next_ta
        next_inputs = decoders_next_inputs
        next_state = decoders_next_states
        next_finished = decoders_next_finished
        next_seqlen = decoders_next_seqlen

        return (time + 1, outputs_ta, next_state, next_inputs, next_finished,
                next_seqlen)

    with variable_scope.variable_scope(scope, "decoder") as varscope:
        # Properly cache variable values inside the while_loop
        if varscope.caching_device is None:
            varscope.set_caching_device(lambda op: op.device)
        decoder_cnt = 0
        decoders_initial_finished = []
        decoders_initial_seqlen = []
        decoders_initial_inputs = []
        decoders_initial_state = []
        decoders_outputs_tas = []
        decoders_weights_ta = nest.map_structure(_create_ta, len(decoders),
                                                 tf.float32)

        if maximum_iterations is not None:
            maximum_iterations = ops.convert_to_tensor(
                maximum_iterations,
                dtype=dtypes.int32,
                name="maximum_iterations")
            if maximum_iterations.get_shape().ndims != 0:
                raise ValueError("maximum_iterations must be a scalar")

        initial_time = constant_op.constant(0, dtype=dtypes.int32)

        for decoder in decoders:
            decoder_cnt += 1
            if not isinstance(decoder, Decoder):
                raise TypeError(
                    "Expected decoder to be type Decoder_%d, but saw: %s" %
                    (decoder_cnt, type(decoder)))
            initial_finished, initial_inputs, initial_state = decoder.initialize(
            )
            zero_outputs = _create_zero_outputs(decoder.output_size,
                                                decoder.output_dtype,
                                                decoder.batch_size)
            if maximum_iterations is not None:
                initial_finished = math_ops.logical_or(initial_finished,
                                                       0 >= maximum_iterations)
            initial_sequence_lengths = array_ops.zeros_like(initial_finished,
                                                            dtype=dtypes.int32)
            initial_outputs_ta = nest.map_structure(_create_ta,
                                                    decoder.output_size,
                                                    decoder.output_dtype)

            decoders_initial_finished.append(initial_finished)
            decoders_initial_seqlen.append(initial_sequence_lengths)
            decoders_initial_inputs.append(initial_inputs)
            decoders_initial_state.append(initial_state)
            decoders_zero_outputs.append(zero_outputs)
            decoders_outputs_tas.append(initial_outputs_ta)

        decoders_outputs_tas.append(decoders_outputs_tas[0])
        decoders_outputs_tas.append(decoders_weights_ta)

        res = control_flow_ops.while_loop(
            condition,
            body,
            loop_vars=[
                initial_time, decoders_outputs_tas, decoders_initial_state,
                decoders_initial_inputs, decoders_initial_finished,
                decoders_initial_seqlen
            ],
            parallel_iterations=parallel_iterations,
            swap_memory=swap_memory)

        final_outputs_ta = res[1][-2:]
        final_state = res[2][0]
        final_sequence_lengths = res[5]

        final_outputs = nest.map_structure(lambda ta: ta.stack(),
                                           final_outputs_ta)

        # try:
        #     final_outputs, final_state = decoders[0].finalize(
        #         final_outputs, final_state, final_sequence_lengths)
        # except NotImplementedError:
        #     pass

        if not output_time_major:
            final_outputs = nest.map_structure(_transpose_batch_time,
                                               final_outputs)

    return final_outputs, final_state, final_sequence_lengths
def _beam_search_step(time, logits, beam_state, batch_size, beam_width,
                      end_token, length_penalty_weight):
    """Performs a single step of Beam Search Decoding.

  Args:
    time: Beam search time step, should start at 0. At time 0 we assume
      that all beams are equal and consider only the first beam for
      continuations.
    logits: Logits at the current time step. A tensor of shape
      `[batch_size, beam_width, vocab_size]`
    beam_state: Current state of the beam search.
      An instance of `BeamSearchDecoderState`.
    batch_size: The batch size for this input.
    beam_width: Python int.  The size of the beams.
    end_token: The int32 end token.
    length_penalty_weight: Float weight to penalize length. Disabled with 0.0.

  Returns:
    A new beam state.
  """
    static_batch_size = tensor_util.constant_value(batch_size)

    # Calculate the current lengths of the predictions
    prediction_lengths = beam_state.lengths
    previously_finished = beam_state.finished

    # Calculate the total log probs for the new hypotheses
    # Final Shape: [batch_size, beam_width, vocab_size]
    step_log_probs = nn_ops.log_softmax(logits)
    step_log_probs = _mask_probs(step_log_probs, end_token,
                                 previously_finished)
    total_probs = array_ops.expand_dims(beam_state.log_probs,
                                        2) + step_log_probs

    # Calculate the continuation lengths by adding to all continuing beams.
    vocab_size = logits.shape[-1].value
    lengths_to_add = array_ops.one_hot(indices=array_ops.tile(
        array_ops.reshape(end_token, [1, 1]), [batch_size, beam_width]),
                                       depth=vocab_size,
                                       on_value=0,
                                       off_value=1)
    add_mask = (1 - math_ops.to_int32(previously_finished))
    lengths_to_add = array_ops.expand_dims(add_mask, 2) * lengths_to_add
    new_prediction_lengths = (lengths_to_add +
                              array_ops.expand_dims(prediction_lengths, 2))

    # Calculate the scores for each beam
    scores = _get_scores(log_probs=total_probs,
                         sequence_lengths=new_prediction_lengths,
                         length_penalty_weight=length_penalty_weight)

    time = ops.convert_to_tensor(time, name="time")
    # During the first time step we only consider the initial beam
    scores_shape = array_ops.shape(scores)
    scores_flat = control_flow_ops.cond(
        time > 0, lambda: array_ops.reshape(scores, [batch_size, -1]),
        lambda: scores[:, 0])
    num_available_beam = control_flow_ops.cond(
        time > 0, lambda: math_ops.reduce_prod(scores_shape[1:]),
        lambda: math_ops.reduce_prod(scores_shape[2:]))

    # Pick the next beams according to the specified successors function
    next_beam_size = math_ops.minimum(
        ops.convert_to_tensor(beam_width,
                              dtype=dtypes.int32,
                              name="beam_width"), num_available_beam)
    next_beam_scores, word_indices = nn_ops.top_k(scores_flat,
                                                  k=next_beam_size)
    next_beam_scores.set_shape([static_batch_size, beam_width])
    word_indices.set_shape([static_batch_size, beam_width])

    # Pick out the probs, beam_ids, and states according to the chosen predictions
    next_beam_probs = _tensor_gather_helper(
        gather_indices=word_indices,
        gather_from=total_probs,
        range_input=batch_size,
        range_size=beam_width * vocab_size,
        final_shape=[static_batch_size, beam_width])

    next_word_ids = math_ops.to_int32(word_indices % vocab_size)
    next_beam_ids = math_ops.to_int32(word_indices / vocab_size)

    # Append new ids to current predictions
    previously_finished = _tensor_gather_helper(
        gather_indices=next_beam_ids,
        gather_from=previously_finished,
        range_input=batch_size,
        range_size=beam_width,
        final_shape=[static_batch_size, beam_width])
    next_finished = math_ops.logical_or(
        previously_finished, math_ops.equal(next_word_ids, end_token))

    # Calculate the length of the next predictions.
    # 1. Finished beams remain unchanged
    # 2. Beams that are now finished (EOS predicted) remain unchanged
    # 3. Beams that are not yet finished have their length increased by 1
    lengths_to_add = math_ops.to_int32(
        math_ops.not_equal(next_word_ids, end_token))
    lengths_to_add = (1 - math_ops.to_int32(next_finished)) * lengths_to_add
    next_prediction_len = _tensor_gather_helper(
        gather_indices=next_beam_ids,
        gather_from=beam_state.lengths,
        range_input=batch_size,
        range_size=beam_width,
        final_shape=[static_batch_size, beam_width])
    next_prediction_len += lengths_to_add

    next_state = BeamSearchDecoderState(cell_state=beam_state.cell_state,
                                        log_probs=next_beam_probs,
                                        lengths=next_prediction_len,
                                        finished=next_finished)

    output = BeamSearchDecoderOutput(scores=next_beam_scores,
                                     predicted_ids=next_word_ids,
                                     parent_ids=next_beam_ids)

    return output, next_state
Example #45
0
def dynamic_decode(decoder,
                   output_time_major=False,
                   impute_finished=False,
                   maximum_iterations=None,
                   parallel_iterations=32,
                   swap_memory=False,
                   scope=None):
  """Perform dynamic decoding with `decoder`.

  Calls initialize() once and step() repeatedly on the Decoder object.

  Args:
    decoder: A `Decoder` instance.
    output_time_major: Python boolean.  Default: `False` (batch major).  If
      `True`, outputs are returned as time major tensors (this mode is faster).
      Otherwise, outputs are returned as batch major tensors (this adds extra
      time to the computation).
    impute_finished: Python boolean.  If `True`, then states for batch
      entries which are marked as finished get copied through and the
      corresponding outputs get zeroed out.  This causes some slowdown at
      each time step, but ensures that the final state and outputs have
      the correct values and that backprop ignores time steps that were
      marked as finished.
    maximum_iterations: `int32` scalar, maximum allowed number of decoding
       steps.  Default is `None` (decode until the decoder is fully done).
    parallel_iterations: Argument passed to `tf.while_loop`.
    swap_memory: Argument passed to `tf.while_loop`.
    scope: Optional variable scope to use.

  Returns:
    `(final_outputs, final_state)`.

  Raises:
    TypeError: if `decoder` is not an instance of `Decoder`.
    ValueError: if maximum_iterations is provided but is not a scalar.
  """
  if not isinstance(decoder, Decoder):
    raise TypeError("Expected decoder to be type Decoder, but saw: %s" %
                    type(decoder))

  with variable_scope.variable_scope(scope or "decoder") as varscope:
    # Properly cache variable values inside the while_loop
    if varscope.caching_device is None:
      varscope.set_caching_device(lambda op: op.device)

    if maximum_iterations is not None:
      maximum_iterations = ops.convert_to_tensor(
          maximum_iterations, dtype=dtypes.int32, name="maximum_iterations")
      if maximum_iterations.get_shape().ndims != 0:
        raise ValueError("maximum_iterations must be a scalar")

    initial_finished, initial_inputs, initial_state = decoder.initialize()

    zero_outputs = _create_zero_outputs(decoder.output_size,
                                        decoder.output_dtype,
                                        decoder.batch_size)

    if maximum_iterations is not None:
      initial_finished = math_ops.logical_or(
          initial_finished, 0 >= maximum_iterations)
    initial_time = constant_op.constant(0, dtype=dtypes.int32)

    def _shape(batch_size, from_shape):
      if not isinstance(from_shape, tensor_shape.TensorShape):
        return tensor_shape.TensorShape(None)
      else:
        batch_size = tensor_util.constant_value(
            ops.convert_to_tensor(
                batch_size, name="batch_size"))
        return tensor_shape.TensorShape([batch_size]).concatenate(from_shape)

    def _create_ta(s, d):
      return tensor_array_ops.TensorArray(
          dtype=d,
          size=0,
          dynamic_size=True,
          element_shape=_shape(decoder.batch_size, s))

    initial_outputs_ta = nest.map_structure(_create_ta, decoder.output_size,
                                            decoder.output_dtype)

    def condition(unused_time, unused_outputs_ta, unused_state, unused_inputs,
                  finished):
      return math_ops.logical_not(math_ops.reduce_all(finished))

    def body(time, outputs_ta, state, inputs, finished):
      """Internal while_loop body.

      Args:
        time: scalar int32 tensor.
        outputs_ta: structure of TensorArray.
        state: (structure of) state tensors and TensorArrays.
        inputs: (structure of) input tensors.
        finished: 1-D bool tensor.

      Returns:
        `(time + 1, outputs_ta, next_state, next_inputs, next_finished)`.
      """
      (next_outputs, decoder_state, next_inputs,
       decoder_finished) = decoder.step(time, inputs, state)
      next_finished = math_ops.logical_or(decoder_finished, finished)
      if maximum_iterations is not None:
        next_finished = math_ops.logical_or(
            next_finished, time + 1 >= maximum_iterations)

      nest.assert_same_structure(state, decoder_state)
      nest.assert_same_structure(outputs_ta, next_outputs)
      nest.assert_same_structure(inputs, next_inputs)

      # Zero out output values past finish
      if impute_finished:
        emit = nest.map_structure(
            lambda out, zero: array_ops.where(finished, zero, out),
            next_outputs,
            zero_outputs)
      else:
        emit = next_outputs

      # Copy through states past finish
      def _maybe_copy_state(new, cur):
        # TensorArrays and scalar states get passed through.
        if isinstance(cur, tensor_array_ops.TensorArray):
          pass_through = True
        else:
          new.set_shape(cur.shape)
          pass_through = (new.shape.ndims == 0)
        return new if pass_through else array_ops.where(finished, cur, new)

      if impute_finished:
        next_state = nest.map_structure(
            _maybe_copy_state, decoder_state, state)
      else:
        next_state = decoder_state

      outputs_ta = nest.map_structure(lambda ta, out: ta.write(time, out),
                                      outputs_ta, emit)
      return (time + 1, outputs_ta, next_state, next_inputs, next_finished)

    res = control_flow_ops.while_loop(
        condition,
        body,
        loop_vars=[
            initial_time, initial_outputs_ta, initial_state, initial_inputs,
            initial_finished
        ],
        parallel_iterations=parallel_iterations,
        swap_memory=swap_memory)

    final_outputs_ta = res[1]
    final_state = res[2]

    final_outputs = nest.map_structure(lambda ta: ta.stack(), final_outputs_ta)
    if not output_time_major:
      final_outputs = nest.map_structure(_transpose_batch_time, final_outputs)

  if hasattr(decoder, "finalize"):
    final_outputs, final_state = decoder.finalize(final_outputs, final_state)

  return final_outputs, final_state
Example #46
0
def dynamic_decode(decoder,
                   output_time_major=False,
                   impute_finished=False,
                   maximum_iterations=None,
                   parallel_iterations=32,
                   swap_memory=False,
                   scope=None,
                   **kwargs):
    """Perform dynamic decoding with `decoder`.

    Calls initialize() once and step() repeatedly on the Decoder object.

    Args:
      decoder: A `Decoder` instance.
      output_time_major: Python boolean.  Default: `False` (batch major). If
        `True`, outputs are returned as time major tensors (this mode is
        faster). Otherwise, outputs are returned as batch major tensors (this
        adds extra time to the computation).
      impute_finished: Python boolean.  If `True`, then states for batch
        entries which are marked as finished get copied through and the
        corresponding outputs get zeroed out.  This causes some slowdown at
        each time step, but ensures that the final state and outputs have
        the correct values and that backprop ignores time steps that were
        marked as finished.
      maximum_iterations: `int32` scalar, maximum allowed number of decoding
         steps.  Default is `None` (decode until the decoder is fully done).
      parallel_iterations: Argument passed to `tf.while_loop`.
      swap_memory: Argument passed to `tf.while_loop`.
      scope: Optional variable scope to use.
      **kwargs: dict, other keyword arguments for dynamic_decode. It might
        contain arguments for `BaseDecoder` to initialize, which takes all
        tensor inputs during call().

    Returns:
      `(final_outputs, final_state, final_sequence_lengths)`.

    Raises:
      TypeError: if `decoder` is not an instance of `Decoder`.
      ValueError: if `maximum_iterations` is provided but is not a scalar.
    """
    if not isinstance(decoder, (Decoder, BaseDecoder)):
        raise TypeError(
            "Expected decoder to be type Decoder, but saw: %s" % type(decoder))

    with variable_scope.variable_scope(scope, "decoder") as varscope:
        # Determine context types.
        ctxt = ops.get_default_graph()._get_control_flow_context()  # pylint: disable=protected-access
        is_xla = control_flow_util.GetContainingXLAContext(ctxt) is not None
        in_while_loop = (control_flow_util.GetContainingWhileContext(ctxt) is
                         not None)
        # Properly cache variable values inside the while_loop.
        # Don't set a caching device when running in a loop, since it is
        # possible that train steps could be wrapped in a tf.while_loop. In that
        # scenario caching prevents forward computations in loop iterations from
        # re-reading the updated weights.
        if not context.executing_eagerly() and not in_while_loop:
            if varscope.caching_device is None:
                varscope.set_caching_device(lambda op: op.device)

        if maximum_iterations is not None:
            maximum_iterations = ops.convert_to_tensor(
                maximum_iterations,
                dtype=dtypes.int32,
                name="maximum_iterations")
            if maximum_iterations.get_shape().ndims != 0:
                raise ValueError("maximum_iterations must be a scalar")

        if isinstance(decoder, Decoder):
            initial_finished, initial_inputs, initial_state = \
                decoder.initialize()
        else:
            # For BaseDecoder that takes tensor inputs during call.
            decoder_init_input = kwargs.pop("decoder_init_input", None)
            decoder_init_kwargs = kwargs.pop("decoder_init_kwargs", {})
            initial_finished, initial_inputs, initial_state = \
                decoder.initialize(decoder_init_input, **decoder_init_kwargs)

        zero_outputs = _create_zero_outputs(
            decoder.output_size, decoder.output_dtype, decoder.batch_size)

        if is_xla and maximum_iterations is None:
            raise ValueError(
                "maximum_iterations is required for XLA compilation.")
        if maximum_iterations is not None:
            initial_finished = math_ops.logical_or(initial_finished,
                                                   0 >= maximum_iterations)
        initial_sequence_lengths = array_ops.zeros_like(
            initial_finished, dtype=dtypes.int32)
        initial_time = constant_op.constant(0, dtype=dtypes.int32)

        def _shape(batch_size, from_shape):
            if (not isinstance(from_shape, tensor_shape.TensorShape)
                    or from_shape.ndims == 0):
                return None
            else:
                batch_size = tensor_util.constant_value(
                    ops.convert_to_tensor(batch_size, name="batch_size"))
                return tensor_shape.TensorShape(
                    [batch_size]).concatenate(from_shape)

        dynamic_size = maximum_iterations is None or not is_xla

        def _create_ta(s, d):
            return tensor_array_ops.TensorArray(
                dtype=d,
                size=0 if dynamic_size else maximum_iterations,
                dynamic_size=dynamic_size,
                element_shape=_shape(decoder.batch_size, s))

        initial_outputs_ta = nest.map_structure(
            _create_ta, decoder.output_size, decoder.output_dtype)

        def condition(unused_time, unused_outputs_ta, unused_state,
                      unused_inputs, finished, unused_sequence_lengths):
            return math_ops.logical_not(math_ops.reduce_all(finished))

        def body(time, outputs_ta, state, inputs, finished, sequence_lengths):
            """Internal while_loop body.

            Args:
              time: scalar int32 tensor.
              outputs_ta: structure of TensorArray.
              state: (structure of) state tensors and TensorArrays.
              inputs: (structure of) input tensors.
              finished: bool tensor (keeping track of what's finished).
              sequence_lengths: int32 tensor (keeping track of time of finish).

            Returns:
              `(time + 1, outputs_ta, next_state, next_inputs, next_finished,
                next_sequence_lengths)`.
              ```
            """
            (next_outputs, decoder_state, next_inputs,
             decoder_finished) = decoder.step(time, inputs, state)
            if decoder.tracks_own_finished:
                next_finished = decoder_finished
            else:
                next_finished = math_ops.logical_or(decoder_finished, finished)
            next_sequence_lengths = array_ops.where(
                math_ops.logical_not(finished),
                array_ops.fill(array_ops.shape(sequence_lengths), time + 1),
                sequence_lengths)

            nest.assert_same_structure(state, decoder_state)
            nest.assert_same_structure(outputs_ta, next_outputs)
            nest.assert_same_structure(inputs, next_inputs)

            # Zero out output values past finish
            if impute_finished:
                emit = nest.map_structure(
                    lambda out, zero: array_ops.where(finished, zero, out),
                    next_outputs, zero_outputs)
            else:
                emit = next_outputs

            # Copy through states past finish
            def _maybe_copy_state(new, cur):
                # TensorArrays and scalar states get passed through.
                if isinstance(cur, tensor_array_ops.TensorArray):
                    pass_through = True
                else:
                    new.set_shape(cur.shape)
                    pass_through = (new.shape.ndims == 0)
                return new if pass_through else array_ops.where(
                    finished, cur, new)

            if impute_finished:
                next_state = nest.map_structure(_maybe_copy_state,
                                                decoder_state, state)
            else:
                next_state = decoder_state

            outputs_ta = nest.map_structure(
                lambda ta, out: ta.write(time, out), outputs_ta, emit)
            return (time + 1, outputs_ta, next_state, next_inputs,
                    next_finished, next_sequence_lengths)

        res = control_flow_ops.while_loop(
            condition,
            body,
            loop_vars=(
                initial_time,
                initial_outputs_ta,
                initial_state,
                initial_inputs,
                initial_finished,
                initial_sequence_lengths,
            ),
            parallel_iterations=parallel_iterations,
            maximum_iterations=maximum_iterations,
            swap_memory=swap_memory)

        final_outputs_ta = res[1]
        final_state = res[2]
        final_sequence_lengths = res[5]

        final_outputs = nest.map_structure(lambda ta: ta.stack(),
                                           final_outputs_ta)

        try:
            final_outputs, final_state = decoder.finalize(
                final_outputs, final_state, final_sequence_lengths)
        except NotImplementedError:
            pass

        if not output_time_major:
            final_outputs = nest.map_structure(_transpose_batch_time,
                                               final_outputs)

    return final_outputs, final_state, final_sequence_lengths
def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
                      beam_width, end_token, length_penalty_weight):
  """Performs a single step of Beam Search Decoding.

  Args:
    time: Beam search time step, should start at 0. At time 0 we assume
      that all beams are equal and consider only the first beam for
      continuations.
    logits: Logits at the current time step. A tensor of shape
      `[batch_size, beam_width, vocab_size]`
    next_cell_state: The next state from the cell, e.g. an instance of
      AttentionWrapperState if the cell is attentional.
    beam_state: Current state of the beam search.
      An instance of `BeamSearchDecoderState`.
    batch_size: The batch size for this input.
    beam_width: Python int.  The size of the beams.
    end_token: The int32 end token.
    length_penalty_weight: Float weight to penalize length. Disabled with 0.0.

  Returns:
    A new beam state.
  """
  static_batch_size = tensor_util.constant_value(batch_size)

  # Calculate the current lengths of the predictions
  prediction_lengths = beam_state.lengths
  previously_finished = beam_state.finished

  # Calculate the total log probs for the new hypotheses
  # Final Shape: [batch_size, beam_width, vocab_size]
  step_log_probs = nn_ops.log_softmax(logits)
  step_log_probs = _mask_probs(step_log_probs, end_token, previously_finished)
  total_probs = array_ops.expand_dims(beam_state.log_probs, 2) + step_log_probs

  # Calculate the continuation lengths by adding to all continuing beams.
  vocab_size = logits.shape[-1].value or array_ops.shape(logits)[-1]
  lengths_to_add = array_ops.one_hot(
      indices=array_ops.fill([batch_size, beam_width], end_token),
      depth=vocab_size,
      on_value=np.int64(0),
      off_value=np.int64(1),
      dtype=dtypes.int64)
  add_mask = math_ops.to_int64(math_ops.logical_not(previously_finished))
  lengths_to_add *= array_ops.expand_dims(add_mask, 2)
  new_prediction_lengths = (
      lengths_to_add + array_ops.expand_dims(prediction_lengths, 2))

  # Calculate the scores for each beam
  scores = _get_scores(
      log_probs=total_probs,
      sequence_lengths=new_prediction_lengths,
      length_penalty_weight=length_penalty_weight,
      dtype=logits.dtype)

  time = ops.convert_to_tensor(time, name="time")
  # During the first time step we only consider the initial beam
  scores_shape = array_ops.shape(scores)
  scores_flat = array_ops.reshape(scores, [batch_size, -1])

  # Pick the next beams according to the specified successors function
  next_beam_size = ops.convert_to_tensor(
      beam_width, dtype=dtypes.int32, name="beam_width")
  next_beam_scores, word_indices = nn_ops.top_k(scores_flat, k=next_beam_size)

  next_beam_scores.set_shape([static_batch_size, beam_width])
  word_indices.set_shape([static_batch_size, beam_width])

  # Pick out the probs, beam_ids, and states according to the chosen predictions
  next_beam_probs = _tensor_gather_helper(
      gather_indices=word_indices,
      gather_from=total_probs,
      batch_size=batch_size,
      range_size=beam_width * vocab_size,
      gather_shape=[-1],
      name="next_beam_probs")
  # Note: just doing the following
  #   math_ops.to_int32(word_indices % vocab_size,
  #       name="next_beam_word_ids")
  # would be a lot cleaner but for reasons unclear, that hides the results of
  # the op which prevents capturing it with tfdbg debug ops.
  raw_next_word_ids = math_ops.mod(
      word_indices, vocab_size, name="next_beam_word_ids")
  next_word_ids = math_ops.to_int32(raw_next_word_ids)
  next_beam_ids = math_ops.to_int32(
      word_indices / vocab_size, name="next_beam_parent_ids")

  # Append new ids to current predictions
  previously_finished = _tensor_gather_helper(
      gather_indices=next_beam_ids,
      gather_from=previously_finished,
      batch_size=batch_size,
      range_size=beam_width,
      gather_shape=[-1])
  next_finished = math_ops.logical_or(
      previously_finished,
      math_ops.equal(next_word_ids, end_token),
      name="next_beam_finished")

  # Calculate the length of the next predictions.
  # 1. Finished beams remain unchanged.
  # 2. Beams that are now finished (EOS predicted) have their length
  #    increased by 1.
  # 3. Beams that are not yet finished have their length increased by 1.
  lengths_to_add = math_ops.to_int64(math_ops.logical_not(previously_finished))
  next_prediction_len = _tensor_gather_helper(
      gather_indices=next_beam_ids,
      gather_from=beam_state.lengths,
      batch_size=batch_size,
      range_size=beam_width,
      gather_shape=[-1])
  next_prediction_len += lengths_to_add

  # Pick out the cell_states according to the next_beam_ids. We use a
  # different gather_shape here because the cell_state tensors, i.e.
  # the tensors that would be gathered from, all have dimension
  # greater than two and we need to preserve those dimensions.
  # pylint: disable=g-long-lambda
  next_cell_state = nest.map_structure(
      lambda gather_from: _maybe_tensor_gather_helper(
          gather_indices=next_beam_ids,
          gather_from=gather_from,
          batch_size=batch_size,
          range_size=beam_width,
          gather_shape=[batch_size * beam_width, -1]),
      next_cell_state)
  # pylint: enable=g-long-lambda

  next_state = BeamSearchDecoderState(
      cell_state=next_cell_state,
      log_probs=next_beam_probs,
      lengths=next_prediction_len,
      finished=next_finished)

  output = BeamSearchDecoderOutput(
      scores=next_beam_scores,
      predicted_ids=next_word_ids,
      parent_ids=next_beam_ids)

  return output, next_state
    def _time_step(time, input_ta_t, output_ta_t, extra_output_ta_t,
                   last_coverage, state, finished):
        """Take a time step of the dynamic RNN.
    Args:
      time: int32 scalar Tensor.
      output_ta_t: List of `TensorArray`s that represent the output.
      state: nested tuple of vector tensors that represent the state.
    Returns:
      The tuple (time + 1, output_ta_t with updated flow, new_state).
    """

        input_t = tuple(ta.read(time) for ta in input_ta_t)
        # Restore some shape information
        for input_, shape in zip(input_t, inputs_got_shape):
            input_.set_shape(shape[1:])

        input_t = nest.pack_sequence_as(structure=inputs,
                                        flat_sequence=input_t)

        rnn_state = state if not is_infer else state[0]
        if use_coverage:
            # coverage_t = tuple(ta.read(time) for ta in coverage_ta_t)
            # coverage_t = nest.pack_sequence_as(structure=encoded_fert_init, flat_sequence=coverage_t)
            # coverage_t.set_shape(coverage_shape)
            ctx, att_weights, new_coverage = decoder.attention_step(
                rnn_state, state_size, context, att_sequence_length,
                last_coverage, encoded_fertility)

            # new_coverage = nest.flatten(new_coverage)
            # coverage_ta_t = tuple(ta.write(time + 1, coverage)
            #                       for ta, coverage in zip(coverage_ta_t, new_coverage))
        else:
            new_coverage = last_coverage
            ctx, att_weights = decoder.attention_step(rnn_state,
                                                      state_size,
                                                      context,
                                                      use_coverage=False)

        call_cell = lambda: cell(array_ops.concat([input_t, ctx], 1), rnn_state
                                 )

        if sequence_length is not None:
            (output,
             new_state) = _rnn_step(time=time,
                                    sequence_length=sequence_length,
                                    min_sequence_length=min_sequence_length,
                                    max_sequence_length=max_sequence_length,
                                    zero_output=zero_output,
                                    state=rnn_state,
                                    call_cell=call_cell,
                                    state_size=state_size,
                                    skip_conditionals=True)
        else:
            (output, new_state) = call_cell()
            assert is_infer, "Manually zero output when inferring."
            output = nest.map_structure(
                lambda out, zero: array_ops.where(finished, zero, out), output,
                zero_output)

        logits = decoder.logit_fn(output)
        if is_infer:
            mix_state = (new_state, state[1])
            (new_output, extra_output, new_state, new_input_t,
             search_finished) = decoder.search_step(time, output, logits,
                                                    mix_state)

            new_input_t = nest.flatten(new_input_t)
            input_ta_t = tuple(
                ta.write(time + 1, new_input_)
                for ta, new_input_ in zip(input_ta_t, new_input_t))

            new_finished = math_ops.logical_or(search_finished, finished)
            extra_output = nest.pack_sequence_as(structure=extra_output_ta,
                                                 flat_sequence=extra_output)
            output = new_output
        else:
            extra_output = logits
            extra_output = nest.flatten(extra_output)
            new_finished = finished

        # Pack state if using state tuples
        output = nest.flatten(output)

        output_ta_t = tuple(
            ta.write(time, out) for ta, out in zip(output_ta_t, output))

        if is_infer:
            write_ta_ = lambda t, o: t.write(time, o)
            map_write_ = lambda ta, out: nest.map_structure(write_ta_, ta, out)
            extra_output_ta_t = tuple(
                map_write_(ta, out)
                for ta, out in zip(extra_output_ta_t, extra_output))
        else:
            extra_output_ta_t = tuple(
                ta.write(time, out)
                for ta, out in zip(extra_output_ta_t, extra_output))

        return (time + 1, input_ta_t, output_ta_t, extra_output_ta_t,
                new_coverage, new_state, new_finished)
Example #49
0
    def next_inputs(self, time, outputs, state, sample_ids, name=None):
        """Gets the next inputs for next step."""
        with ops.name_scope(name, "ScheduledOutputTrainingHelperNextInputs",
                            [time, outputs, state, sample_ids]):
            (finished, base_next_inputs,
             state) = (super(ScheduledOutputTrainingHelper,
                             self).next_inputs(time=time,
                                               outputs=outputs,
                                               state=state,
                                               sample_ids=sample_ids,
                                               name=name))
            sample_ids = math_ops.cast(sample_ids, dtypes.bool)

            def maybe_sample():
                """Perform scheduled sampling."""
                def maybe_concatenate_auxiliary_inputs(outputs_, indices=None):
                    """Concatenate outputs with auxiliary inputs, if they exist."""
                    if self._auxiliary_input_tas is None:
                        return outputs_

                    next_time = time + 1
                    auxiliary_inputs = nest.map_structure(
                        lambda ta: ta.read(next_time),
                        self._auxiliary_input_tas)
                    if indices is not None:
                        auxiliary_inputs = array_ops.gather_nd(
                            auxiliary_inputs, indices)
                    return nest.map_structure(
                        lambda x, y: array_ops.concat((x, y), -1), outputs_,
                        auxiliary_inputs)

                if self._next_inputs_fn is None:
                    return array_ops.where(
                        sample_ids,
                        maybe_concatenate_auxiliary_inputs(outputs),
                        base_next_inputs)

                where_sampling = math_ops.cast(array_ops.where(sample_ids),
                                               dtypes.int32)
                where_not_sampling = math_ops.cast(
                    array_ops.where(math_ops.logical_not(sample_ids)),
                    dtypes.int32)
                outputs_sampling = array_ops.gather_nd(outputs, where_sampling)
                inputs_not_sampling = array_ops.gather_nd(
                    base_next_inputs, where_not_sampling)
                sampled_next_inputs = maybe_concatenate_auxiliary_inputs(
                    self._next_inputs_fn(outputs_sampling), where_sampling)

                base_shape = array_ops.shape(base_next_inputs)
                return (array_ops.scatter_nd(indices=where_sampling,
                                             updates=sampled_next_inputs,
                                             shape=base_shape) +
                        array_ops.scatter_nd(indices=where_not_sampling,
                                             updates=inputs_not_sampling,
                                             shape=base_shape))

            all_finished = math_ops.reduce_all(finished)
            no_samples = math_ops.logical_not(math_ops.reduce_any(sample_ids))
            next_inputs = control_flow_ops.cond(
                math_ops.logical_or(all_finished, no_samples),
                lambda: base_next_inputs, maybe_sample)
            return (finished, next_inputs, state)
Example #50
0
def stratified_sample(tensors, labels, target_probs, batch_size,
                      init_probs=None, enqueue_many=False, queue_capacity=16,
                      threads_per_queue=1, name=None):
  """Stochastically creates batches based on per-class probabilities.

  This method discards examples. Internally, it creates one queue to amortize
  the cost of disk reads, and one queue to hold the properly-proportioned
  batch. See `stratified_sample_unknown_dist` for a function that performs
  stratified sampling with one queue per class and doesn't require knowing the
  class data-distribution ahead of time.

  Args:
    tensors: List of tensors for data. All tensors are either one item or a
        batch, according to enqueue_many.
    labels: Tensor for label of data. Label is a single integer or a batch,
        depending on enqueue_many. It is not a one-hot vector.
    target_probs: Target class proportions in batch. An object whose type has a
        registered Tensor conversion function.
    batch_size: Size of batch to be returned.
    init_probs: Class proportions in the data. An object whose type has a
        registered Tensor conversion function, or `None` for estimating the
        initial distribution.
    enqueue_many: Bool. If true, interpret input tensors as having a batch
        dimension.
    queue_capacity: Capacity of the large queue that holds input examples.
    threads_per_queue: Number of threads for the large queue that holds input
        examples and for the final queue with the proper class proportions.
    name: Optional prefix for ops created by this function.
  Raises:
    ValueError: enqueue_many is True and labels doesn't have a batch
        dimension, or if enqueue_many is False and labels isn't a scalar.
    ValueError: enqueue_many is True, and batch dimension on data and labels
        don't match.
    ValueError: if probs don't sum to one.
    ValueError: if a zero initial probability class has a nonzero target
        probability.
    TFAssertion: if labels aren't integers in [0, num classes).
  Returns:
    (data_batch, label_batch), where data_batch is a list of tensors of the same
        length as `tensors`

  Example:
    # Get tensor for a single data and label example.
    data, label = data_provider.Get(['data', 'label'])

    # Get stratified batch according to per-class probabilities.
    target_probs = [...distribution you want...]
    [data_batch], labels = tf.contrib.training.stratified_sample(
        [data], label, target_probs)

    # Run batch through network.
    ...
  """
  with ops.name_scope(name, 'stratified_sample', tensors + [labels]):
    tensor_list = ops.convert_n_to_tensor_or_indexed_slices(tensors)
    labels = ops.convert_to_tensor(labels)
    target_probs = ops.convert_to_tensor(target_probs, dtype=dtypes.float32)
    # Reduce the case of a single example to that of a batch of size 1.
    if not enqueue_many:
      tensor_list = [array_ops.expand_dims(tensor, 0) for tensor in tensor_list]
      labels = array_ops.expand_dims(labels, 0)

    # If `init_probs` is `None`, set up online estimation of data distribution.
    if init_probs is None:
      # We use `target_probs` to get the number of classes, so its shape must be
      # fully defined at graph construction time.
      target_probs.get_shape().assert_is_fully_defined()
      init_probs = _estimate_data_distribution(
          labels, target_probs.get_shape().num_elements())
    else:
      init_probs = ops.convert_to_tensor(init_probs, dtype=dtypes.float32)

    # Validate that input is consistent.
    tensor_list, labels, [init_probs, target_probs] = _verify_input(
        tensor_list, labels, [init_probs, target_probs])

    # Check that all zero initial probabilities also have zero target
    # probabilities.
    assert_op = control_flow_ops.Assert(
        math_ops.reduce_all(math_ops.logical_or(
            math_ops.not_equal(init_probs, 0),
            math_ops.equal(target_probs, 0))),
        ['All classes with zero initial probability must also have zero target '
         'probability: ', init_probs, target_probs])
    init_probs = control_flow_ops.with_dependencies([assert_op], init_probs)

    # Calculate acceptance sampling probabilities.
    accept_probs = _calculate_acceptance_probabilities(init_probs, target_probs)
    proportion_rejected = math_ops.reduce_sum((1 - accept_probs) * init_probs)
    accept_probs = control_flow_ops.cond(
        math_ops.less(proportion_rejected, .5),
        lambda: accept_probs,
        lambda: logging_ops.Print(  # pylint: disable=g-long-lambda
            accept_probs, [accept_probs],
            message='Proportion of examples rejected by sampler is high.',
            first_n=10))

    # Make a single queue to hold input examples. Reshape output so examples
    # don't have singleton batch dimension.
    batched = input_ops.batch(tensor_list + [labels],
                              batch_size=1,
                              num_threads=threads_per_queue,
                              capacity=queue_capacity,
                              enqueue_many=True)
    val_list = [array_ops.squeeze(x, [0]) for x in batched[:-1]]
    label = array_ops.squeeze(batched[-1], [0])

    # Set up second queue containing batches that have the desired class
    # proportions.
    cur_prob = array_ops.gather(accept_probs, label)
    keep_input = random_ops.random_uniform([]) < cur_prob
    batched = _conditional_batch(
        val_list + [label],
        keep_input,
        batch_size,
        num_threads=threads_per_queue)
    return batched[:-1], batched[-1]
Example #51
0
        def body(time, outputs_ta, state, inputs, finished, sequence_lengths):
            """Internal while_loop body.

      Args:
        time: scalar int32 tensor.
        outputs_ta: structure of TensorArray.
        state: (structure of) state tensors and TensorArrays.
        inputs: (structure of) input tensors.
        finished: bool tensor (keeping track of what's finished).
        sequence_lengths: int32 tensor (keeping track of time of finish).

      Returns:
        `(time + 1, outputs_ta, next_state, next_inputs, next_finished,
          next_sequence_lengths)`.
        ```
      """
            (next_outputs, decoder_state, next_inputs,
             decoder_finished) = decoder.step(time, inputs, state)
            if decoder.tracks_own_finished:
                next_finished = decoder_finished
            else:
                next_finished = math_ops.logical_or(decoder_finished, finished)
            if maximum_iterations is not None:
                next_finished = math_ops.logical_or(
                    next_finished, time + 1 >= maximum_iterations)
            next_sequence_lengths = array_ops.where(
                math_ops.logical_and(math_ops.logical_not(finished),
                                     next_finished),
                array_ops.fill(array_ops.shape(sequence_lengths), time + 1),
                sequence_lengths)

            nest.assert_same_structure(state, decoder_state)
            nest.assert_same_structure(outputs_ta, next_outputs)
            nest.assert_same_structure(inputs, next_inputs)

            # Zero out output values past finish
            if impute_finished:
                emit = nest.map_structure(
                    lambda out, zero: array_ops.where(finished, zero, out),
                    next_outputs, zero_outputs)
            else:
                emit = next_outputs

            # Copy through states past finish
            def _maybe_copy_state(new, cur):
                # TensorArrays and scalar states get passed through.
                if isinstance(cur, tensor_array_ops.TensorArray):
                    pass_through = True
                else:
                    new.set_shape(cur.shape)
                    pass_through = (new.shape.ndims == 0)
                return new if pass_through else array_ops.where(
                    finished, cur, new)

            if impute_finished:
                next_state = nest.map_structure(_maybe_copy_state,
                                                decoder_state, state)
            else:
                next_state = decoder_state

            outputs_ta = nest.map_structure(
                lambda ta, out: ta.write(time, out), outputs_ta, emit)
            return (time + 1, outputs_ta, next_state, next_inputs,
                    next_finished, next_sequence_lengths)
Example #52
0
 def next_inputs(self, time, sample_ids):
     finished = math_ops.logical_or(
         tf.greater_equal(time + 1, self.max_step),
         tf.equal(self.eos_id, sample_ids))
     return finished, self.lookup(sample_ids)
Example #53
0
def _beam_search_step(time, logits, beam_state, batch_size, beam_width,
                      end_token, length_penalty_weight):
  """Performs a single step of Beam Search Decoding.

  Args:
    time: Beam search time step, should start at 0. At time 0 we assume
      that all beams are equal and consider only the first beam for
      continuations.
    logits: Logits at the current time step. A tensor of shape `[B, vocab_size]`
    beam_state: Current state of the beam search. An instance of `BeamState`
    batch_size: The batch size for this input.
    beam_width: The size of the beams.
    end_token: The int32 end token.
    length_penalty_weight: Float weight to penalize length. Disabled with 0.0.

  Returns:
    A new beam state.
  """
  static_batch_size = tensor_util.constant_value(batch_size)

  # Calculate the current lengths of the predictions
  prediction_lengths = beam_state.lengths
  previously_finished = beam_state.finished

  # Calculate the total log probs for the new hypotheses
  # Final Shape: [batch_size, beam_width, vocab_size]
  probs = nn_ops.log_softmax(logits)
  probs = _mask_probs(probs, end_token, previously_finished)
  total_probs = array_ops.expand_dims(beam_state.log_probs, 2) + probs

  # Calculate the continuation lengths by adding to all continuing beams.
  vocab_size = logits.get_shape().as_list()[-1]
  lengths_to_add = array_ops.one_hot(
      array_ops.tile(
          array_ops.reshape(end_token, [1, 1]), [batch_size, beam_width]),
      vocab_size, 0, 1)
  add_mask = (1 - math_ops.to_int32(previously_finished))
  lengths_to_add = array_ops.expand_dims(add_mask, 2) * lengths_to_add
  new_prediction_lengths = array_ops.expand_dims(prediction_lengths,
                                                 2) + lengths_to_add

  # Calculate the scores for each beam
  scores = _get_scores(
      log_probs=total_probs,
      sequence_lengths=new_prediction_lengths,
      length_penalty_weight=length_penalty_weight)

  scores_flat = array_ops.reshape(scores, [batch_size, -1])
  # During the first time step we only consider the initial beam
  scores_flat = control_flow_ops.cond(
      ops.convert_to_tensor(time) > 0, lambda: scores_flat,
      lambda: scores[:, 0])

  # Pick the next beams according to the specified successors function
  next_beam_scores, word_indices = nn_ops.top_k(scores_flat, k=beam_width)
  next_beam_scores.set_shape([static_batch_size, beam_width])
  word_indices.set_shape([static_batch_size, beam_width])

  # Pick out the probs, beam_ids, and states according to the chosen predictions
  next_beam_probs = _tensor_gather_helper(
      gather_indices=word_indices,
      gather_from=total_probs,
      range_input=batch_size,
      range_size=beam_width * vocab_size,
      final_shape=[static_batch_size, beam_width])

  next_word_ids = math_ops.to_int32(word_indices % vocab_size)
  next_beam_ids = math_ops.to_int32(word_indices / vocab_size)

  # Append new ids to current predictions
  previously_finished = _tensor_gather_helper(
      gather_indices=next_beam_ids,
      gather_from=previously_finished,
      range_input=batch_size,
      range_size=beam_width,
      final_shape=[static_batch_size, beam_width])
  next_finished = math_ops.logical_or(previously_finished,
                                      math_ops.equal(next_word_ids, end_token))

  # Calculate the length of the next predictions.
  # 1. Finished beams remain unchanged
  # 2. Beams that are now finished (EOS predicted) remain unchanged
  # 3. Beams that are not yet finished have their length increased by 1
  lengths_to_add = math_ops.to_int32(
      math_ops.not_equal(next_word_ids, end_token))
  lengths_to_add = (1 - math_ops.to_int32(next_finished)) * lengths_to_add
  next_prediction_len = _tensor_gather_helper(
      gather_indices=next_beam_ids,
      gather_from=beam_state.lengths,
      range_input=batch_size,
      range_size=beam_width,
      final_shape=[static_batch_size, beam_width])
  next_prediction_len += lengths_to_add

  next_state = BeamSearchDecoderState(
      cell_state=beam_state.cell_state,
      log_probs=next_beam_probs,
      lengths=next_prediction_len,
      finished=next_finished)

  output = BeamSearchDecoderOutput(
      scores=next_beam_scores,
      predicted_ids=next_word_ids,
      parent_ids=next_beam_ids)

  return output, next_state
Example #54
0
def stratified_sample(tensors, labels, init_probs, target_probs, batch_size,
                      enqueue_many=False, queue_capacity=16,
                      threads_per_queue=1, name=None):
  """Stochastically creates batches based on per-class probabilities.

  This method discards examples. Internally, it creates one queue to amortize
  the cost of disk reads, and one queue to hold the properly-proportioned
  batch. See `stratified_sample_unknown_dist` for a function that performs
  stratified sampling with one queue per class and doesn't require knowing the
  class data-distribution ahead of time.

  Args:
    tensors: List of tensors for data. All tensors are either one item or a
        batch, according to enqueue_many.
    labels: Tensor for label of data. Label is a single integer or a batch,
        depending on enqueue_many. It is not a one-hot vector.
    init_probs: Class proportions in the data. An object whose type has a
        registered Tensor conversion function.
    target_probs: Target class proportions in batch. An object whose type has a
        registered Tensor conversion function.
    batch_size: Size of batch to be returned.
    enqueue_many: Bool. If true, interpret input tensors as having a batch
        dimension.
    queue_capacity: Capacity of the large queue that holds input examples.
    threads_per_queue: Number of threads for the large queue that holds input
        examples and for the final queue with the proper class proportions.
    name: Optional prefix for ops created by this function.
  Raises:
    ValueError: enqueue_many is True and labels doesn't have a batch
        dimension, or if enqueue_many is False and labels isn't a scalar.
    ValueError: enqueue_many is True, and batch dimension on data and labels
        don't match.
    ValueError: if probs don't sum to one.
    ValueError: if a zero initial probability class has a nonzero target
        probability.
    TFAssertion: if labels aren't integers in [0, num classes).
  Returns:
    (data_batch, label_batch), where data_batch is a list of tensors of the same
        length as `tensors`

  Example:
    # Get tensor for a single data and label example.
    data, label = data_provider.Get(['data', 'label'])

    # Get stratified batch according to per-class probabilities.
    init_probs = [1.0/NUM_CLASSES for _ in range(NUM_CLASSES)]
    target_probs = [...distribution you want...]
    [data_batch], labels = tf.contrib.framework.sampling_ops.stratified_sample(
        [data], label, init_probs, target_probs)

    # Run batch through network.
    ...
  """
  with ops.op_scope(tensors + [labels], name, 'stratified_sample'):
    tensor_list = ops.convert_n_to_tensor_or_indexed_slices(tensors)
    labels = ops.convert_to_tensor(labels)
    init_probs = ops.convert_to_tensor(init_probs, dtype=dtypes.float32)
    target_probs = ops.convert_to_tensor(target_probs, dtype=dtypes.float32)
    # Reduce the case of a single example to that of a batch of size 1.
    if not enqueue_many:
      tensor_list = [array_ops.expand_dims(tensor, 0) for tensor in tensor_list]
      labels = array_ops.expand_dims(labels, 0)

    # Validate that input is consistent.
    tensor_list, labels, [init_probs, target_probs] = _verify_input(
        tensor_list, labels, [init_probs, target_probs])

    # Check that all zero initial probabilities also have zero target
    # probabilities.
    assert_op = logging_ops.Assert(math_ops.reduce_all(math_ops.logical_or(
        math_ops.not_equal(init_probs, 0),
        math_ops.equal(target_probs, 0))), [init_probs, target_probs])
    init_probs = control_flow_ops.with_dependencies([assert_op], init_probs)

    # Calculate acceptance sampling probabilities.
    accept_probs = _calculate_acceptance_probabilities(init_probs, target_probs)
    proportion_rejected = math_ops.reduce_sum((1 - accept_probs) * init_probs)
    accept_probs = control_flow_ops.cond(
        math_ops.less(proportion_rejected, .5),
        lambda: accept_probs,
        lambda: logging_ops.Print(  # pylint: disable=g-long-lambda
            accept_probs, [accept_probs],
            message='Proportion of examples rejected by sampler is high.',
            first_n=10))

    # Make a single queue to hold input examples. Reshape output so examples
    # don't have singleton batch dimension.
    batched = input_ops.batch(tensor_list + [labels],
                              batch_size=1,
                              num_threads=threads_per_queue,
                              capacity=queue_capacity,
                              enqueue_many=True)
    val_list = [array_ops.squeeze(x, [0]) for x in batched[:-1]]
    label = array_ops.squeeze(batched[-1], [0])

    # Set up second queue containing batches that have the desired class
    # proportions.
    batched = _get_stratified_batch_from_tensors(
        val_list, label, accept_probs, batch_size, threads_per_queue)
    return batched[:-1], batched[-1]
def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
                      beam_width, end_token, length_penalty_weight):
  """Performs a single step of Beam Search Decoding.

  Args:
    time: Beam search time step, should start at 0. At time 0 we assume
      that all beams are equal and consider only the first beam for
      continuations.
    logits: Logits at the current time step. A tensor of shape
      `[batch_size, beam_width, vocab_size]`
    next_cell_state: The next state from the cell, e.g. an instance of
      AttentionWrapperState if the cell is attentional.
    beam_state: Current state of the beam search.
      An instance of `BeamSearchDecoderState`.
    batch_size: The batch size for this input.
    beam_width: Python int.  The size of the beams.
    end_token: The int32 end token.
    length_penalty_weight: Float weight to penalize length. Disabled with 0.0.

  Returns:
    A new beam state.
  """
  static_batch_size = tensor_util.constant_value(batch_size)

  # Calculate the current lengths of the predictions
  prediction_lengths = beam_state.lengths
  previously_finished = beam_state.finished

  # Calculate the total log probs for the new hypotheses
  # Final Shape: [batch_size, beam_width, vocab_size]
  step_log_probs = nn_ops.log_softmax(logits)
  step_log_probs = _mask_probs(step_log_probs, end_token, previously_finished)
  total_probs = array_ops.expand_dims(beam_state.log_probs, 2) + step_log_probs

  # Calculate the continuation lengths by adding to all continuing beams.
  vocab_size = logits.shape[-1].value or array_ops.shape(logits)[-1]
  lengths_to_add = array_ops.one_hot(
      indices=array_ops.tile(
          array_ops.reshape(end_token, [1, 1]), [batch_size, beam_width]),
      depth=vocab_size,
      on_value=constant_op.constant(0, dtype=dtypes.int64),
      off_value=constant_op.constant(1, dtype=dtypes.int64),
      dtype=dtypes.int64)
  add_mask = (1 - math_ops.to_int64(previously_finished))
  lengths_to_add = array_ops.expand_dims(add_mask, 2) * lengths_to_add
  new_prediction_lengths = (
      lengths_to_add + array_ops.expand_dims(prediction_lengths, 2))

  # Calculate the scores for each beam
  scores = _get_scores(
      log_probs=total_probs,
      sequence_lengths=new_prediction_lengths,
      length_penalty_weight=length_penalty_weight)

  time = ops.convert_to_tensor(time, name="time")
  # During the first time step we only consider the initial beam
  scores_shape = array_ops.shape(scores)
  scores_flat = control_flow_ops.cond(
      time > 0,
      lambda: array_ops.reshape(scores, [batch_size, -1]),
      lambda: scores[:, 0])
  num_available_beam = control_flow_ops.cond(
      time > 0, lambda: math_ops.reduce_prod(scores_shape[1:]),
      lambda: math_ops.reduce_prod(scores_shape[2:]))

  # Pick the next beams according to the specified successors function
  next_beam_size = math_ops.minimum(
      ops.convert_to_tensor(beam_width, dtype=dtypes.int32, name="beam_width"),
      num_available_beam)
  next_beam_scores, word_indices = nn_ops.top_k(scores_flat, k=next_beam_size)

  next_beam_scores.set_shape([static_batch_size, beam_width])
  word_indices.set_shape([static_batch_size, beam_width])

  # Pick out the probs, beam_ids, and states according to the chosen predictions
  next_beam_probs = _tensor_gather_helper(
      gather_indices=word_indices,
      gather_from=total_probs,
      batch_size=batch_size,
      range_size=beam_width * vocab_size,
      gather_shape=[-1],
      name="next_beam_probs")
  # Note: just doing the following
  #   math_ops.to_int32(word_indices % vocab_size,
  #       name="next_beam_word_ids")
  # would be a lot cleaner but for reasons unclear, that hides the results of
  # the op which prevents capturing it with tfdbg debug ops.
  raw_next_word_ids = math_ops.mod(word_indices, vocab_size,
                                   name="next_beam_word_ids")
  next_word_ids = math_ops.to_int32(raw_next_word_ids)
  next_beam_ids = math_ops.to_int32(word_indices / vocab_size,
                                    name="next_beam_parent_ids")

  # Append new ids to current predictions
  previously_finished = _tensor_gather_helper(
      gather_indices=next_beam_ids,
      gather_from=previously_finished,
      batch_size=batch_size,
      range_size=beam_width,
      gather_shape=[-1])
  next_finished = math_ops.logical_or(previously_finished,
                                      math_ops.equal(next_word_ids, end_token),
                                      name="next_beam_finished")

  # Calculate the length of the next predictions.
  # 1. Finished beams remain unchanged
  # 2. Beams that are now finished (EOS predicted) remain unchanged
  # 3. Beams that are not yet finished have their length increased by 1
  lengths_to_add = math_ops.to_int64(
      math_ops.not_equal(next_word_ids, end_token))
  lengths_to_add = (1 - math_ops.to_int64(next_finished)) * lengths_to_add
  next_prediction_len = _tensor_gather_helper(
      gather_indices=next_beam_ids,
      gather_from=beam_state.lengths,
      batch_size=batch_size,
      range_size=beam_width,
      gather_shape=[-1])
  next_prediction_len += lengths_to_add

  # Pick out the cell_states according to the next_beam_ids. We use a
  # different gather_shape here because the cell_state tensors, i.e.
  # the tensors that would be gathered from, all have dimension
  # greater than two and we need to preserve those dimensions.
  # pylint: disable=g-long-lambda
  next_cell_state = nest.map_structure(
      lambda gather_from: _maybe_tensor_gather_helper(
          gather_indices=next_beam_ids,
          gather_from=gather_from,
          batch_size=batch_size,
          range_size=beam_width,
          gather_shape=[batch_size * beam_width, -1]),
      next_cell_state)
  # pylint: enable=g-long-lambda

  next_state = BeamSearchDecoderState(
      cell_state=next_cell_state,
      log_probs=next_beam_probs,
      lengths=next_prediction_len,
      finished=next_finished)

  output = BeamSearchDecoderOutput(
      scores=next_beam_scores,
      predicted_ids=next_word_ids,
      parent_ids=next_beam_ids)

  return output, next_state
    def updateEnv(_position, _step, _mapNo):
      """ Update env_state according to current position and step.
      Args:
      position: a 2D Tensor of shape [batch_size, 3].
      step: a 2D Tensor of shape [batch_size, 1], where
      0 --> no action, 1 --> move forward 1 step, 2 --> turn right, 3 --> turn left, 4 --> turn back.
      mapNo: a 1D int32 Tensor of length batch_size.
      
      Returns:
      env: a 2D Tensor of shape [batch_size, env_size]
        environment state after taking the step based on the position.
      position: a 2D Tensor of shape [batch_size, 3]
        new position after taking the step based on the position.
      """
      if not _mapNo:
        raise ValueError(" Invalid argument mapNo in updateEnv! ")
      if not _position:
        raise ValueError(" Invalid argument position in updateEnv! ")
      new_env = []
      new_pos = []
      # if step == None, take no step and return the environment representations of each position.
      if not _step:
        new_pos = _position 
        for j in xrange(batch_size):
          vec = array_ops.slice(mapIdx, array_ops.pack([_mapNo[j], _position[j,0], _position[j,1], _position[j,2], 0]), [1,1,1,1,state_size])
          new_env.append(array_ops.squeeze(vec))
        new_env = array_ops.reshape(array_ops.pack(new_env), [batch_size, state_size])
        return new_pos, new_env
      
      else:

        def f_move(ppos): # move forward 1 step
          return control_flow_ops.cond(math_ops.equal(ppos[2],0), 
            lambda:array_ops.pack([ppos[0], ppos[1]-1, ppos[2]]), lambda:control_flow_ops.cond(math_ops.equal(ppos[2],1),
              lambda:array_ops.pack([ppos[0]+1, ppos[1], ppos[2]]), lambda:control_flow_ops.cond(math_ops.equal(ppos[2],2),
                lambda:array_ops.pack([ppos[0], ppos[1]+1, ppos[2]]), lambda:array_ops.pack([ppos[0]-1, ppos[1], ppos[2]]))))
            
        def f_right(ppos): # turn right
          return control_flow_ops.cond(math_ops.equal(ppos[2],0),
            lambda: array_ops.pack([ppos[0],ppos[1], 1]), lambda:control_flow_ops.cond(math_ops.equal(ppos[2],1),
              lambda: array_ops.pack([ppos[0], ppos[1], 2]), lambda:control_flow_ops.cond(math_ops.equal(ppos[2],2),
                lambda: array_ops.pack([ppos[0], ppos[1], 3]), lambda: array_ops.pack([ppos[0], ppos[1], 0]))))
        
        def f_left(ppos): # turn left
          return control_flow_ops.cond(math_ops.equal(ppos[2], 0),
            lambda: array_ops.pack([ppos[0], ppos[1], 3]), lambda: control_flow_ops.cond(math_ops.equal(ppos[2],1),
              lambda: array_ops.pack([ppos[0], ppos[1], 0]), lambda:control_flow_ops.cond(math_ops.equal(ppos[2],2),
                lambda:array_ops.pack([ppos[0], ppos[1], 1]), lambda:array_ops.pack([ppos[0],ppos[1],2]))))
        
        def f_back(ppos): # turn back
          return control_flow_ops.cond(math_ops.equal(ppos[2],0),
            lambda:array_ops.pack([ppos[0], ppos[1], 2]), lambda:control_flow_ops.cond(math_ops.equal(ppos[2],1),
              lambda:array_ops.pack([ppos[0], ppos[1], 3]), lambda: control_flow_ops.cond(math_ops.equal(ppos[2],2),
                lambda:array_ops.pack([ppos[0], ppos[1], 0]), lambda:array_ops.pack([ppos[0], ppos[1], 1]))))

        def ffn4(sstep, ppos): 
          return control_flow_ops.cond(math_ops.equal(sstep, data_utils.turnBack_ID),
          lambda:f_back(ppos), lambda:_position[j,:])

        def ffn3(sstep, ppos): 
          return control_flow_ops.cond(math_ops.equal(sstep, data_utils.turnLeft_ID),
          lambda:f_left(ppos), lambda:ffn4(sstep, ppos))

        def ffn2(sstep, ppos): 
          return control_flow_ops.cond(math_ops.equal(sstep, data_utils.turnRight_ID),
          lambda:f_right(ppos), lambda:ffn3(sstep, ppos))

        def ffn1(sstep, ppos): 
          return control_flow_ops.cond(math_ops.equal(sstep, data_utils.moveAct_ID),
          lambda:f_move(ppos), lambda:ffn2(sstep, ppos))


        for j in xrange(batch_size):
          #update position
          temp_pos = control_flow_ops.cond(math_ops.equal(_step[j], data_utils.noAct_ID),
            lambda:_position[j,:], lambda:ffn1(_step[j], _position[j,:]))
          new_pos.append(control_flow_ops.cond(math_ops.logical_or(math_ops.greater(temp_pos[0], 24),
            math_ops.logical_or(math_ops.greater(temp_pos[1], 24),
              math_ops.logical_or(math_ops.less(temp_pos[0], 0), math_ops.less(temp_pos[1],0)))),
            lambda:_position[j,:], lambda:temp_pos))
          # new_pos.append(temp_pos)

          # update env
          new_env.append(array_ops.reshape(
              array_ops.slice(mapIdx, array_ops.pack([_mapNo[j], new_pos[-1][0], new_pos[-1][1], new_pos[-1][2], 0]), [1,1,1,1,state_size]),
              [state_size]))
        
        new_pos = array_ops.pack(new_pos)
        new_env = array_ops.pack(new_env)
        return new_pos, new_env
Example #57
0
def stratified_sample(tensors,
                      labels,
                      target_probs,
                      batch_size,
                      init_probs=None,
                      enqueue_many=False,
                      queue_capacity=16,
                      threads_per_queue=1,
                      name=None):
    """Stochastically creates batches based on per-class probabilities.

  This method discards examples. Internally, it creates one queue to amortize
  the cost of disk reads, and one queue to hold the properly-proportioned
  batch.

  Args:
    tensors: List of tensors for data. All tensors are either one item or a
        batch, according to enqueue_many.
    labels: Tensor for label of data. Label is a single integer or a batch,
        depending on `enqueue_many`. It is not a one-hot vector.
    target_probs: Target class proportions in batch. An object whose type has a
        registered Tensor conversion function.
    batch_size: Size of batch to be returned.
    init_probs: Class proportions in the data. An object whose type has a
        registered Tensor conversion function, or `None` for estimating the
        initial distribution.
    enqueue_many: Bool. If true, interpret input tensors as having a batch
        dimension.
    queue_capacity: Capacity of the large queue that holds input examples.
    threads_per_queue: Number of threads for the large queue that holds input
        examples and for the final queue with the proper class proportions.
    name: Optional prefix for ops created by this function.
  Raises:
    ValueError: If `tensors` isn't iterable.
    ValueError: `enqueue_many` is True and labels doesn't have a batch
        dimension, or if `enqueue_many` is False and labels isn't a scalar.
    ValueError: `enqueue_many` is True, and batch dimension on data and labels
        don't match.
    ValueError: if probs don't sum to one.
    ValueError: if a zero initial probability class has a nonzero target
        probability.
    TFAssertion: if labels aren't integers in [0, num classes).
  Returns:
    (data_batch, label_batch), where data_batch is a list of tensors of the same
        length as `tensors`

  Example:
    # Get tensor for a single data and label example.
    data, label = data_provider.Get(['data', 'label'])

    # Get stratified batch according to per-class probabilities.
    target_probs = [...distribution you want...]
    [data_batch], labels = tf.contrib.training.stratified_sample(
        [data], label, target_probs)

    # Run batch through network.
    ...
  """
    with ops.name_scope(name, 'stratified_sample', list(tensors) + [labels]):
        tensor_list = ops.convert_n_to_tensor_or_indexed_slices(tensors)
        labels = ops.convert_to_tensor(labels)
        target_probs = ops.convert_to_tensor(target_probs,
                                             dtype=dtypes.float32)
        # Reduce the case of a single example to that of a batch of size 1.
        if not enqueue_many:
            tensor_list = [
                array_ops.expand_dims(tensor, 0) for tensor in tensor_list
            ]
            labels = array_ops.expand_dims(labels, 0)

        # If `init_probs` is `None`, set up online estimation of data distribution.
        if init_probs is None:
            # We use `target_probs` to get the number of classes, so its shape must be
            # fully defined at graph construction time.
            target_probs.get_shape().assert_is_fully_defined()
            init_probs = _estimate_data_distribution(
                labels,
                target_probs.get_shape().num_elements())
        else:
            init_probs = ops.convert_to_tensor(init_probs,
                                               dtype=dtypes.float32)

        # Validate that input is consistent.
        tensor_list, labels, [init_probs, target_probs
                              ] = _verify_input(tensor_list, labels,
                                                [init_probs, target_probs])

        # Check that all zero initial probabilities also have zero target
        # probabilities.
        assert_op = control_flow_ops.Assert(
            math_ops.reduce_all(
                math_ops.logical_or(math_ops.not_equal(init_probs, 0),
                                    math_ops.equal(target_probs, 0))),
            [
                'All classes with zero initial probability must also have zero target '
                'probability: ', init_probs, target_probs
            ])
        init_probs = control_flow_ops.with_dependencies([assert_op],
                                                        init_probs)

        # Calculate acceptance sampling probabilities.
        accept_probs = _calculate_acceptance_probabilities(
            init_probs, target_probs)
        proportion_rejected = math_ops.reduce_sum(
            (1 - accept_probs) * init_probs)
        accept_probs = control_flow_ops.cond(
            math_ops.less(proportion_rejected, .5),
            lambda: accept_probs,
            lambda: logging_ops.Print(  # pylint: disable=g-long-lambda
                accept_probs, [accept_probs],
                message='Proportion of examples rejected by sampler is high.',
                first_n=10))

        # Make a single queue to hold input examples. Reshape output so examples
        # don't have singleton batch dimension.
        batched = input_ops.batch(tensor_list + [labels],
                                  batch_size=1,
                                  num_threads=threads_per_queue,
                                  capacity=queue_capacity,
                                  enqueue_many=True)
        val_list = [array_ops.squeeze(x, [0]) for x in batched[:-1]]
        label = array_ops.squeeze(batched[-1], [0])

        # Set up second queue containing batches that have the desired class
        # proportions.
        cur_prob = array_ops.gather(accept_probs, label)
        batched = input_ops.maybe_batch(
            val_list + [label],
            keep_input=random_ops.random_uniform([]) < cur_prob,
            batch_size=batch_size,
            num_threads=threads_per_queue)
        return batched[:-1], batched[-1]
Example #58
0
    def pair_weights(self, sorted_labels):
        """See `_LambdaWeight`."""
        with ops.name_scope(None, 'dcg_lambda_weight', (sorted_labels, )):
            valid_pair, sorted_labels = self._get_valid_pairs_and_clean_labels(
                sorted_labels)
            gain = self._gain_fn(sorted_labels)
            if self._normalized:
                gain *= self._inverse_max_dcg(sorted_labels)
            pair_gain = array_ops.expand_dims(gain, 2) - array_ops.expand_dims(
                gain, 1)
            pair_gain *= math_ops.to_float(valid_pair)

            list_size = array_ops.shape(sorted_labels)[1]
            topn = self._topn or list_size
            rank = math_ops.range(list_size) + 1

            def _discount_for_relative_rank_diff():
                """Rank-based discount in the LambdaLoss paper."""
                # The LambdaLoss is not well defined when topn is active and topn <
                # list_size. We cap the rank of examples to topn + 1 so that the rank
                # differene is capped to topn. This is just a convenient upperbound
                # when topn is active. We need to revisit this.
                capped_rank = array_ops.where(
                    math_ops.greater(rank, topn),
                    array_ops.ones_like(rank) * (topn + 1), rank)
                rank_diff = math_ops.to_float(
                    math_ops.abs(
                        array_ops.expand_dims(capped_rank, 1) -
                        array_ops.expand_dims(capped_rank, 0)))
                pair_discount = array_ops.where(
                    math_ops.greater(rank_diff, 0),
                    math_ops.abs(
                        self._rank_discount_fn(rank_diff) -
                        self._rank_discount_fn(rank_diff + 1)),
                    array_ops.zeros_like(rank_diff))
                return pair_discount

            def _discount_for_absolute_rank():
                """Standard discount in the LambdaMART paper."""
                # When the rank discount is (1 / rank) for example, the discount is
                # |1 / r_i - 1 / r_j|. When i or j > topn, the discount becomes 0.
                rank_discount = array_ops.where(
                    math_ops.greater(rank, topn),
                    array_ops.zeros_like(math_ops.to_float(rank)),
                    self._rank_discount_fn(math_ops.to_float(rank)))
                pair_discount = math_ops.abs(
                    array_ops.expand_dims(rank_discount, 1) -
                    array_ops.expand_dims(rank_discount, 0))
                return pair_discount

            u = _discount_for_relative_rank_diff()
            v = _discount_for_absolute_rank()
            pair_discount = (
                1. - self._smooth_fraction) * u + self._smooth_fraction * v
            pair_weight = math_ops.abs(pair_gain) * pair_discount
            if self._topn is None:
                return pair_weight
            pair_mask = math_ops.logical_or(
                array_ops.expand_dims(math_ops.less_equal(rank, self._topn),
                                      1),
                array_ops.expand_dims(math_ops.less_equal(rank, self._topn),
                                      0))
            return pair_weight * math_ops.to_float(pair_mask)
Example #59
0
        def body(time, elements_finished, current_input, emit_ta, state,
                 loop_state):
            """Internal while loop body for raw_rnn.

      Args:
        time: time scalar.
        elements_finished: batch-size vector.
        current_input: possibly nested tuple of input tensors.
        emit_ta: possibly nested tuple of output TensorArrays.
        state: possibly nested tuple of state tensors.
        loop_state: possibly nested tuple of loop state tensors.

      Returns:
        Tuple having the same size as Args but with updated values.
      """
            (next_output, cell_state) = cell(current_input, state)

            nest.assert_same_structure(state, cell_state)
            nest.assert_same_structure(cell.output_size, next_output)

            next_time = time + 1
            (next_finished, next_input, next_state, emit_output,
             next_loop_state) = loop_fn(next_time, next_output, cell_state,
                                        loop_state)

            nest.assert_same_structure(state, next_state)
            nest.assert_same_structure(current_input, next_input)
            nest.assert_same_structure(emit_ta, emit_output)

            # If loop_fn returns None for next_loop_state, just reuse the
            # previous one.
            loop_state = loop_state if next_loop_state is None else next_loop_state

            def _copy_some_through(current, candidate):
                """Copy some tensors through via array_ops.where."""
                current_flat = nest.flatten(current)
                candidate_flat = nest.flatten(candidate)
                # pylint: disable=g-long-lambda,cell-var-from-loop
                result_flat = [
                    _on_device(lambda: array_ops.where(elements_finished,
                                                       current_i, candidate_i),
                               device=candidate_i.op.device)
                    for (current_i,
                         candidate_i) in zip(current_flat, candidate_flat)
                ]
                # pylint: enable=g-long-lambda,cell-var-from-loop
                return nest.pack_sequence_as(structure=current,
                                             flat_sequence=result_flat)

            emit_output = _copy_some_through(zero_emit, emit_output)
            next_state = _copy_some_through(state, next_state)

            emit_output_flat = nest.flatten(emit_output)
            emit_ta_flat = nest.flatten(emit_ta)

            elements_finished = math_ops.logical_or(elements_finished,
                                                    next_finished)

            emit_ta_flat = [
                ta.write(time, emit)
                for (ta, emit) in zip(emit_ta_flat, emit_output_flat)
            ]

            emit_ta = nest.pack_sequence_as(structure=emit_structure,
                                            flat_sequence=emit_ta_flat)

            return (next_time, elements_finished, next_input, emit_ta,
                    next_state, loop_state)