예제 #1
0
      def TrainAndDecodeEpoch(i, host_device):
        """Train and decode infeed for an epoch.

        Args:
          i: host index.
          host_device: host device string

        Returns:
          Decode with control deps on train node.
        """
        train_infeed_fn = lambda: self._train_input.CreatePerHostEnqueueOp(i)
        decode_infeed_fn = lambda: self._decode_input.CreatePerHostEnqueueOp(i)
        tf.logging.info('self._train_steps_per_loop: %d',
                        self._train_steps_per_loop)
        tf.logging.info('self._decode_steps_per_loop: %d',
                        self._decode_steps_per_loop)
        train = wrap_computation_in_while_loop(train_infeed_fn,
                                               self._train_steps_per_loop,
                                               host_device)
        with tf.device(host_device):
          with tf.control_dependencies([train]):
            decode = wrap_computation_in_while_loop(decode_infeed_fn,
                                                    self._decode_steps_per_loop,
                                                    host_device)
        return decode
 def IncBy(self, delta):
   """Increment the counter by delta and return the new value."""
   # NOTE: We must ensure _value is computed (_var + 0) before
   # updating _var with delta.
   delta = tf.cast(delta, tf.int64)
   with tf.control_dependencies([self._value]):
     scalar(self._name, self._value)
     return tf.identity(tf.assign_add(self._var, delta))
예제 #3
0
def MaybeGuaranteeConstGetter(getter, name, *args, **kwargs):
    global _CONST_GUARANTEE
    if _CONST_GUARANTEE:
        with tf.control_dependencies(None):
            return tf.guarantee_const(getter(name, *args, **kwargs),
                                      name=name + '/GuaranteeConst')
    else:
        return getter(name, *args, **kwargs)
예제 #4
0
 def _ApplyAndReset():
     with tf.control_dependencies([
             self._opt.Apply(
                 lr,
                 py_utils.ApplyGradMultiplier(var_grad,
                                              1. / p.accum_steps))
     ]):
         return tf.group(*[
             tf.assign(a, tf.zeros_like(a))
             for _, a in var_grad.Flatten()
         ])
예제 #5
0
    def FProp(self, theta, inputs, paddings=None):
        """Apply batch normalization.

    Args:
      theta: A `.NestedMap` object containing weights' values of this layer and
        its children layers.
      inputs: The inputs tensor.  Shaped [..., dim].
      paddings: The paddings tensor.  Shaped [..., 1], with the same rank as the
        input tensor.

    Returns:
      Output after applying batch normalization, with the same shape as
      'inputs'.
    """
        p = self.params
        if paddings is None:
            paddings = self._GetDefaultPaddings(inputs)
        with tf.name_scope(p.name):
            norm_mean, norm_variance, beta, gamma = self.ComputeAndUpdateMoments(
                theta, inputs, paddings)
            with tf.control_dependencies([
                    py_utils.assert_greater_equal(
                        norm_variance, tf.zeros_like(norm_variance)),
                    py_utils.assert_shape_match([tf.shape(inputs)[-1]],
                                                tf.shape(norm_mean)),
                    py_utils.assert_shape_match([tf.shape(inputs)[-1]],
                                                tf.shape(norm_variance)),
            ]):
                if p.use_fused_batch_norm_for_eval and self.do_eval:
                    bn_output, _, _ = nn.fused_batch_norm(inputs,
                                                          gamma,
                                                          beta,
                                                          norm_mean,
                                                          norm_variance,
                                                          self._epsilon,
                                                          is_training=False)
                else:
                    bn_output = tf.nn.batch_normalization(
                        inputs, norm_mean, norm_variance, beta, gamma,
                        self._epsilon)

                if p.set_padded_output_to_zero:
                    bn_output *= 1.0 - paddings

            return bn_output
예제 #6
0
    def FProp(self, theta, current_step):
        p = self.params
        with tf.name_scope(p.name):

            steps = self._best_step
            best_step = steps[0]
            last_step = steps[1]

            ref_step = tf.maximum(self._ref_step, best_step)
            f = self._cur_factor

            # Decay if no improvement within window.
            new_factor = tf.where(last_step - ref_step < p.window, f,
                                  tf.maximum(p.min_factor, f * p.decay))
            # Update ref_step if we decayed.
            new_step = tf.where(tf.equal(new_factor, f), ref_step, last_step)
            update_step = tf.assign(self._ref_step, new_step)
            with tf.control_dependencies([update_step]):
                return tf.assign(self._cur_factor, new_factor)
예제 #7
0
    def Finalize(self):
        """Finishes creation of the overall figure, returning the image summary."""
        subplot_grid_shape = self._subplot_grid_shape
        if subplot_grid_shape is None:
            subplot_grid_shape = (len(self._subplots), 1)

        # AddMatplotlibFigureSummary (due to restrictions of py_func) only supports
        # flattened list of tensors so we must do some bookkeeping to maintain a
        # mapping from _SubplotMetadata object to flattened_tensors.
        subplot_slices = []
        flattened_tensors = []
        for subplot in self._subplots:
            start = len(flattened_tensors)
            subplot_slices.append((start, start + len(subplot.tensor_list)))
            flattened_tensors.extend(subplot.tensor_list)

        def PlotFunc(fig, *numpy_data_list):
            gs = gridspec.GridSpec(*subplot_grid_shape,
                                   **self._gridspec_kwargs)
            for n, subplot in enumerate(self._subplots):
                axes = fig.add_subplot(gs[n])
                start, end = subplot_slices[n]
                subplot_data = numpy_data_list[start:end]
                subplot.plot_func(fig, axes, *subplot_data)

        func = functools.partial(_RenderMatplotlibFigures, self._figsize,
                                 self._max_outputs, PlotFunc)
        batch_sizes = [tf.shape(t)[0] for t in flattened_tensors]
        num_tensors = len(flattened_tensors)
        with tf.control_dependencies([
                tf.assert_equal(batch_sizes, [batch_sizes[0]] * num_tensors,
                                summarize=num_tensors)
        ]):
            rendered = tf.py_func(func,
                                  flattened_tensors,
                                  tf.uint8,
                                  name='RenderMatplotlibFigures')
        return tf.summary.image(self._name,
                                rendered,
                                max_outputs=self._max_outputs)
예제 #8
0
      def TpuTrainStep(*args):
        """Train a shard of a batch on a single TPU core.

        Args:
          *args: metrics values from previous steps.

        Returns:
          New summed metrics values and a train_op.
        """
        self._model = self._task_params.Instantiate()
        self._task = self._model.GetTask()
        self._task.AddChild('input', self._input)
        self._model.ConstructFPropBPropGraph()
        per_step_eval_metrics = self._eval_metrics.SetMetrics(
            self._task.eval_metrics, args)
        outfeed_op = self._OutfeedEnqueue(self._task.per_example_tensors)
        summed_metrics = []
        assert len(per_step_eval_metrics) == len(args)
        with tf.control_dependencies([outfeed_op]):
          for x, y in zip(per_step_eval_metrics, args):
            summed_metrics.append(x + y)
        return summed_metrics + [self._model.GetTask().train_op]
예제 #9
0
 def TrainAndDecode():
   with tf.control_dependencies([TpuTrain()]):
     return DecodeLoopFn()
예제 #10
0
 def computation(i):
   ops = op_fn()
   if not isinstance(ops, list):
     ops = [ops]
   with tf.control_dependencies(ops):
     return tf.Print(i + 1, [i], 'while_loop:')