def _InputBatch(self):
     targets = tf.ones([self.params.batch_size, 1024], dtype=tf.int32)
     input_batch = py_utils.NestedMap()
     input_batch.tgt = py_utils.NestedMap()
     input_batch.tgt.ids = tf.roll(targets, 1, axis=1)
     input_batch.tgt.labels = targets
     input_batch.tgt.segment_ids = tf.minimum(targets, 1)
     input_batch.tgt.segment_pos = targets
     input_batch = input_batch.Transform(
         lambda t: tf.ensure_shape(t, (self.params.batch_size, 1024)))
     return input_batch
    def __init__(self, params):
        super().__init__(params)
        p = self.params

        (utt_ids, src_frames,
         src_paddings), self._bucket_keys = self._BuildDataSource()

        self._sample_ids = utt_ids

        src_frames, src_paddings = self._MaybePadSourceInputs(
            src_frames, src_paddings)

        # We expect src_inputs to be of shape
        # [batch_size, num_frames, feature_dim, channels].
        src_frames = tf.expand_dims(src_frames, axis=-1)

        if p.pad_to_max_seq_length:
            assert p.source_max_length
            assert p.target_max_length

            if all(x == p.bucket_batch_limit[0] for x in p.bucket_batch_limit):
                # Set the input batch size as an int rather than a tensor.
                src_frames_shape = (self.InfeedBatchSize(),
                                    p.source_max_length, p.frame_size, 1)
                src_paddings_shape = (self.InfeedBatchSize(),
                                      p.source_max_length)
            else:
                tf.logging.warning(
                    'Could not set static input shape since not all bucket batch sizes '
                    'are the same:', p.bucket_batch_limit)
                src_frames_shape = None
                src_paddings_shape = None

            src_frames = py_utils.PadBatchDimension(src_frames,
                                                    self.InfeedBatchSize(), 0)
            src_paddings = py_utils.PadBatchDimension(src_paddings,
                                                      self.InfeedBatchSize(),
                                                      1)
            self._sample_ids = py_utils.PadBatchDimension(
                self._sample_ids, self.InfeedBatchSize(), self._sample_ids.min)

            src_frames = py_utils.PadSequenceDimension(src_frames,
                                                       p.source_max_length,
                                                       0,
                                                       shape=src_frames_shape)
            src_paddings = py_utils.PadSequenceDimension(
                src_paddings, p.source_max_length, 1, shape=src_paddings_shape)
            self._sample_ids = tf.ensure_shape(self._sample_ids,
                                               self.InfeedBatchSize())

        src = py_utils.NestedMap(src_inputs=src_frames, paddings=src_paddings)

        self._src = src
Beispiel #3
0
  def FProp(self, theta, prepared_inputs, step_inputs, padding, state0):
    """Calls an embedding lookup and updates the state of token history.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      prepared_inputs: unused.
      step_inputs: A NestedMap containing a list called inputs. This list should
        contain a single float32 (will be converted to int32 later) tensor of
        shape [batch], where each value represents an index into the embedding
        table. (By convention, all Steps that can be used with StackStep must
        store inputs in step_inputs.inputs[], but in this step it does not make
        sense for that list to have more than one tensor in it).
      padding: unused.
      state0: A NestedMap containing the state of previous tokens.
      - prev_ids: A Tensor containing the n previous token ids. [batch,
        num_prev_tokens]. Each row is the token ids at t-1, ..., t-n.

    Returns:
      Embedding vectors [batch, p.embedding_dim] and new state

    """
    p = self.params

    # prepare token ids
    if p.include_current_token:
      ids = tf.concat([
          tf.cast(step_inputs.inputs[0][:, None], tf.float32),
          tf.cast(state0.prev_ids, tf.float32)
      ],
                      axis=-1)
    else:
      ids = state0.prev_ids

    # lookup embedding. ids.shape is [batch, num_tokens]
    ids = tf.cast(ids, tf.int32)
    embedding = self.emb.EmbLookup(theta.emb, ids)
    embedding = tf.reshape(embedding, [-1, p.embedding_dim])

    # update state
    state1 = state0.copy()
    if p.num_prev_tokens > 0:
      state1.prev_ids = tf.concat([
          tf.cast(step_inputs.inputs[0][:, None], tf.float32),
          tf.cast(state0.prev_ids[:, :-1], tf.float32)
      ],
                                  axis=-1)
    state1.prev_ids = tf.ensure_shape(
        state1.prev_ids, [None, p.num_prev_tokens],
        name='prev_ids_shape_validation')
    state1.embedding = embedding

    return py_utils.NestedMap(output=embedding), state1
Beispiel #4
0
    def Decode(self, encoder_outputs, dec_input, state):
        """The decoder model.

    Args:
      encoder_outputs: a NestedMap containing the following fields: encoded -
        the encoding of the spelling of the target feature state - hidden state
        of the encoder output neighbor_spellings_encoded - encoding of neighbor
        spellings or tf.constant(0) neighbor_pronunciations_encoded - encoding
        of neighbor pronunciations or tf.constant(0) -
        initial prediction of [<s>] of shape [*, 1]
      dec_input: if not None, then use this instead of the dec_input in
        encoder_outputs.
      state: previous state of decoding.

    Returns:
       res: a NestedMap() containing
         predictions - of shape [*, output_vocab_size]
         state - updated hidden state
         attention_weights
         neighbor_attention_weights
    """
        p = self.params

        context_vector, attention_weights = self.attention(
            state, encoder_outputs.encoded)
        (neighbor_context_vector,
         neighbor_attention_weights) = self._NeighborModelAttention(
             encoder_outputs, state)
        x = self.shared_out_emb(dec_input)  # pylint: disable=not-callable
        if p.use_neighbors:
            x = tf.concat([context_vector, neighbor_context_vector, x],
                          axis=-1)
        else:
            x = tf.concat([context_vector, x], axis=-1)

        output, state = self._gru_cell(x, state)
        x = self._fc(output)

        # If this fails then the checkpoint contains incorrect shapes or the
        # hparams are incompatible. No idea why TF doesn't check this anymore.
        x = tf.ensure_shape(x, [None, p.output_vocab_size])

        res = py_utils.NestedMap()
        res.predictions = x
        res.state = state
        res.attention_weights = attention_weights
        res.neighbor_attention_weights = neighbor_attention_weights
        return res
  def BeamSearchDecode(self,
                       theta,
                       encoder_outputs,
                       num_hyps_per_beam_override=0,
                       init_beam_search_state=None,
                       pre_beam_search_step_callback=None,
                       post_beam_search_step_callback=None,
                       max_steps=None):
    """Performs beam-search based decoding.

    Args:
      theta: A NestedMap object containing weights' values of the decoder layer
        and its children layers.
      encoder_outputs: A NestedMap containing encoder outputs to be passed to
        the callbacks.
      num_hyps_per_beam_override: If set to a value <= 0, this parameter is
        ignored. If set to a value > 0, then this value will be used to override
        `p.num_hyps_per_beam`.
      init_beam_search_state: The `InitBeamSearchState` callback. Please refer
        to the class header comments for more details.
      pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback.
        Please refer to the class header comments for more details.
      post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback.
        Please refer to the class header comments for more details.
      max_steps: maximum beam search steps. If None, use
        self.params.target_seq_len.

    Returns:
      A `BeamSearchDecodeOutput`.
    """
    p = self.params
    num_hyps_per_beam = p.num_hyps_per_beam
    if num_hyps_per_beam_override > 0:
      num_hyps_per_beam = num_hyps_per_beam_override
    if max_steps is None:
      max_steps = p.target_seq_len

    initial_results, other_states = init_beam_search_state(
        theta, encoder_outputs, num_hyps_per_beam)

    num_hyps = tf.shape(initial_results.log_probs)[0]
    num_beams = num_hyps // num_hyps_per_beam

    if 'step_ids' in initial_results:
      # [num_hyps, 1]
      step_ids = tf.ensure_shape(initial_results.step_ids, [None, 1])
    else:
      step_ids = tf.fill([num_hyps, 1],
                         tf.constant(p.target_sos_id, dtype=tf.int32))

    min_score = -1e36
    best_scores = (tf.zeros(shape=[num_beams], dtype=p.dtype) + min_score)
    cumulative_scores = tf.zeros(shape=[num_hyps], dtype=p.dtype)
    in_scores = tf.zeros([max_steps, num_hyps], dtype=p.dtype)
    in_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.int32)
    in_prev_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.int32)
    in_done_hyps = tf.zeros([max_steps, num_hyps], dtype=tf.string)
    bs_atten_probs = tf.zeros(
        [max_steps, num_hyps,
         tf.shape(initial_results.atten_probs)[1]],
        dtype=p.dtype)
    cur_step = tf.constant(0, dtype=tf.int32)
    all_done = tf.constant(False, dtype=tf.bool)
    core_bs_states = (best_scores, cumulative_scores, in_scores, in_hyps,
                      in_prev_hyps, in_done_hyps, bs_atten_probs)

    def LoopContinue(cur_step, all_done, unused_step_ids, unused_core_bs_states,
                     unused_other_states_list):
      return tf.logical_and(cur_step < max_steps, tf.logical_not(all_done))

    def LoopBody(cur_step, unused_all_done, step_ids, core_bs_states,
                 other_states_list):
      (cur_step, all_done, new_step_ids, new_bs_states,
       new_other_states) = self._BeamSearchStep(
           theta, encoder_outputs, cur_step, step_ids, core_bs_states,
           other_states.Pack(other_states_list), num_hyps_per_beam,
           pre_beam_search_step_callback, post_beam_search_step_callback)
      return (cur_step, all_done, new_step_ids, new_bs_states,
              new_other_states.Flatten())

    flat_other_states = other_states.Flatten()
    _, _, _, final_bs_states, flat_final_other_states = tf.while_loop(
        LoopContinue,
        LoopBody,
        loop_vars=(cur_step, all_done, step_ids, core_bs_states,
                   flat_other_states),
        parallel_iterations=10,
        back_prop=False,
        swap_memory=False,
        shape_invariants=(tf.TensorShape(cur_step.get_shape()),
                          tf.TensorShape(all_done.get_shape()),
                          tf.TensorShape(step_ids.get_shape()),
                          _GetShapes(core_bs_states),
                          _GetShapes(flat_other_states, none_shapes=True)))
    # [target_seq_len, num_beams * num_hyps_per_beam].
    final_done_hyps = final_bs_states[5]
    final_other_states = other_states.Pack(flat_final_other_states)

    # TODO(rpang): avoid inspecting 'encoder_outputs'.
    source_paddings = encoder_outputs.padding
    if isinstance(source_paddings, py_utils.NestedMap):
      source_seq_lengths = tf.cast(
          tf.round(
              tf.reduce_sum(1.0 - tf.transpose(source_paddings.Flatten()[0]),
                            1)), tf.int32)
    else:
      source_seq_lengths = tf.cast(
          tf.round(tf.reduce_sum(1.0 - tf.transpose(source_paddings), 1)),
          tf.int32)

    # [num_beams, num_hyps_per_beam].
    topk_hyps = ops.top_k_terminated_hyps(
        final_done_hyps,
        source_seq_lengths,
        k=num_hyps_per_beam,
        num_hyps_per_beam=num_hyps_per_beam,
        length_normalization=p.length_normalization,
        coverage_penalty=p.coverage_penalty,
        target_seq_length_ratio=p.target_seq_length_ratio,
        eoc_id=p.target_eoc_id,
        merge_paths=p.merge_paths)
    # [num_beams * num_hyps_per_beam, ...].
    max_seq_length = 0 if isinstance(max_steps, tf.Tensor) else max_steps
    topk_ids, topk_lens, topk_scores = ops.unpack_hyp(
        tf.reshape(topk_hyps, [-1]), max_seq_length=max_seq_length)
    # [num_beams, num_hyps_per_beam].
    topk_scores = tf.reshape(topk_scores, tf.shape(topk_hyps))

    return BeamSearchDecodeOutput(final_done_hyps, topk_hyps, topk_ids,
                                  topk_lens, topk_scores, None,
                                  final_other_states)
Beispiel #6
0
 def Cast(x):
     x = tf.ensure_shape(x, [p.batch_size, p.max_sequence_length])
     if x.dtype.is_floating:
         x = tf.cast(x, py_utils.FPropDtype(p))
     return x
Beispiel #7
0
 def ShapeAndCast(x):
     x = tf.ensure_shape(x,
                         (self.InfeedBatchSize(), p.source_max_length))
     if x.dtype.is_floating:
         x = tf.cast(x, py_utils.FPropDtype(p))
     return x
Beispiel #8
0
 def _EnsureTgtShape(x):
     if x.dtype == tf.string:
         return tf.ensure_shape(x, [self._ScaledBatchSize()])
     return tf.ensure_shape(
         x,
         [self._ScaledBatchSize(), self.params.target_max_length])
Beispiel #9
0
 def _EnsureSrcShape(x):
     if x.dtype == tf.string:
         return tf.ensure_shape(x, [self._ScaledBatchSize()])
     return tf.ensure_shape(
         x,
         [self._ScaledBatchSize(), self.params.source_max_length])
Beispiel #10
0
    def GreedySearchDecode(self,
                           theta,
                           encoder_outputs,
                           init_beam_search_state=None,
                           pre_beam_search_step_callback=None,
                           post_beam_search_step_callback=None,
                           max_steps=None):
        """Performs greedy-search based decoding.

    Args:
      theta: A NestedMap object containing weights' values of the decoder layer
        and its children layers.
      encoder_outputs: A NestedMap containing encoder outputs to be passed to
        the callbacks.
      init_beam_search_state: The `InitBeamSearchState` callback. Please refer
        to the class header comments for more details.
      pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback.
        Please refer to the class header comments for more details.
      post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback.
        Please refer to the class header comments for more details.
      max_steps: maximum beam search steps. If None, use
        self.params.target_seq_len.

    Returns:
      A tuple (hyp_ids, hyp_lens, done_hyps). Note that num_hyps is same as
      src_batch_size.

        - hyp_ids: [num_hyps, max_step]. Hyps end with <eos> token if the <eos>
          token is encountered during search.
        - hyp_lens: [num_hyps].
        - done_hyps: [num_hyps], whether or not an eos is encountered.
    """
        p = self.params
        if max_steps is None:
            max_steps = p.target_seq_len

        initial_results, other_states = init_beam_search_state(
            theta,
            encoder_outputs,
            1  # num_hyps_per_beam
        )

        num_hyps = tf.shape(initial_results.log_probs)[0]

        if 'step_ids' in initial_results:
            # [num_hyps, 1]
            step_ids = tf.ensure_shape(initial_results.step_ids, [None, 1])
        else:
            step_ids = tf.fill([num_hyps, 1],
                               tf.constant(p.target_sos_id, dtype=tf.int32))

        cur_step = tf.constant(0, dtype=tf.int32)
        done_hyps = inplace_ops.empty(shape=[num_hyps],
                                      dtype=tf.bool,
                                      init=True,
                                      name='done_hyps')
        hyp_lens = inplace_ops.empty(shape=[num_hyps],
                                     dtype=tf.int32,
                                     init=True,
                                     name='hyp_lens')
        hyp_ids = inplace_ops.empty(shape=[max_steps, num_hyps],
                                    dtype=tf.int32,
                                    init=True,
                                    name='hyp_ids')

        def LoopContinue(cur_step, unused_step_ids, unused_hyp_ids,
                         unused_hyp_lens, done_hyps, unused_other_states_list):
            return tf.logical_and(cur_step < max_steps,
                                  tf.logical_not(tf.reduce_all(done_hyps)))

        def LoopBody(cur_step, step_ids, hyp_ids, hyp_lens, done_hyps,
                     other_states_list):
            (cur_step, new_step_ids, hyp_ids, hyp_lens, done_hyps,
             new_other_states) = self._GreedySearchStep(
                 theta, encoder_outputs, cur_step, step_ids, hyp_ids, hyp_lens,
                 done_hyps, other_states.Pack(other_states_list),
                 pre_beam_search_step_callback, post_beam_search_step_callback)
            return (cur_step, new_step_ids, hyp_ids, hyp_lens, done_hyps,
                    new_other_states.Flatten())

        flat_other_states = other_states.Flatten()
        _, _, final_hyp_ids, final_hyp_lens, final_done_hyps, _ = tf.while_loop(
            LoopContinue,
            LoopBody,
            loop_vars=(cur_step, step_ids, hyp_ids, hyp_lens, done_hyps,
                       flat_other_states),
            parallel_iterations=10,
            back_prop=False,
            swap_memory=False,
            shape_invariants=(tf.TensorShape(cur_step.get_shape()),
                              tf.TensorShape(step_ids.get_shape()),
                              tf.TensorShape(hyp_ids.get_shape()),
                              tf.TensorShape(hyp_lens.get_shape()),
                              tf.TensorShape(done_hyps.get_shape()),
                              _GetShapes(flat_other_states, none_shapes=True)))

        # transpose hyp_ids so it matches BeamSearchDecode's output
        final_hyp_ids = tf.transpose(final_hyp_ids)
        return final_hyp_ids, final_hyp_lens, final_done_hyps
def max_assignment(score: tf.Tensor,
                   *,
                   elementwise_upper_bound: tf.Tensor,
                   row_sums: tf.Tensor,
                   col_sums: tf.Tensor,
                   epsilon: float = 0.1,
                   num_iterations: int = 50,
                   use_epsilon_scaling: bool = True):
    """Differentiable max assignment with margin and upper bound constraints.

  Args:
    score: a 3D tensor of size [batch_size, n_rows, n_columns]. score[i, j, k]
      denotes the weight if the assignment on this entry is non-zero.
    elementwise_upper_bound: a 3D tensor of size [batch_size, n_rows,
      n_columns]. Each entry denotes the maximum value assignment[i, j, k] can
      take and must be a non-negative value. For example, upper_bound[i, j,
      k]=1.0 for binary assignment problem.
    row_sums: a 2D tensor of size [batch_size, n_rows]. The row sum constraint.
      The output assignment p[i, j, :] must sum to row_sums[i, j].
    col_sums: a 2D tensor of size [batch_size, n_columns]. The column sum
      constraint. The output assignment p[i, :, k] must sum to col_sums[i, k].
    epsilon: the epsilon coefficient of entropy regularization. The value should
      be within the range (0, 1]. `0.01` might work better than `0.1`. `0.1` may
      not make the assignment close enough to 0 or 1.
    num_iterations: the maximum number of iterations to perform.
    use_epsilon_scaling: whether to use epsilon scaling. In practice, the
      convergence of the iterative algorithm is much better if we start by
      solving the optimization with a larger epsilon value and re-use the
      solution (i.e. dual variables) for the instance with a smaller epsilon.
      This is called the epsilon scaling trick. See [Schmitzer 2019]
      (https://arxiv.org/pdf/1610.06519.pdf) as a reference. Here if
      use_epsilon_scaling=True, after each iteration we decrease the running
      epsilon by a constant factor until it reaches the target epsilon
      value. We found this to work well for gradient backward propagation,
      while the original scaling trick doesn't.

  Returns:
    A tuple with the following values.
      - assignment: a 3D tensor of size [batch_size, n_rows, n_columns].
        The output assignment.
      - used_iter: a scalar tensor indicating the number of iterations used.
      - eps: a scalar tensor indicating the stopping epsilon value.
      - delta: a scalar tensor indicating the stopping delta value (the relative
        change on the margins of assignment p in the last iteration).
  """

    # Check if all shapes are correct
    score_shape = score.shape
    bsz = score_shape[0]
    n = score_shape[1]
    m = score_shape[2]
    score = tf.ensure_shape(score, [bsz, n, m])
    elementwise_upper_bound = tf.ensure_shape(elementwise_upper_bound,
                                              [bsz, n, m])
    row_sums = tf.ensure_shape(tf.expand_dims(row_sums, axis=2), [bsz, n, 1])
    col_sums = tf.ensure_shape(tf.expand_dims(col_sums, axis=1), [bsz, 1, m])

    # the total sum of row sums must be equal to total sum of column sums
    sum_diff = tf.reduce_sum(row_sums, axis=1) - tf.reduce_sum(col_sums,
                                                               axis=2)
    sum_diff = tf.abs(sum_diff)
    tf.Assert(tf.reduce_all(sum_diff < 1e-6), [sum_diff])

    # Convert upper_bound constraint into another margin constraint
    # by adding auxiliary variables & scores. Tensor `a`, `b` and `c`
    # represent the margins (i.e. reduced sum) of 3 axes respectively.
    #
    max_row_sums = tf.reduce_sum(elementwise_upper_bound,
                                 axis=-1,
                                 keepdims=True)
    max_col_sums = tf.reduce_sum(elementwise_upper_bound,
                                 axis=-2,
                                 keepdims=True)
    score_ = tf.stack([score, tf.zeros_like(score)], axis=1)  # (bsz, 2, n, m)
    a = tf.stack([row_sums, max_row_sums - row_sums], axis=1)  # (bsz, 2, n, 1)
    b = tf.stack([col_sums, max_col_sums - col_sums], axis=1)  # (bsz, 2, 1, m)
    c = tf.expand_dims(elementwise_upper_bound, axis=1)  # (bsz, 1, n, m)

    # Clip log(0) to a large negative values -1e+36 to avoid
    # getting inf or NaN values in computation. Cannot use larger
    # values because float32 would use `-inf` automatically.
    #
    tf.Assert(tf.reduce_all(a >= 0), [a])
    tf.Assert(tf.reduce_all(b >= 0), [b])
    tf.Assert(tf.reduce_all(c >= 0), [c])
    log_a = tf.maximum(tf.math.log(a), -1e+36)
    log_b = tf.maximum(tf.math.log(b), -1e+36)
    log_c = tf.maximum(tf.math.log(c), -1e+36)

    # Initialize the dual variables of margin constraints
    u = tf.zeros_like(a)
    v = tf.zeros_like(b)
    w = tf.zeros_like(c)

    eps = tf.constant(1.0 if use_epsilon_scaling else epsilon,
                      dtype=score.dtype)
    epsilon = tf.constant(epsilon, dtype=score.dtype)

    def do_updates(cur_iter, eps, u, v, w):  # pylint: disable=unused-argument
        # Epsilon scaling, i.e. gradually decreasing `eps` until it
        # reaches the target `epsilon` value
        cur_iter = tf.cast(cur_iter, u.dtype)
        scaling = tf.minimum(0.6 * 1.04**cur_iter, 0.85)
        eps = tf.maximum(epsilon, eps * scaling)
        score_div_eps = score_ / eps

        # Update u
        log_q_1 = score_div_eps + (w + v) / eps
        log_q_1 = tf.reduce_logsumexp(log_q_1, axis=-1, keepdims=True)
        new_u = (log_a - tf.maximum(log_q_1, -1e+30)) * eps

        # Update v
        log_q_2 = score_div_eps + (w + new_u) / eps
        log_q_2 = tf.reduce_logsumexp(log_q_2, axis=-2, keepdims=True)
        new_v = (log_b - tf.maximum(log_q_2, -1e+30)) * eps

        # Update w
        log_q_3 = score_div_eps + (new_u + new_v) / eps
        log_q_3 = tf.reduce_logsumexp(log_q_3, axis=-3, keepdims=True)
        new_w = (log_c - tf.maximum(log_q_3, -1e+30)) * eps
        return eps, new_u, new_v, new_w

    def compute_relative_changes(eps, u, v, w, new_eps, new_u, new_v, new_w):
        prev_sum_uvw = tf.stop_gradient((u + v + w) / eps)
        sum_uvw = tf.stop_gradient((new_u + new_v + new_w) / new_eps)

        # Compute the relative changes on margins of P.
        # This will be used for stopping criteria.
        # Note the last update on w would guarantee the
        # margin constraint c is satisfied, so we don't
        # need to check it here.
        p = tf.exp(tf.stop_gradient(score_ / new_eps + sum_uvw))
        p_a = tf.reduce_sum(p, axis=-1, keepdims=True)
        p_b = tf.reduce_sum(p, axis=-2, keepdims=True)
        delta_a = tf.abs(a - p_a) / (a + 1e-6)
        delta_b = tf.abs(b - p_b) / (b + 1e-6)
        new_delta = tf.reduce_max(delta_a)
        new_delta = tf.maximum(new_delta, tf.reduce_max(delta_b))

        # Compute the relative changes on assignment solution P.
        # This will be used for stopping criteria.
        delta_p = tf.abs(tf.exp(prev_sum_uvw) -
                         tf.exp(sum_uvw)) / (tf.exp(sum_uvw) + 1e-6)
        new_delta = tf.maximum(new_delta, tf.reduce_max(delta_p))
        return new_delta

    for cur_iter in tf.range(num_iterations):
        prev_eps, prev_u, prev_v, prev_w = eps, u, v, w
        eps, u, v, w = do_updates(cur_iter, eps, u, v, w)
    delta = compute_relative_changes(prev_eps, prev_u, prev_v, prev_w, eps, u,
                                     v, w)
    cur_iter = num_iterations
    assignment = tf.exp((score_ + u + v + w) / eps)
    assignment = assignment[:, 0]
    return assignment, cur_iter, eps, delta
Beispiel #12
0
    def __init__(self, params):
        super().__init__(params)
        p = self.params

        (utt_ids, audio_document_ids, num_utterances_in_audio_document,
         tgt_ids, tgt_labels, tgt_paddings, src_frames,
         src_paddings), self._bucket_keys = self._BuildDataSource()

        self._sample_ids = utt_ids

        src_frames, src_paddings = self._MaybePadSourceInputs(
            src_frames, src_paddings)

        # We expect src_inputs to be of shape
        # [batch_size, num_frames, feature_dim, channels].
        src_frames = tf.expand_dims(src_frames, axis=-1)

        # Convert target ids, labels, paddings, and weights from shape [batch_size,
        # 1, num_frames] to [batch_size, num_frames]
        tgt_ids = tf.squeeze(tgt_ids, axis=1)
        tgt_labels = tf.squeeze(tgt_labels, axis=1)
        tgt_paddings = tf.squeeze(tgt_paddings, axis=1)

        if p.pad_to_max_seq_length:
            assert p.source_max_length
            assert p.target_max_length

            if all(x == p.bucket_batch_limit[0] for x in p.bucket_batch_limit):
                # Set the input batch size as an int rather than a tensor.
                src_frames_shape = (self.InfeedBatchSize(),
                                    p.source_max_length, p.frame_size, 1)
                src_paddings_shape = (self.InfeedBatchSize(),
                                      p.source_max_length)
                tgt_shape = (self.InfeedBatchSize(), p.target_max_length)
            else:
                tf.logging.warning(
                    'Could not set static input shape since not all bucket batch sizes '
                    'are the same:', p.bucket_batch_limit)
                src_frames_shape = None
                src_paddings_shape = None
                tgt_shape = None

            src_frames = py_utils.PadBatchDimension(src_frames,
                                                    self.InfeedBatchSize(), 0)
            src_paddings = py_utils.PadBatchDimension(src_paddings,
                                                      self.InfeedBatchSize(),
                                                      1)
            tgt_ids = py_utils.PadBatchDimension(tgt_ids,
                                                 self.InfeedBatchSize(), 0)
            tgt_labels = py_utils.PadBatchDimension(tgt_labels,
                                                    self.InfeedBatchSize(), 0)
            tgt_paddings = py_utils.PadBatchDimension(tgt_paddings,
                                                      self.InfeedBatchSize(),
                                                      1)
            self._sample_ids = py_utils.PadBatchDimension(
                self._sample_ids, self.InfeedBatchSize(),
                type(self).PAD_INDEX)
            # For reasons I don't understand, the shape of self._sample_ids after the above is
            # [BatchSize, 1] rather than [BatchSize].
            self._sample_ids = tf.squeeze(self._sample_ids, axis=1)
            self._sample_ids = tf.ensure_shape(self._sample_ids,
                                               self.InfeedBatchSize())

            audio_document_ids = py_utils.PadBatchDimension(
                audio_document_ids, self.InfeedBatchSize(),
                type(self).PAD_INDEX)
            # For reasons I don't understand, the shape of audio_document_ids after the above is
            # [BatchSize, 1] rather than [BatchSize].
            audio_document_ids = tf.squeeze(audio_document_ids, axis=1)
            audio_document_ids = tf.ensure_shape(audio_document_ids,
                                                 self.InfeedBatchSize())

            num_utterances_in_audio_document = py_utils.PadBatchDimension(
                num_utterances_in_audio_document, self.InfeedBatchSize(),
                type(self).PAD_INDEX)
            # For reasons I don't understand, the shape of num_utterances_in_audio_document after the above is
            # [BatchSize, 1] rather than [BatchSize].
            num_utterances_in_audio_document = tf.squeeze(
                num_utterances_in_audio_document, axis=1)
            num_utterances_in_audio_document = tf.ensure_shape(
                num_utterances_in_audio_document, self.InfeedBatchSize())

            src_frames = py_utils.PadSequenceDimension(src_frames,
                                                       p.source_max_length,
                                                       0,
                                                       shape=src_frames_shape)
            src_paddings = py_utils.PadSequenceDimension(
                src_paddings, p.source_max_length, 1, shape=src_paddings_shape)
            tgt_ids = py_utils.PadSequenceDimension(tgt_ids,
                                                    p.target_max_length,
                                                    0,
                                                    shape=tgt_shape)
            tgt_labels = py_utils.PadSequenceDimension(tgt_labels,
                                                       p.target_max_length,
                                                       0,
                                                       shape=tgt_shape)
            tgt_paddings = py_utils.PadSequenceDimension(tgt_paddings,
                                                         p.target_max_length,
                                                         1,
                                                         shape=tgt_shape)

        tgt = py_utils.NestedMap(ids=tgt_ids,
                                 labels=tgt_labels,
                                 paddings=tgt_paddings,
                                 weights=1.0 - tgt_paddings)
        src = py_utils.NestedMap(src_inputs=src_frames, paddings=src_paddings)

        self._tgt = tgt
        self._src = src

        self._audio_document_ids = audio_document_ids
        self._num_utterances_in_audio_document = num_utterances_in_audio_document
Beispiel #13
0
def DecodeWavPyFunc(input_bytes):
    sample_rate, audio = tf.py_function(read_wave_via_scipy, [input_bytes],
                                        [tf.int32, tf.float32])
    sample_rate = tf.ensure_shape(sample_rate, [])
    audio = tf.ensure_shape(audio, [None, 1])
    return sample_rate, audio