예제 #1
0
  def _GetBetaGamma(self, theta, inputs, **kwargs):
    p = self.params

    assert 'class_emb' in kwargs
    class_emb = kwargs['class_emb']

    # class_emb is a one-hot vector of shape [batch, class_emb_dim=num_classes].
    class_ids = tf.math.argmax(class_emb, axis=-1, output_type=tf.int32)
    # [batch, dim]
    # Not using matmul/einsum to avoid potential precision problem on TPU with
    # sparse inputs.
    beta = tf.gather(theta.beta, class_ids)
    gamma = tf.gather(theta.gamma, class_ids)
    if not p.gamma_zero_init and not p.gamma_one_init:
      # Note, The real gamma to use is 1 + gamma.
      gamma = 1.0 + gamma

    # Extend to [batch, 1, ... 1, dim]
    batch = py_utils.GetShape(inputs)[0]
    to_shape = tf.concat(
        [[batch],
         tf.ones([py_utils.GetRank(inputs) - 2], tf.int32), [self.params.dim]],
        axis=0)
    beta = tf.reshape(beta, to_shape)
    gamma = tf.reshape(gamma, to_shape)
    return beta, gamma
예제 #2
0
    def _InputBatch(self):
        p = self.params

        @tf.function
        def ReadData():
            x, y = io_ops.restore_v2(p.ckpt, [p.data, p.label], [''] * 2,
                                     [p.data_dtype, p.label_dtype])
            # Always convert to float32.
            return tf.cast(x, tf.float32), tf.cast(y, tf.float32)

        # Loads data and label into memory and keep it around.
        data, label = ops.cached_call(f=ReadData.get_concrete_function(),
                                      T=[tf.float32, tf.float32])
        b, shape = self.InfeedBatchSize(), list(p.data_shape)
        data = tf.reshape(data, [-1] + shape)
        label = tf.reshape(label, [-1])
        label = py_utils.HasShape(label, [tf.shape(data)[0]])
        sample_ids = ops.random_permutation_sequence(
            num=p.num_samples,
            batch=b,
            repeat=p.repeat,
            seed=p.random_seed if p.random_seed else 0)
        n = tf.shape(sample_ids)[0]
        raw = py_utils.PadOrTrimTo(tf.gather(data, sample_ids), [b] + shape)
        ret = py_utils.NestedMap(
            raw=raw,
            data=self._Preprocess(raw),
            label=py_utils.PadOrTrimTo(tf.gather(label, sample_ids), [b]),
            weight=py_utils.PadOrTrimTo(tf.ones([n], dtype=tf.float32), [b]))
        if not py_utils.use_tpu():
            ret['sample_ids'] = sample_ids
        return ret
예제 #3
0
    def testForwardPassWithDoubleBatch(self):
        with self.session(use_gpu=False) as sess:
            p = self._EncoderParams()
            bs = 2
            seq_len = 16
            tf.random.set_seed(8372749040)
            mt_enc = p.Instantiate()
            batch = py_utils.NestedMap()
            batch.ids = tf.constant(
                np.random.randint(low=0,
                                  high=63,
                                  size=[bs, seq_len],
                                  dtype=np.int32))
            paddings = []
            for _ in range(bs):
                zeros_len = np.random.randint(1, seq_len + 1)
                paddings.append([
                    0.,
                ] * zeros_len + [1.] * (seq_len - zeros_len))
            batch.paddings = tf.zeros([bs, seq_len])

            other_batch = py_utils.NestedMap()
            other_batch.ids = tf.gather(batch.ids, [1, 0])
            other_batch.paddings = tf.gather(batch.paddings, [1, 0])
            lambdas = np.random.random((bs, seq_len))
            lambdas = tf.constant(lambdas, tf.float32)
            out = mt_enc.FProp(mt_enc.theta,
                               batch,
                               interpolation_batch=other_batch,
                               lambdas=[lambdas, 1 - lambdas])
            enc_out_sum = tf.reduce_sum(out.encoded, 0)

            tf.global_variables_initializer().run()
            actual_enc_out, actual_enc_out_sum = sess.run(
                [out.encoded, enc_out_sum])

            expected_enc_out_sum = [[
                -38.089085, -22.181915, 3.3765068, -45.2483, -58.186905,
                -3.4464571, 24.461462, 12.014615, 33.08178, 34.02244,
                23.391253, -15.515911, 0.72847706, 50.45283, -26.36325,
                21.799355
            ],
                                    [
                                        -37.716507, -12.993027, 7.148979,
                                        -39.70747, -57.864025, 2.2049172,
                                        29.571432, 18.955816, 30.406136,
                                        33.270325, 21.685469, -17.21592,
                                        1.3697424, 49.33187, -30.023928,
                                        22.915518
                                    ]]  # pyformat: disable

            self.assertAllEqual([seq_len, bs, p.model_dim],
                                actual_enc_out.shape)
            self.assertAllClose(expected_enc_out_sum,
                                actual_enc_out_sum,
                                rtol=1e-05,
                                atol=1e-05)
예제 #4
0
 def _GetLangIds(self, source_id):
   """Look up the correct lang_id from the source_id tensor."""
   task_id = self._GetTaskIds(source_id)
   src_lang_id = task_id
   tgt_lang_id = task_id
   if self.params.task_to_src_lang_map:
     src_langs = tf.constant(self.params.task_to_src_lang_map, dtype=tf.int32)
     src_lang_id = tf.gather(src_langs, task_id)
   if self.params.task_to_tgt_lang_map:
     tgt_langs = tf.constant(self.params.task_to_tgt_lang_map, dtype=tf.int32)
     tgt_lang_id = tf.gather(tgt_langs, task_id)
   return src_lang_id, tgt_lang_id
예제 #5
0
 def ReOrderHyps(x_in):
     """Reorders x_in based on prev hyp ids."""
     if (isinstance(x_in, tf.Tensor) and x_in.shape.ndims
             and x_in.shape.ndims > 0):
         if x_in.shape.ndims > 2 and not p.batch_major_state:
             x_out = tf.gather(x_in, old_hyp_ids, axis=1)
         else:
             x_out = tf.gather(x_in, old_hyp_ids)
         x_out.set_shape(x_in.get_shape())
         return x_out
     else:
         return x_in
예제 #6
0
            def ApplyBias():
                """Bias and update log_probs and consistent."""
                def TileForBeamAndFlatten(tensor):
                    tensor = tf.reshape(tensor, [1, -1])  # [1, src_batch]
                    tensor = tf.tile(tensor,
                                     [num_hyps_per_beam, 1
                                      ])  # [num_hyps_per_beam, src_batch]
                    tgt_batch = tf.shape(step_ids)[
                        0]  # num_hyps_per_beam*src_batch
                    return tf.reshape(tensor, [tgt_batch])

                # Consistent if step_ids == labels from previous step
                # TODO(navari): Consider updating consistent only if weights > 0. Then
                # re-evaluate the need for bias_only_if_consistent=True.
                # Note that prev_label is incorrrect for step 0 but is overridden later
                prev_label = TileForBeamAndFlatten(
                    tf.gather(labels, tf.maximum(time_step - 1, 0), axis=1))
                is_step0 = tf.equal(time_step, 0)
                local_consistence = tf.logical_or(
                    is_step0, tf.equal(prev_label, tf.squeeze(step_ids, 1)))
                consistent = tf.logical_and(states.consistent,
                                            local_consistence)

                # get label, weight slices corresponding to current time_step
                label = TileForBeamAndFlatten(
                    tf.gather(labels, time_step, axis=1))
                weight = TileForBeamAndFlatten(
                    tf.gather(weights, time_step, axis=1))
                if p.bias_only_if_consistent:
                    weight = weight * tf.cast(consistent, p.dtype)

                # convert from dense label to sparse label probs
                vocab_size = tf.shape(bs_results.log_probs)[1]
                uncertainty = tf.constant(
                    1e-10,
                    p.dtype)  # avoid 0 probs which may cause issues with log
                label_probs = tf.one_hot(
                    label,
                    vocab_size,
                    on_value=1 - uncertainty,
                    off_value=uncertainty / tf.cast(vocab_size - 1, p.dtype),
                    dtype=p.dtype)  # [tgt_batch, vocab_size]
                pred_probs = tf.exp(bs_results.log_probs)

                # interpolate predicted probs and label probs
                weight = tf.expand_dims(weight, 1)
                probs = py_utils.with_dependencies([
                    py_utils.assert_less_equal(weight, 1.),
                    py_utils.assert_greater_equal(weight, 0.)
                ], (1.0 - weight) * pred_probs + weight * label_probs)
                return tf.log(probs), consistent
예제 #7
0
def ComputeMoments(inputs,
                   padding,
                   reduce_over_dims,
                   cumulative_axis=None,
                   enable_cross_replica_sum_on_tpu=False,
                   keepdims=False):
    """Computes mean and variance over the valid data points in inputs."""
    mask = 1.0 - padding
    inputs = py_utils.with_dependencies([
        py_utils.assert_equal(tf.rank(inputs), tf.rank(mask)),
        py_utils.assert_greater_equal(mask, tf.zeros_like(mask)),
    ], inputs)
    sum_v = tf.reduce_sum(inputs * tf.cast(mask, inputs.dtype),
                          reduce_over_dims,
                          keepdims=keepdims)
    count_v = tf.reduce_sum(mask, reduce_over_dims, keepdims=keepdims)

    if cumulative_axis is not None:
        sum_v = tf.math.cumsum(sum_v, axis=cumulative_axis)
        count_v = tf.math.cumsum(count_v, axis=cumulative_axis)
    # Input shape is guaranteed to be a multiple of mask shape because the
    # inputs * mask op above was successfully broadcasted.
    input_size_on_reduced_dims = tf.reduce_prod(
        tf.gather(tf.shape(inputs), reduce_over_dims))
    mask_size_on_reduced_dims = tf.reduce_prod(
        tf.gather(tf.shape(mask), reduce_over_dims))
    mask_multiplier = tf.math.truediv(input_size_on_reduced_dims,
                                      mask_size_on_reduced_dims)
    count_v *= tf.cast(mask_multiplier, count_v.dtype)
    if py_utils.use_tpu() and enable_cross_replica_sum_on_tpu:
        sum_v = tf.tpu.cross_replica_sum(sum_v)
        count_v = tf.tpu.cross_replica_sum(count_v)

    count_v = tf.maximum(count_v, 1.0)
    mean = sum_v / count_v
    sum_vv = tf.reduce_sum((inputs - mean) * (inputs - mean) * mask,
                           reduce_over_dims,
                           keepdims=keepdims)
    if cumulative_axis is not None:
        sum_vv = tf.math.cumsum(sum_vv, axis=cumulative_axis)

    if py_utils.use_tpu() and enable_cross_replica_sum_on_tpu:
        sum_vv = tf.tpu.cross_replica_sum(sum_vv)

    variance = py_utils.with_dependencies([
        py_utils.assert_greater_equal(sum_vv, tf.zeros_like(sum_vv)),
    ], sum_vv / count_v)
    return mean, variance
예제 #8
0
        def _TokenizeOneSentence(i, text, token_ids_ta, target_ids_ta,
                                 paddings_ta):
            """Tokenizes a single sentence."""
            if tf.is_tensor(i):
                text_i = tf.gather(text, i)
            else:
                text_i = text[i]
            ids = self._tokenizer.tokenize(text_i).merge_dims(0, -1)
            ids.set_shape([None])

            if append_eos:
                ids = tf.concat([ids, [self.eos_id]], axis=0)
            sos_ids = tf.concat([[self.sos_id], ids], axis=0)
            if p.prepend_sos:
                ids = sos_ids

            # This truncates after the EOS is added, so some sentences might
            # not have EOS at the end.
            token_ids_ta = token_ids_ta.write(
                i, py_utils.PadOrTrimTo(sos_ids, [max_length], 0))
            target_ids_ta = target_ids_ta.write(
                i, py_utils.PadOrTrimTo(ids, [max_length], 0))
            paddings_ta = paddings_ta.write(
                i,
                py_utils.PadOrTrimTo(tf.zeros_like(ids, dtype=tf.float32),
                                     [max_length], 1.))

            return i + 1, strs, token_ids_ta, target_ids_ta, paddings_ta
예제 #9
0
    def CreateDenseCoordinates(self, ranges):
        """Create a matrix of coordinate locations corresponding to a dense grid.

    Example: To create (x, y) coordinates corresponding over a 10x10 grid with
    step sizes 1, call ``CreateDenseCoordinates([(1, 10, 10), (1, 10, 10)])``.

    Args:
      ranges: A list of 3-tuples, each tuple is expected to contain (min, max,
        num_steps). Each list element corresponds to one dimesion. Each tuple
        will be passed into np.linspace to create the values for a single
        dimension.

    Returns:
      tf.float32 tensor of shape [total_points, len(ranges)], where
      total_points = product of all num_steps.

    """
        total_points = int(np.prod([r_steps for _, _, r_steps in ranges]))
        cycle_steps = total_points
        stack_coordinates = []

        for r_start, r_stop, r_steps in ranges:
            values = tf.lin_space(tf.cast(r_start, tf.float32),
                                  tf.cast(r_stop, tf.float32),
                                  tf.cast(r_steps, tf.int32))
            cycle_steps //= r_steps
            gather_idx = (tf.range(total_points) // cycle_steps) % r_steps
            stack_coordinates.append(tf.gather(values, gather_idx))

        return tf.stack(stack_coordinates, axis=1)
예제 #10
0
 def _GetTaskIds(self, source_id):
     """Look up the correct task_id from the source_id tensor."""
     if self.params.file_pattern_task_ids:
         file_task_ids = tf.constant(self.params.file_pattern_task_ids,
                                     dtype=tf.int32)
         return tf.gather(file_task_ids, source_id)
     return source_id
예제 #11
0
 def ReOrderHyps(x_in):
   """Reorders x_in based on prev hyp ids."""
   if (isinstance(x_in, tf.Tensor) and x_in.shape.ndims and
       x_in.shape.ndims > 0):
     if x_in.shape.ndims > 2 and not p.batch_major_state:
       # Use corrected indices only here for batch major compute as key/value
       # caches are the states being affected.
       correct_old_hyp_ids = (
           old_hyp_ids_in_cache_order
           if p.batch_major_compute else old_hyp_ids)
       x_out = tf.gather(x_in, correct_old_hyp_ids, axis=1)
     else:
       x_out = tf.gather(x_in, old_hyp_ids)
     x_out.set_shape(x_in.get_shape())
     return x_out
   else:
     return x_in
예제 #12
0
 def _GetTaskIds(self, source_id):
   """Look up the correct task_id from the source_id tensor."""
   if self.params.file_pattern_task_ids:
     file_task_ids = tf.constant(
         self.params.file_pattern_task_ids, dtype=tf.int32)
     source_id = tf.gather(file_task_ids, source_id)
   src_task_id = source_id
   tgt_task_id = source_id
   if self.params.task_to_src_lang_map:
     src_lang_ids = tf.constant(
         self.params.task_to_src_lang_map, dtype=tf.int32)
     src_task_id = tf.gather(src_lang_ids, src_task_id)
   if self.params.task_to_tgt_lang_map:
     tgt_lang_ids = tf.constant(
         self.params.task_to_tgt_lang_map, dtype=tf.int32)
     tgt_task_id = tf.gather(tgt_lang_ids, tgt_task_id)
   return src_task_id, tgt_task_id
예제 #13
0
                def ApplyBias():
                    """Bias and update log_probs and consistent."""

                    # Consistent if step_ids == labels from previous step
                    # TODO(navari): Consider updating consistent only if weights > 0. Then
                    # re-evaluate the need for bias_only_if_consistent=True.
                    # Note that prev_label is incorrrect for step 0 but is overridden
                    # later
                    prev_label = TileForBeamAndFlatten(
                        tf.gather(labels, tf.maximum(time_step - 1, 0),
                                  axis=1))
                    is_step0 = tf.equal(time_step, 0)
                    local_consistence = tf.math.logical_or(
                        is_step0, tf.equal(prev_label, tf.squeeze(step_ids,
                                                                  1)))
                    consistent = tf.math.logical_and(states.consistent,
                                                     local_consistence)

                    # get label, weight slices corresponding to current time_step
                    label = TileForBeamAndFlatten(
                        tf.gather(labels, time_step, axis=1))
                    weight = TileForBeamAndFlatten(
                        tf.gather(weights, time_step, axis=1))
                    if p.bias_only_if_consistent:
                        weight = weight * tf.cast(consistent,
                                                  py_utils.FPropDtype(p))

                    # convert from dense label to sparse label probs
                    vocab_size = tf.shape(bs_results.log_probs)[1]
                    label_probs = tf.one_hot(label,
                                             vocab_size,
                                             dtype=py_utils.FPropDtype(
                                                 p))  # [tgt_batch, vocab_size]
                    pred_probs = tf.exp(bs_results.log_probs)

                    # interpolate predicted probs and label probs
                    weight = tf.expand_dims(weight, 1)
                    probs = py_utils.with_dependencies([
                        py_utils.assert_less_equal(weight, 1.),
                        py_utils.assert_greater_equal(weight, 0.)
                    ], (1.0 - weight) * pred_probs + weight * label_probs)
                    # Ensure that tf.math.log is applied to positive values.
                    probs = tf.maximum(probs,
                                       tf.constant(1e-12, dtype=probs.dtype))
                    return tf.math.log(probs), consistent
예제 #14
0
 def _Padding():
   indices = tf.random.uniform([num_points_out - actual_num],
                               minval=0,
                               maxval=actual_num,
                               dtype=tf.int32,
                               seed=seed)
   padded = []
   for t in tensor_list:
     padded.append(tf.concat([t, tf.gather(t, indices, axis=0)], axis=0))
   return padded
예제 #15
0
 def _ApplyMass(source_id):
     if self.params.file_pattern_task_ids:
         file_task_ids = tf.constant(
             self.params.file_pattern_task_ids, dtype=tf.int32)
         task_id = tf.gather(file_task_ids, source_id)
     else:
         task_id = source_id
     mass_task_ids = tf.constant(self.params.mass_task_ids,
                                 dtype=tf.int32)
     return tf.reduce_any(tf.equal(task_id, mass_task_ids))
예제 #16
0
        def Step(recurrent_theta, state0, inputs):
            """Computes one decoder step."""
            if p.use_recurrent:
                del inputs
            with tf.name_scope('single_sampler_step'):
                # Compute logits and states.
                bs_result, bs_state1 = pre_step_callback(
                    decoder_theta,
                    recurrent_theta.encoder_outputs,
                    tf.expand_dims(state0.ids, 1),  # [batch, 1].
                    state0.bs_state,
                    num_hyps_per_beam=p.num_hyps_per_beam)
                batch = tf.shape(bs_result.log_probs)[0]
                state1 = py_utils.NestedMap(timestep=state0.timestep + 1)
                state1.logits = bs_result.log_probs

                if p.top_k > 0:
                    topk_logits, topk_ids = tf.math.top_k(state1.logits,
                                                          k=p.top_k)
                    sample_logits = tf.nn.log_softmax(
                        topk_logits) if p.top_k_renormalize else topk_logits
                else:
                    sample_logits = state1.logits

                # Sample ids from logits. [batch].
                ids = tf.reshape(
                    tf.random.stateless_categorical(
                        sample_logits / p.temperature,
                        num_samples=1,
                        seed=tf.stack(
                            [recurrent_theta.random_seed, state0.timestep]),
                        dtype=state0.ids.dtype,
                        name='sample_next_id'), [batch])
                state1.ids = tf.gather(topk_ids, ids, axis=1,
                                       batch_dims=1) if p.top_k > 0 else ids

                if 'is_last_chunk' in bs_result and p.target_eoc_id >= 0:
                    state1.ids = tf.where(
                        tf.math.logical_and(
                            bs_result.is_last_chunk,
                            tf.equal(state1.ids, p.target_eoc_id)),
                        tf.fill(tf.shape(state1.ids), p.target_eos_id),
                        state1.ids)
                state1.bs_state = post_step_callback(
                    decoder_theta, recurrent_theta.encoder_outputs, state1.ids,
                    bs_state1)
            if p.use_recurrent:
                return state1, py_utils.NestedMap()
            else:
                inputs.ids = inputs.ids.write(state0.timestep, state1.ids)
                inputs.logits = inputs.logits.write(state0.timestep,
                                                    state1.logits)
                return (recurrent_theta, state1, inputs)
예제 #17
0
 def ReOrderHyps(key, x_in):
     """Reorders x_in based on prev hyp ids."""
     if random_seed_regex.match(key):
         # For keys like rnn_states[0].r, it is a shape [2] random seeds tensor
         # used for deterministic behavior and should not be reordered.
         return py_utils.HasShape(x_in, [2])
     correct_old_hyp_ids = (old_hyp_ids_in_cache_order
                            if p.batch_major_compute else old_hyp_ids)
     if (isinstance(x_in, tf.Tensor) and x_in.shape.ndims):
         if x_in.shape.ndims > 2 and not p.batch_major_state:
             # Use corrected indices only here for batch major compute as key/value
             # caches are the states being affected.
             x_out = tf.gather(x_in, correct_old_hyp_ids, axis=1)
         elif key in POSSIBLY_TIME_MAJOR_STATE_KEYS:
             x_out = tf.gather(x_in, old_hyp_ids, axis=-1)
         else:
             x_out = tf.gather(x_in, correct_old_hyp_ids)
         x_out.set_shape(x_in.get_shape())
         return x_out
     else:
         return x_in
예제 #18
0
 def _ReshapeGather(tensor):
   """Reshapes tensor and then gathers using the nms indices."""
   tensor = tf.gather(
       tf.reshape(tensor, [batch_size, num_bboxes, -1]),
       per_cls_idxs,
       batch_dims=1)
   if not p.use_oriented_per_class_nms:
     # Tile so that the data fits the expected per class shape of
     # [batch_size, num_classes, ...]. When *not* using oriented NMS, the
     # num_classes dimension will be missing since the indices will not
     # have it.
     tensor = tf.tile(tensor[:, tf.newaxis, :, :],
                      [1, p.num_classes, 1, 1])
   return tensor
예제 #19
0
def CollectVarHistogram(vs_gs):
    """Adds histogram summaries for variables and gradients."""

    for name, (var, grad) in vs_gs.FlattenItems():
        with tf.device(var.device), tf.name_scope(name + '/summary'):
            if isinstance(grad, tf.IndexedSlices):
                var = tf.gather(var, grad.indices)
                grad = grad.values
            if var.dtype.is_complex:
                var = tf.abs(var)
                grad = tf.abs(grad)

        histogram('var_hist/' + name, var)
        histogram('grad_hist/' + name, grad)
예제 #20
0
파일: pruning.py 프로젝트: snsun/lingvo
    def _update_mask(self, weights, threshold):
        """Updates the mask for a given weight tensor.

    This functions first computes the cdf of the weight tensor, and estimates
    the threshold value such that 'desired_sparsity' fraction of weights
    have magnitude less than the threshold.

    Args:
      weights: The weight tensor that needs to be masked.
      threshold: The current threshold value. The function will compute a new
        threshold and return the exponential moving average using the current
        value of threshold

    Returns:
      new_threshold: The new value of the threshold based on weights, and
        sparsity at the current global_step
      new_mask: A numpy array of the same size and shape as weights containing
        0 or 1 to indicate which of the values in weights falls below
        the threshold

    Raises:
      ValueError: if sparsity is not defined
    """
        if self._sparsity is None:
            raise ValueError('Sparsity variable undefined')

        sparsity = self._get_sparsity(weights.op.name)
        with tf.name_scope(weights.op.name + '_pruning_ops'):
            abs_weights = tf.abs(weights)
            k = tf.cast(
                tf.round(
                    tf.cast(tf.size(abs_weights), tf.float32) *
                    (1 - sparsity)), tf.int32)
            # Sort the entire array
            values, _ = tf.nn.top_k(tf.reshape(abs_weights, [-1]),
                                    k=tf.size(abs_weights))
            # Grab the (k-1) th value
            current_threshold = tf.gather(values, k - 1)
            smoothed_threshold = tf.add_n([
                tf.multiply(current_threshold, 1 - self._spec.threshold_decay),
                tf.multiply(threshold, self._spec.threshold_decay)
            ])

            new_mask = tf.cast(
                tf.greater_equal(abs_weights, smoothed_threshold), tf.float32)

        return smoothed_threshold, new_mask
예제 #21
0
def CollectVarHistogram(vs_gs):
    """Adds histogram summaries for variables and gradients."""

    for name, (var, grad) in vs_gs.FlattenItems():
        name = py_utils.SanitizeScopeKey(name)
        with tf.device(var.device), tf.name_scope(name + '/summary'):
            if isinstance(grad, tf.IndexedSlices):
                var = tf.gather(var, grad.indices)
                grad = grad.values
            if var.dtype.is_complex:
                var = tf.abs(var)
                grad = tf.abs(grad)

        if py_utils.IsEagerMode():
            histogram_v2(f'var_hist/{name}', var)
            histogram_v2(f'grad_hist/{name}', grad)
        else:
            histogram(f'var_hist/{name}', var)
            histogram(f'grad_hist/{name}', grad)
예제 #22
0
def GetSentenceEmbeddings(inputs, segment_id):
  """Returns the average sentence embedding to gate by.

  Example::

    inputs: <tf.Variable 'Variable:0' shape=(10, 3) dtype=float64, numpy=
             array([[0.41258181, 0.61071571, 0.63777673],
                    [0.65571443, 0.54297766, 0.10288261],
                    [0.8577837 , 0.81915847, 0.61996602],
                    [0.46897136, 0.92662692, 0.32942232],
                    [0.60162383, 0.3385829 , 0.3408632 ],
                    [0.40774807, 0.86139635, 0.00927162],
                    [0.56126334, 0.51748817, 0.07791397],
                    [0.06595223, 0.95529216, 0.34458149],
                    [0.1238971 , 0.49897169, 0.25216722],
                    [0.11221774, 0.50284604, 0.84106974]])>
    segment_id: <tf.Variable 'Variable:0' shape=(10,) dtype=int64,
                 numpy=array([1, 1, 2, 0, 0, 3, 3, 3, 3, 0])>

  Args:
    inputs: G`SM Tensor.
    segment_id: G`S Tensor.

  Returns:
    sentence_embeddings: GSM Tensor that is an average of the input embeddings
    per segment.
  """
  reshaped_inputs = tf.reshape(inputs, [-1, inputs.shape[-1]])

  # We set num_segments to a large value so that shape is known at compile time.
  max_segments = py_utils.GetShape(reshaped_inputs)[0]
  # We change the padding to be max_segments - 1 instead of 0 because
  # tf.math.unsorted_segment_mean because it only accepts values between 1 and
  # max_segments.
  modified_segment_id = tf.cast(
      segment_id + max_segments * tf.cast(
          tf.equal(segment_id, 0), dtype=tf.dtypes.as_dtype(segment_id.dtype)) -
      1,
      dtype=tf.int32)
  reshaped_segment_id = tf.reshape(modified_segment_id, [-1])

  # Takes the mean of all segments, w/ 0s for the padding.
  params = tf.concat([
      tf.math.unsorted_segment_mean(reshaped_inputs, reshaped_segment_id,
                                    max_segments)[:-1],
      tf.zeros([1, reshaped_inputs.shape[-1]], dtype=reshaped_inputs.dtype)
  ],
                     axis=0)
  raw_sentence_embeddings = tf.gather(params, modified_segment_id)

  # sentence_embedding: <tf.Tensor: shape=(10, 3), dtype=float64, numpy=
  #                     array([[0.92657252, 0.40264503, 0.55494457],
  #                            [0.92657252, 0.40264503, 0.55494457],
  #                            [0.08002721, 0.02360659, 0.63688627],
  #                            [0.        , 0.        , 0.        ],
  #                            [0.        , 0.        , 0.        ],
  #                            [0.8138629 , 0.54451293, 0.48802852],
  #                            [0.8138629 , 0.54451293, 0.48802852],
  #                            [0.8138629 , 0.54451293, 0.48802852],
  #                            [0.8138629 , 0.54451293, 0.48802852],
  #                            [0.        , 0.        , 0.        ]])>
  sentence_embeddings = tf.reshape(raw_sentence_embeddings, inputs.shape)

  return sentence_embeddings
예제 #23
0
 def _Slicing():
     # Choose a random set of indices.
     indices = tf.range(actual_num)
     indices = tf.random_shuffle(indices, seed=seed)[:num_points_out]
     return [tf.gather(t, indices, axis=0) for t in tensor_list]
예제 #24
0
    def _AddNoise(self, batch):
        """Adding noise the src (see https://arxiv.org/pdf/1711.00043).

    This function implement 3 types of noise (hyparams defined in
    self.params.denoise):
    1) slightly shuffle the sentence following p.shuffle_tok_range
    2) randomly drop tokens with probability p.drop_tok_prob
    3) randomly mask tokens with probability p.blank_tok_prob
    The noises are added to the input with probability p.noise_sent_prob.

    Args:
      batch: a `.NestedMap` of the input batch.
    """
        def IsSpecialExample(task_ids, special_task_ids):
            """A utility function indicates whether inputs belong to specific tasks.

      Args:
        task_ids: Task ids for the input batch. Tensor of shape [batch].
        special_task_ids: A list of specified task ids.

      Returns:
        A tensor indicating whether each sample in the batch belong to the
        specified task. Return a tensor of size [batch].
      """
            batch_size = py_utils.GetShape(task_ids)[0]
            return tf.reduce_any(
                tf.equal(
                    tf.expand_dims(task_ids, -1),
                    tf.cast(
                        tf.broadcast_to(
                            special_task_ids,
                            [batch_size, len(special_task_ids)]), tf.int32)),
                -1)

        p = self.params.denoise
        batch_size = tf.shape(batch.src.ids)[0]
        source_max_len = tf.shape(batch.src.ids)[1]

        # Shuffle tokens according to p.shuffle_tok_range
        noise = tf.random.uniform([batch_size, source_max_len], 0,
                                  p.shuffle_tok_range + 1)

        # Don't shuffle eos or padding
        shuffle_tok_range = tf.fill([batch_size, source_max_len],
                                    float(p.shuffle_tok_range))
        shifted_paddings = tf.pad(batch.src.paddings[:, 1:], [[0, 0], [0, 1]],
                                  constant_values=1)
        noise = tf.where(tf.equal(shifted_paddings, 0), noise,
                         shuffle_tok_range)
        indices = tf.broadcast_to(tf.range(source_max_len, dtype=tf.int32),
                                  [batch_size, source_max_len])
        noisy_indices = tf.cast(indices, dtype=tf.float32) + noise
        permutations = tf.argsort(noisy_indices)
        stacked = tf.stack([batch.src.ids, permutations], axis=1)
        denoise_src_ids = tf.stack(tf.map_fn(lambda x: tf.gather(x[0], x[1]),
                                             stacked),
                                   axis=0)

        # Select tokens to drop with probability=p.drop_tok_prob
        random_drop_tok = tf.random.uniform([batch_size, source_max_len])
        # Don't drop eos token
        is_keep_tok = tf.math.logical_or(
            tf.greater(random_drop_tok, p.drop_tok_prob),
            tf.equal(denoise_src_ids, self._src_tokenizer.eos_id))
        denoise_src_ids = tf.ragged.boolean_mask(
            denoise_src_ids,
            is_keep_tok).to_tensor(default_value=0,
                                   shape=tf.shape(batch.src.ids))
        denoise_src_paddings = tf.ragged.boolean_mask(
            batch.src.paddings,
            is_keep_tok).to_tensor(default_value=1,
                                   shape=tf.shape(batch.src.ids))

        # Select tokens to blank with probability=p.blank_tok_prob
        # Don't blank eos token
        random_blank_tok = tf.random.uniform([batch_size, source_max_len])
        shifted_paddings = tf.pad(denoise_src_paddings[:, 1:],
                                  [[0, 0], [0, 1]],
                                  constant_values=1)
        is_blank_tok = tf.math.logical_and(
            tf.less(random_blank_tok, p.blank_tok_prob),
            tf.equal(shifted_paddings, 0))
        blank_id = tf.fill([batch_size, source_max_len], p.blank_id)
        denoise_src_ids = tf.where(is_blank_tok, blank_id, denoise_src_ids)

        # Select denoising task examples with probability=p.denoise_sent_prob
        random_uniform_sent = tf.random.uniform([batch_size])
        is_denoise_sent = tf.math.logical_and(
            tf.less(random_uniform_sent, p.noise_sent_prob),
            IsSpecialExample(self._GetTaskIds(batch.src.source_ids[:, 0]),
                             p.task_ids))
        batch.src.ids = tf.where(is_denoise_sent, denoise_src_ids,
                                 batch.src.ids)
        batch.src.paddings = tf.where(is_denoise_sent, denoise_src_paddings,
                                      batch.src.paddings)
        batch.src.ids_indicator = 1 - batch.src.paddings
        batch.src.weights = batch.src.ids_indicator
예제 #25
0
    def _ApplyPacking(self, batch):
        """Packs a given batch.

    Note that this may change the batch size.

    This function packs the input batch and adds .segment_ids and .segment_pos
    fields to its `src` and `tgt` fields.

    Args:
      batch: a `.NestedMap` of input tensors to be packed. It is modified in
        place.
    """
        src_actual_seq_len = tf.math.reduce_sum(tf.cast(
            batch.src.ids_indicator, tf.int32),
                                                axis=1)
        tgt_actual_seq_len = tf.math.reduce_sum(tf.cast(
            batch.tgt.ids_indicator, tf.int32),
                                                axis=1)
        summary_utils.histogram('source_seq_lengths', src_actual_seq_len)
        summary_utils.histogram('target_seq_lengths', tgt_actual_seq_len)

        if not self.params.packing_factor:
            # Supply segment_ids and segment_pos with no packing.
            batch.src.segment_ids = batch.src.ids_indicator
            batch.src.segment_pos = _GetSegmentPos(batch.src.ids_indicator)
            batch.tgt.segment_ids = batch.tgt.ids_indicator
            batch.tgt.segment_pos = _GetSegmentPos(batch.tgt.ids_indicator)
            return

        (src_segment_ids, src_segment_pos, src_indices_in_input,
         tgt_segment_ids, tgt_segment_pos,
         tgt_indices_in_input) = ops.pack_sequences(
             src_actual_seq_len, tgt_actual_seq_len, self._ScaledBatchSize(),
             self.params.source_max_length, self.params.target_max_length)

        uniq_src_indices_in_input = tf.unique(
            tf.reshape(src_indices_in_input, [-1])).y
        uniq_tgt_indices_in_input = tf.unique(
            tf.reshape(tgt_indices_in_input, [-1])).y
        summary_utils.histogram(
            'packed_source_seq_lengths',
            tf.gather(src_actual_seq_len, uniq_src_indices_in_input, axis=0))
        summary_utils.histogram(
            'packed_target_seq_lengths',
            tf.gather(tgt_actual_seq_len, uniq_tgt_indices_in_input, axis=0))

        # Ratio of number of non-padded tokens. If < 1.0, we are dropping
        # input data due to p.packing_factor too high.
        src_orig_tokens_count = tf.cast(tf.reduce_sum(src_actual_seq_len),
                                        tf.float32)
        src_packed_tokens_count = tf.reduce_sum(
            tf.cast(src_segment_ids > 0, tf.float32))
        summary_utils.scalar('examples/src_packed_token_ratio',
                             src_packed_tokens_count / src_orig_tokens_count)
        tgt_orig_tokens_count = tf.cast(tf.reduce_sum(tgt_actual_seq_len),
                                        tf.float32)
        tgt_packed_tokens_count = tf.reduce_sum(
            tf.cast(tgt_segment_ids > 0, tf.float32))
        summary_utils.scalar('examples/tgt_packed_token_ratio',
                             tgt_packed_tokens_count / tgt_orig_tokens_count)

        # We deferred adding .paddings and use its complement .ids_indicator
        # exclusively so that we can apply the packing with padding set to 0 for all
        # fields.
        def ApplyPackingToSource(x):
            if x.dtype == tf.string:
                return ops.apply_packing(x, '\t', src_segment_ids,
                                         src_indices_in_input)
            return ops.apply_packing(x, 0, src_segment_ids,
                                     src_indices_in_input)

        src_paddings = ops.apply_packing(batch.src.paddings, 1,
                                         src_segment_ids, src_indices_in_input)
        batch.src = batch.src.Transform(ApplyPackingToSource)
        batch.src.paddings = src_paddings
        batch.src.segment_ids = tf.cast(src_segment_ids, tf.float32)
        batch.src.segment_pos = src_segment_pos

        def ApplyPackingToTarget(x):
            if x.dtype == tf.string:
                return ops.apply_packing(x, '\t', tgt_segment_ids,
                                         tgt_indices_in_input)
            return ops.apply_packing(x, 0, tgt_segment_ids,
                                     tgt_indices_in_input)

        tgt_paddings = ops.apply_packing(batch.tgt.paddings, 1,
                                         tgt_segment_ids, tgt_indices_in_input)
        batch.tgt = batch.tgt.Transform(ApplyPackingToTarget)
        batch.tgt.paddings = tgt_paddings
        batch.tgt.segment_ids = tf.cast(tgt_segment_ids, tf.float32)
        batch.tgt.segment_pos = tgt_segment_pos

        # The number of examples is indicated by the segment_ids of the target.
        num_segments = tf.math.reduce_max(batch.tgt.segment_ids, axis=1)
        num_examples = tf.reduce_sum(num_segments)
        # Note that this is per infeed value when p.use_per_host_infeed = True.
        metric_name = 'examples/num_packed_examples'
        summary_utils.scalar(metric_name, num_examples)
예제 #26
0
        def Step(recurrent_theta, state0, inputs):
            """Computes one decoder step."""
            if p.use_recurrent:
                del inputs
            with tf.name_scope('single_sampler_step'):
                # Compute logits and states.
                bs_result, bs_state1 = pre_step_callback(
                    decoder_theta,
                    recurrent_theta.encoder_outputs,
                    tf.expand_dims(state0.ids, 1),  # [batch, 1].
                    state0.bs_state,
                    p.num_hyps_per_beam,
                    0)  # cur_step
                batch = tf.shape(bs_result.log_probs)[0]
                state1 = py_utils.NestedMap(timestep=state0.timestep + 1)
                state1.logits = bs_result.log_probs
                sample_logits = state1.logits
                # Perform Nucleus Sampling. Assumes logits are in (-1e10, 1e3).
                if p.nucleus_p < 1.0:
                    max_logit = 1e3
                    min_logit = -1e10
                    sorted_logits = tf.sort(sample_logits,
                                            direction='DESCENDING',
                                            axis=-1)
                    sorted_probs = tf.nn.softmax(sorted_logits)
                    cumsum_probs = tf.math.cumsum(sorted_probs,
                                                  axis=-1,
                                                  exclusive=True)
                    masked_logits = tf.where(
                        cumsum_probs < p.nucleus_p, sorted_logits,
                        tf.ones_like(sorted_logits) * max_logit)
                    threshold = tf.math.reduce_min(masked_logits,
                                                   axis=-1,
                                                   keepdims=True)
                    sample_logits = tf.where(
                        sample_logits < threshold,
                        tf.ones_like(sorted_logits) * min_logit, sample_logits)
                # Note that here, we retain the possibility of applying both top_k
                # and nucleus filtering.
                if p.top_k > 0:
                    topk_logits, topk_ids = tf.math.top_k(sample_logits,
                                                          k=p.top_k)
                    sample_logits = tf.nn.log_softmax(
                        topk_logits) if p.top_k_renormalize else topk_logits

                # Sample ids from logits. [batch].
                ids = tf.reshape(
                    tf.random.stateless_categorical(
                        sample_logits / p.temperature,
                        num_samples=1,
                        seed=tf.stack(
                            [recurrent_theta.random_seed, state0.timestep]),
                        dtype=state0.ids.dtype,
                        name='sample_next_id'), [batch])
                state1.ids = tf.gather(topk_ids, ids, axis=1,
                                       batch_dims=1) if p.top_k > 0 else ids

                if 'is_last_chunk' in bs_result and p.target_eoc_id >= 0:
                    state1.ids = tf.where(
                        tf.math.logical_and(
                            bs_result.is_last_chunk,
                            tf.equal(state1.ids, p.target_eoc_id)),
                        tf.fill(tf.shape(state1.ids), p.target_eos_id),
                        state1.ids)
                state1.bs_state = post_step_callback(
                    decoder_theta, recurrent_theta.encoder_outputs, state1.ids,
                    bs_state1)
            if p.use_recurrent:
                return state1, py_utils.NestedMap()
            else:
                inputs.ids = inputs.ids.write(state0.timestep, state1.ids)
                inputs.logits = inputs.logits.write(state0.timestep,
                                                    state1.logits)
                return (recurrent_theta, state1, inputs)
예제 #27
0
  def _Pack(self, batch):
    """Packs a given batch.

    Note that this may change the batch size.

    This function packs the input batch and adds .segment_ids and .segment_pos
    fields to its `src` and `tgt` fields.

    Args:
      batch: a `.NestedMap` of input tensors to be packed. It is modified in
        place.
    """
    src_actual_seq_len = tf.math.reduce_sum(
        tf.cast(batch.src.ids_indicator, tf.int32), axis=1)
    tgt_actual_seq_len = tf.math.reduce_sum(
        tf.cast(batch.tgt.ids_indicator, tf.int32), axis=1)
    summary_utils.histogram('source_seq_lengths', src_actual_seq_len)
    summary_utils.histogram('target_seq_lengths', tgt_actual_seq_len)

    if not self.params.packing_factor:
      # Supply segment_ids and segment_pos with no packing.
      batch.src.segment_ids = batch.src.ids_indicator
      batch.src.segment_pos = _GetSegmentPos(batch.src.ids_indicator)
      batch.tgt.segment_ids = batch.tgt.ids_indicator
      batch.tgt.segment_pos = _GetSegmentPos(batch.tgt.ids_indicator)
      return

    (src_segment_ids, src_segment_pos, src_indices_in_input, tgt_segment_ids,
     tgt_segment_pos, tgt_indices_in_input) = ops.pack_sequences(
         src_actual_seq_len, tgt_actual_seq_len, self._ScaledBatchSize(),
         self.params.source_max_length, self.params.target_max_length)

    uniq_src_indices_in_input = tf.unique(
        tf.reshape(src_indices_in_input, [-1])).y
    uniq_tgt_indices_in_input = tf.unique(
        tf.reshape(tgt_indices_in_input, [-1])).y
    summary_utils.histogram(
        'packed_source_seq_lengths',
        tf.gather(src_actual_seq_len, uniq_src_indices_in_input, axis=0))
    summary_utils.histogram(
        'packed_target_seq_lengths',
        tf.gather(tgt_actual_seq_len, uniq_tgt_indices_in_input, axis=0))

    # We deferred adding .paddings and use its complement .ids_indicator
    # exclusively so that we can apply the packing with padding set to 0 for all
    # fields.
    def ApplyPackingToSource(x):
      if x.dtype == tf.string:
        return ops.apply_packing(x, '\t', src_segment_ids, src_indices_in_input)
      return ops.apply_packing(x, 0, src_segment_ids, src_indices_in_input)

    batch.src = batch.src.Transform(ApplyPackingToSource)
    batch.src.segment_ids = tf.cast(src_segment_ids, tf.float32)
    batch.src.segment_pos = src_segment_pos

    def ApplyPackingToTarget(x):
      if x.dtype == tf.string:
        return ops.apply_packing(x, '\t', tgt_segment_ids, tgt_indices_in_input)
      return ops.apply_packing(x, 0, tgt_segment_ids, tgt_indices_in_input)

    batch.tgt = batch.tgt.Transform(ApplyPackingToTarget)
    batch.tgt.segment_ids = tf.cast(tgt_segment_ids, tf.float32)
    batch.tgt.segment_pos = tgt_segment_pos
예제 #28
0
def ExtractBlockContextV2(x,
                          block_size,
                          left_context,
                          right_context,
                          padding_val=0.0,
                          paddings=None):
    """Extracts temporal context for every block (without restrictions).

  This is a generalized implementation of ExtractBlockContext, where block_size,
  left_context, and right_context are 3 free parameters and we don't have
  constraints (other than l>=1, r>=0, block_size>0).

  Args:
    x: a tensor of [batch, time, dim].
    block_size: int. Number of time frames in a block.
    left_context: int. Left context size. Note that the actual left context is
      `left_context - 1` (this is to be compatible with ExtractBlockContext
      implementation).
    right_context: int. Right context size.
    padding_val: float. value on the padded frames.
    paddings: optional. If specified, it must be a tensor of [batch, time], and
      we will return a padding tensor indicating padding info for the returned
      tensor.

  Returns:
    (x_patches, x_paddings) where

    - x_patches: A tensor of
      [batch, num_blocks, context_size + block_size, dim] with necessary
      paddings, where context_size = (left_context - 1) + right_context,
      and output[:, i, ...] are
      x[:, start-left_context+1:end+right_context, ...], where
      start = i * block_size, end = (i + 1) * block_size.
    - x_paddings: None if paddings = None; else a
      [batch, num_blocks, context_size + block_size] tensor, indicating the
      padding info for the corresponding position in x_patches.

  Let's define some variables here:

  B: batch size
  T: input tensor length in time axis
  D: input tensor dimension in the last axis
  W: block size
  U: ceil(T/W)
  L: left context size
  R: right context size
  C: L-1+W+R, full block length

  Given a [B, T, D] tensor, the return is a [B, U, C, D] tensor
  where ret[b, u, :] is a length of 2D tensor in a shape (L - 1 + W + R, D),
  which is a u-th block of the input tensor with (L - 1) left context frames
  and R right context frames.

  Implementation note:

  We use the following procedure to get the return tensor

  - first do padding in the beginning and at the end:
    [B, T, D] -> [B, L - 1 + U*W + L - 1 + R, D]
  - add one extra axis
    [B, L-1+U*W+R, D] -> [B, L-1+U*W+R, D, 1]
  - use gather to gather blocks
    [B, L-1+U*W+R+L-1, D, 1] -> [B, U, C, D]

  TODO(yqw): after verfiying correctness and benchmark, consider replace v1
  implementation?
  """
    # 0. basic shapes
    b, t, d = py_utils.GetShape(x, 3)
    w = block_size
    u = (t + w - 1) // w  # equivalent to math.ceil(t/w)
    l = left_context
    r = right_context
    c = l - 1 + r + w

    # the only constraints are block_size > 0 , l >= 1, r>=0
    if w <= 0:
        raise ValueError(f'block size ({w}) must be greater than 0')
    if l < 1:
        raise ValueError(f'Left context ({left_context}) must be >= 1.')
    if r < 0:
        raise ValueError(f'Right context ({right_context}) must be >= 0')
    if paddings is not None:
        paddings = py_utils.HasShape(paddings, [b, t])

    # 1. do front and rear padding
    left_pad = l - 1
    # we need to make sure all u * w elements have enough long context
    right_pad = (u * w - t + l - 1 + r)
    x_padded = _DoPadding(x,
                          b,
                          left_pad,
                          right_pad,
                          d,
                          padding_val=padding_val)
    if paddings is not None:
        paddings = _DoPadding(paddings,
                              b,
                              left_pad,
                              right_pad,
                              d=None,
                              padding_val=1.0)

    # 2. generate gather indices
    # gather_indices is a [u, c] matrix like
    #  [ 0, .........,             c-1]
    #  [ w, .........,       w + (c-1)]
    #  [2w, ..........,     2w + (c-1)]
    #  [(u-1)*w, ...., (u-1)*w + (c-1)]
    gather_indices = (tf.tile(tf.expand_dims(tf.range(0, c), axis=0), (u, 1)) +
                      tf.tile(tf.expand_dims(tf.range(0, u * w, w), axis=1),
                              (1, c)))

    # 3. generate x_patches, shape [b, u, c, d]
    x_patches = tf.gather(x_padded, gather_indices, axis=1)

    if paddings is not None:
        # gather is now a [b, u, c] tensor
        paddings = tf.gather(paddings, gather_indices, axis=1)

    return x_patches, paddings
예제 #29
0
        def PreBeamSearchStepCallback(theta, encoder_outputs, step_ids, states,
                                      num_hyps_per_beam, *args, **kwargs):
            """Wrapper for adding bias to _PreBeamSearchStateCallback.

      Biases results.log_probs towards provided encoder_outputs.targets.

      Args:
        theta: a NestedMap of parameters.
        encoder_outputs: a NestedMap computed by encoder.
        step_ids: A tensor of shape [tgt_batch, 1].
        states: A `.NestedMap` of tensors representing states that the clients
          would like to keep track of for each of the active hyps.
        num_hyps_per_beam: Beam size.
        *args: additional arguments to _PreBeamSearchStepCallback.
        **kwargs: additional arguments to _PreBeamSearchStepCallback.

      Returns:
        A tuple (results, out_states).
        results: A `.NestedMap` of beam search results.
          atten_probs:
            The updated attention probs, of shape [tgt_batch, src_len].
          log_probs:
            Log prob for each of the tokens in the target vocab. This is of
            shape
            [tgt_batch, vocab_size].
        out_states: a `.NestedMap` The updated states. The states relevant here
          are:
          time_step: A scalar indicating current step of decoder.  Must be
            provided and maintained by subclass.
          consistent: A boolean vector of shape [tgt_batch, ] which tracks
              whether each hypothesis has exactly matched
              encoder_outputs.targets
              so far.
      """
            p = self.params
            time_step = states.time_step
            bs_results, out_states = self._PreBeamSearchStepCallback(
                theta, encoder_outputs, step_ids, states, num_hyps_per_beam,
                *args, **kwargs)

            labels = encoder_outputs.targets.labels
            weights = encoder_outputs.targets.weights

            def TileForBeamAndFlatten(tensor):
                tensor = tf.reshape(tensor, [1, -1])  # [1, src_batch]
                tensor = tf.tile(
                    tensor,
                    [num_hyps_per_beam, 1])  # [num_hyps_per_beam, src_batch]
                tgt_batch = tf.shape(step_ids)[
                    0]  # num_hyps_per_beam*src_batch
                return tf.reshape(tensor, [tgt_batch])

            # Consistent if step_ids == labels from previous step
            # TODO(navari): Consider updating consistent only if weights > 0. Then
            # re-evaluate the need for bias_only_if_consistent=True.
            # Note that prev_label is incorrrect for step 0 but is overridden later
            prev_label = TileForBeamAndFlatten(
                tf.gather(labels, tf.maximum(time_step - 1, 0), axis=1))
            is_step0 = tf.equal(time_step, 0)
            local_consistence = tf.logical_or(
                is_step0, tf.equal(prev_label, tf.squeeze(step_ids, 1)))
            out_states.consistent = tf.logical_and(states.consistent,
                                                   local_consistence)

            # get label, weight slices corresponding to current time_step
            label = TileForBeamAndFlatten(tf.gather(labels, time_step, axis=1))
            weight = TileForBeamAndFlatten(
                tf.gather(weights, time_step, axis=1))
            if p.bias_only_if_consistent:
                weight = weight * tf.cast(out_states.consistent, p.dtype)

            # convert from dense label to sparse label probs
            vocab_size = tf.shape(bs_results.log_probs)[1]
            uncertainty = tf.constant(
                1e-10,
                p.dtype)  # avoid 0 probs which may cause issues with log
            label_probs = tf.one_hot(label,
                                     vocab_size,
                                     on_value=1 - uncertainty,
                                     off_value=uncertainty /
                                     tf.cast(vocab_size - 1, p.dtype),
                                     dtype=p.dtype)  # [tgt_batch, vocab_size]
            pred_probs = tf.exp(bs_results.log_probs)

            # interpolate predicted probs and label probs
            weight = tf.expand_dims(weight, 1)
            probs = py_utils.with_dependencies([
                py_utils.assert_less_equal(weight, 1.),
                py_utils.assert_greater_equal(weight, 0.)
            ], (1.0 - weight) * pred_probs + weight * label_probs)

            bs_results.log_probs = tf.log(probs)

            return bs_results, out_states