コード例 #1
0
    def _inverse_pth_root_graph(self, epsilon):
        graph = tf.Graph()
        with graph.as_default():
            exponent_t = tf.reshape(
                tf.placeholder(dtype=tf.float32, name="exponent", shape=None),
                [])
            # Apply exponent multiplier.
            exponent_t = exponent_t * self._exponent_multiplier
            input_t = tf.placeholder(dtype=tf.float32,
                                     name="input",
                                     shape=None)
            # For p = 2, 4 or 8, we use the iterative Newton-Schur method for
            # computing the inverse-pth root.
            either_p_2_4_8 = tf.logical_or(
                tf.logical_or(tf.equal(-1.0 / exponent_t, 2),
                              tf.equal(-1.0 / exponent_t, 4)),
                tf.equal(-1.0 / exponent_t, 8))
            # 4096 is the larger dimension SVD is tractable for.
            greater_than_4096 = tf.greater(tf.shape(input_t)[0], 4096)
            run_specialized_iterative_method = tf.logical_and(
                greater_than_4096, either_p_2_4_8)
            specialized_fn = functools.partial(
                self._specialized_inverse_pth_root, input_t, exponent_t,
                epsilon)
            generalized_fn = functools.partial(
                self._generalized_inverse_pth_root, input_t, exponent_t,
                epsilon)
            output, diff = tf.cond(run_specialized_iterative_method,
                                   specialized_fn, generalized_fn)

            tf.identity(output, "output")
            tf.identity(tf.cast(diff, tf.float32), "diff")
        return graph.as_graph_def().SerializeToString()
コード例 #2
0
def infinite_repeat(body_fn, infeed_queue):
    """Builds infinite loop.

  Args:
    body_fn: a Python function that builds the loop body.
    infeed_queue: if not None, the infeed queue from which to append a tuple of
      arguments as inputs to condition.

  Returns:
    The final values of the loop-carried tensors.
  """
    def to_list(x):
        if isinstance(x, (list, tuple)):
            return list(x)
        else:
            return [x]

    def body_fn_wrapper(i, *args):
        return [i + 1] + to_list(body_fn(*args))

    outputs = training_loop.while_loop(
        # Infinite loop. Using only tf.constant(True) causes the XLA graph to
        # appear to be stateful.
        lambda i, *args: tf.logical_or(tf.constant(True), i < 10000),
        body_fn_wrapper,
        inputs=[0],
        infeed_queue=infeed_queue)
    outputs = to_list(outputs)
    if len(outputs) == 1:
        # Returns the Op rather than an empty list.
        return outputs[0].op
    else:
        return outputs[1:]
コード例 #3
0
    def dec_callback(self, tgt_id, tgt_pos, tgt_segment_id, tgt_mask,
                     dec_state, t):
        del tgt_pos, tgt_segment_id

        [buf] = dec_state
        if tgt_id.shape == (self.batch_size, self.beam_size):
            buf = inplace_ops.alias_inplace_update(buf, t, tgt_id)
        else:
            div = int(tgt_id.shape[1] // self.beam_size)
            for i, x_i in enumerate(tf.split(tgt_id, div, 1)):
                buf = inplace_ops.alias_inplace_update(buf, t + i, x_i)

        buf1 = tf.transpose(buf, [1, 0, 2])
        buf1 = tf.reshape(buf1,
                          [self.batch_size, self.max_steps * self.beam_size])

        # select next_tgt_id as a function of previous target tokens
        if self.rule == '+1':
            next_tgt_id = (tgt_id + 1)
            next_tgt_id %= self.vocab_size
        elif self.rule == 'sum':
            # sum over all previous tokens in tgt_mask
            next_tgt_id = tf.einsum('BT,BKT->BK', buf1,
                                    tf.cast(tgt_mask, tf.int32))
            next_tgt_id %= self.vocab_size
        elif self.rule == 'fib':
            # select last token according to tgt_mask
            m = tgt_mask
            m *= tf.cast(
                tf.equal(tf.cumsum(m, -1),
                         tf.reduce_sum(m, -1, keepdims=True) - 1), m.dtype)
            last_tgt_id = tf.einsum('BT,BKT->BK', buf1, tf.cast(m, tf.int32))
            next_tgt_id = (last_tgt_id + tgt_id) % self.vocab_size

        # with a lower probably add extra +1 to the correct next_tgt_id
        n = self.vocab_size
        logits = 5 * tf.one_hot(next_tgt_id % n, n)
        logits += 4 * tf.one_hot((next_tgt_id + 1) % n, n)
        logits += 3 * tf.one_hot((next_tgt_id + 2) % n, n)
        logits += 2 * tf.one_hot((next_tgt_id + 3) % n, n)
        logits += 1 * tf.one_hot((next_tgt_id + 4) % n, n)

        # increase eos_score if current tgt_id contains 9
        eos_id = 0
        tgt_id_contains_9 = tf.logical_or(tf.equal(tgt_id % 10, 9),
                                          tf.equal((tgt_id // 10) % 10, 9))
        logits += 9 * tf.einsum('V,BK->BKV', tf.one_hot(
            eos_id, self.vocab_size), tf.cast(tgt_id_contains_9, tf.float32))

        # tie-breaking -- lower token id wins a little bit
        tie = np.arange(0., 1., 1. / n)
        tie /= tie.sum()
        logits -= tie

        logits = tf.nn.log_softmax(logits)

        dec_state = [buf]
        return logits, dec_state
コード例 #4
0
    def _GreedySearchStep(self, theta, encoder_outputs, cur_step, step_ids,
                          hyp_ids, hyp_lens, done_hyps, other_states,
                          pre_beam_search_step_callback,
                          post_beam_search_step_callback):
        """Extend greedy search hyps for one step.

    Args:
      theta: A `.NestedMap` object containing weights' values of the decoder
        layer and its children layers.
      encoder_outputs: A `.NestedMap` containing encoder outputs to be passed to
        the callbacks.
      cur_step: A scalar int tensor, the current time step, 0-based.
      step_ids: An int tensor of shape [num_hyps, 1]. The input ids to the
        current search step.
      hyp_ids: An int tensor of shape [num_hyps, tgt_seq_len].
      hyp_lens: Valid length of all the hyps. Tokens after eos ids are not
        counted.
      done_hyps: Whether or not a hyp has finished.
      other_states: A `.NestedMap` of other beam search states. This
        `.NestedMap` is managed and updated by the client. It is expected that
        each of its member tensors are of rank >= 1. t[i, ...] is the state of
        the i-th hyp at the beginning of this search step.
      pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback.
        See class header comments for more details.
      post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback.
        See class header comments for more details.

    Returns:
      A tuple of following elements for the next greedy search step,
      (next step, new_step_ids, hyp_ids, hyp_lens, done_hyps, other_states)
    """
        p = self.params
        # Increment hyp_lens by 1 if the hyp is not finished yet.
        hyp_lens = hyp_lens + (1 - tf.cast(done_hyps, tf.int32))

        bs_results, new_other_states = pre_beam_search_step_callback(
            theta,
            encoder_outputs,
            step_ids,
            other_states,
            num_hyps_per_beam=1)
        new_step_ids = tf.arg_max(bs_results.log_probs, 1)
        new_step_ids = tf.cast(new_step_ids, tf.int32)
        new_step_ids = tf.reshape(new_step_ids, tf.shape(step_ids))
        final_other_states = post_beam_search_step_callback(
            theta, encoder_outputs, new_step_ids, new_other_states)

        # Stash new_step_ids into the right slot.
        new_step_ids_1d = tf.reshape(new_step_ids, [-1])
        hyp_ids = inplace_ops.alias_inplace_update(hyp_ids, cur_step,
                                                   new_step_ids_1d)
        # Update done_hyps if the current step_ids is the end of sequence token.
        done_hyps = tf.logical_or(done_hyps,
                                  tf.equal(new_step_ids_1d, p.target_eos_id))

        return (cur_step + 1, new_step_ids, hyp_ids, hyp_lens, done_hyps,
                final_other_states)
コード例 #5
0
ファイル: base_decoder.py プロジェクト: nemo628/lingvo
            def ApplyBias():
                """Bias and update log_probs and consistent."""
                def TileForBeamAndFlatten(tensor):
                    tensor = tf.reshape(tensor, [1, -1])  # [1, src_batch]
                    tensor = tf.tile(tensor,
                                     [num_hyps_per_beam, 1
                                      ])  # [num_hyps_per_beam, src_batch]
                    tgt_batch = tf.shape(step_ids)[
                        0]  # num_hyps_per_beam*src_batch
                    return tf.reshape(tensor, [tgt_batch])

                # Consistent if step_ids == labels from previous step
                # TODO(navari): Consider updating consistent only if weights > 0. Then
                # re-evaluate the need for bias_only_if_consistent=True.
                # Note that prev_label is incorrrect for step 0 but is overridden later
                prev_label = TileForBeamAndFlatten(
                    tf.gather(labels, tf.maximum(time_step - 1, 0), axis=1))
                is_step0 = tf.equal(time_step, 0)
                local_consistence = tf.logical_or(
                    is_step0, tf.equal(prev_label, tf.squeeze(step_ids, 1)))
                consistent = tf.logical_and(states.consistent,
                                            local_consistence)

                # get label, weight slices corresponding to current time_step
                label = TileForBeamAndFlatten(
                    tf.gather(labels, time_step, axis=1))
                weight = TileForBeamAndFlatten(
                    tf.gather(weights, time_step, axis=1))
                if p.bias_only_if_consistent:
                    weight = weight * tf.cast(consistent, p.dtype)

                # convert from dense label to sparse label probs
                vocab_size = tf.shape(bs_results.log_probs)[1]
                uncertainty = tf.constant(
                    1e-10,
                    p.dtype)  # avoid 0 probs which may cause issues with log
                label_probs = tf.one_hot(
                    label,
                    vocab_size,
                    on_value=1 - uncertainty,
                    off_value=uncertainty / tf.cast(vocab_size - 1, p.dtype),
                    dtype=p.dtype)  # [tgt_batch, vocab_size]
                pred_probs = tf.exp(bs_results.log_probs)

                # interpolate predicted probs and label probs
                weight = tf.expand_dims(weight, 1)
                probs = py_utils.with_dependencies([
                    py_utils.assert_less_equal(weight, 1.),
                    py_utils.assert_greater_equal(weight, 0.)
                ], (1.0 - weight) * pred_probs + weight * label_probs)
                return tf.log(probs), consistent
コード例 #6
0
ファイル: pruning.py プロジェクト: snsun/lingvo
 def maybe_update_masks():
     with tf.name_scope(self._spec.name):
         is_step_within_pruning_range = tf.logical_and(
             tf.greater_equal(self._global_step,
                              self._spec.begin_pruning_step),
             # If end_pruning_step is negative, keep pruning forever!
             tf.logical_or(
                 tf.less_equal(self._global_step,
                               self._spec.end_pruning_step),
                 tf.less(self._spec.end_pruning_step, 0)))
         is_pruning_step = tf.less_equal(
             tf.add(self._last_update_step,
                    self._spec.pruning_frequency), self._global_step)
         return tf.logical_and(is_step_within_pruning_range,
                               is_pruning_step)
コード例 #7
0
ファイル: learner.py プロジェクト: lbxcfx/lingvo
    def ScaleGradients(self, var_grads, gradient_adjuster=None):
        """Scales gradients according to training params.

    Args:
      var_grads: a `.NestedMap` whose values are (var, grad) pairs.
      gradient_adjuster: if not None, a function that mutates a given var_grads.

    Returns:
      A `.NestedMap` containing:
      - has_nan_or_inf: a scalar of 0 or 1, indicating whether there is any NaN
        or Inf in input gradients.
      - final_var_grads: a `.NestedMap` whose values are (var, grad) pairs,
        where gradients have already been scaled.
      - grad_scale: the gradient scale. 0 if gradient updates should be skipped
        for the step. (Optional, only returned in case global norm clipping is
        used.)
    """
        p = self.params

        # Computes gradients' norm and adds their summaries. Note that all_grad_norm
        # may be nan, which may cause grad_scale to be nan.
        for name, vg in var_grads.FlattenItems():
            summary_utils.AddNormSummary(name + '/' + p.name,
                                         py_utils.NestedMap(s=vg))
        all_grad_norm = tf.sqrt(
            py_utils.SumSquared([
                g for (_, g) in py_utils.NestedMap(child=var_grads).Flatten()
            ]))
        all_var_norm = tf.sqrt(
            py_utils.SumSquared([
                v for (v, _) in py_utils.NestedMap(child=var_grads).Flatten()
            ]))
        grad_norm_is_nan_or_inf = tf.logical_or(tf.is_nan(all_grad_norm),
                                                tf.is_inf(all_grad_norm))

        # Optional gradient adjustment. Note that this happens after computing
        # all_grad_norm.
        if gradient_adjuster is not None:
            tf.logging.info('gradient_adjuster=%s', gradient_adjuster)
            var_grads = gradient_adjuster(var_grads)

        # Handles NaN/Inf gradients.
        has_nan_or_inf = py_utils.HasNanOrInfGradient(var_grads)
        # Grad norm can still be inf even if none of the individual grad is inf.
        has_nan_or_inf = tf.logical_or(has_nan_or_inf, grad_norm_is_nan_or_inf)

        return_values = py_utils.NestedMap()
        if p.clip_gradient_single_norm_to_value:
            # Currently using both types of clipping simultaneously is unsupported.
            if p.clip_gradient_norm_to_value:
                raise ValueError(
                    'Cannot use clip_gradient_single_norm_to_value=%f and '
                    'clip_gradient_norm_to_value=%f.' %
                    (p.clip_gradient_single_norm_to_value,
                     p.clip_gradient_norm_to_value))
            final_var_grads = py_utils.ApplyGradNormCliping(
                var_grads, p.clip_gradient_single_norm_to_value)

        else:
            grad_scale = self._GetGlobalGradScale(all_grad_norm,
                                                  has_nan_or_inf)
            self._AddEvalMetric('grad_norm/all', all_grad_norm,
                                tf.constant(1.0))
            self._AddEvalMetric('var_norm/all', all_var_norm, tf.constant(1.0))
            self._AddEvalMetric('grad_scale_all', grad_scale, tf.constant(1.0))
            final_var_grads = py_utils.ApplyGradMultiplier(
                var_grads, grad_scale)
            return_values.grad_scale = grad_scale

        return_values.has_nan_or_inf = has_nan_or_inf
        return_values.final_var_grads = final_var_grads
        return return_values
コード例 #8
0
ファイル: base_decoder.py プロジェクト: xueyongfu/lingvo
        def PreBeamSearchStepCallback(theta, encoder_outputs, step_ids, states,
                                      num_hyps_per_beam, *args, **kwargs):
            """Wrapper for adding bias to _PreBeamSearchStateCallback.

      Biases results.log_probs towards provided encoder_outputs.targets.

      Args:
        theta: a NestedMap of parameters.
        encoder_outputs: a NestedMap computed by encoder.
        step_ids: A tensor of shape [tgt_batch, 1].
        states: A `.NestedMap` of tensors representing states that the clients
          would like to keep track of for each of the active hyps.
        num_hyps_per_beam: Beam size.
        *args: additional arguments to _PreBeamSearchStepCallback.
        **kwargs: additional arguments to _PreBeamSearchStepCallback.

      Returns:
        A tuple (results, out_states).
        results: A `.NestedMap` of beam search results.
          atten_probs:
            The updated attention probs, of shape [tgt_batch, src_len].
          log_probs:
            Log prob for each of the tokens in the target vocab. This is of
            shape
            [tgt_batch, vocab_size].
        out_states: a `.NestedMap` The updated states. The states relevant here
          are:
          time_step: A scalar indicating current step of decoder.  Must be
            provided and maintained by subclass.
          consistent: A boolean vector of shape [tgt_batch, ] which tracks
              whether each hypothesis has exactly matched
              encoder_outputs.targets
              so far.
      """
            p = self.params
            time_step = states.time_step
            bs_results, out_states = self._PreBeamSearchStepCallback(
                theta, encoder_outputs, step_ids, states, num_hyps_per_beam,
                *args, **kwargs)

            labels = encoder_outputs.targets.labels
            weights = encoder_outputs.targets.weights

            def TileForBeamAndFlatten(tensor):
                tensor = tf.reshape(tensor, [1, -1])  # [1, src_batch]
                tensor = tf.tile(
                    tensor,
                    [num_hyps_per_beam, 1])  # [num_hyps_per_beam, src_batch]
                tgt_batch = tf.shape(step_ids)[
                    0]  # num_hyps_per_beam*src_batch
                return tf.reshape(tensor, [tgt_batch])

            # Consistent if step_ids == labels from previous step
            # TODO(navari): Consider updating consistent only if weights > 0. Then
            # re-evaluate the need for bias_only_if_consistent=True.
            # Note that prev_label is incorrrect for step 0 but is overridden later
            prev_label = TileForBeamAndFlatten(
                tf.gather(labels, tf.maximum(time_step - 1, 0), axis=1))
            is_step0 = tf.equal(time_step, 0)
            local_consistence = tf.logical_or(
                is_step0, tf.equal(prev_label, tf.squeeze(step_ids, 1)))
            out_states.consistent = tf.logical_and(states.consistent,
                                                   local_consistence)

            # get label, weight slices corresponding to current time_step
            label = TileForBeamAndFlatten(tf.gather(labels, time_step, axis=1))
            weight = TileForBeamAndFlatten(
                tf.gather(weights, time_step, axis=1))
            if p.bias_only_if_consistent:
                weight = weight * tf.cast(out_states.consistent, p.dtype)

            # convert from dense label to sparse label probs
            vocab_size = tf.shape(bs_results.log_probs)[1]
            uncertainty = tf.constant(
                1e-10,
                p.dtype)  # avoid 0 probs which may cause issues with log
            label_probs = tf.one_hot(label,
                                     vocab_size,
                                     on_value=1 - uncertainty,
                                     off_value=uncertainty /
                                     tf.cast(vocab_size - 1, p.dtype),
                                     dtype=p.dtype)  # [tgt_batch, vocab_size]
            pred_probs = tf.exp(bs_results.log_probs)

            # interpolate predicted probs and label probs
            weight = tf.expand_dims(weight, 1)
            probs = py_utils.with_dependencies([
                py_utils.assert_less_equal(weight, 1.),
                py_utils.assert_greater_equal(weight, 0.)
            ], (1.0 - weight) * pred_probs + weight * label_probs)

            bs_results.log_probs = tf.log(probs)

            return bs_results, out_states