def TrainAndDecodeEpoch(i, host_device): """Train and decode infeed for an epoch. Args: i: host index. host_device: host device string Returns: Decode with control deps on train node. """ train_infeed_fn = lambda: self._train_input.CreatePerHostEnqueueOp(i) decode_infeed_fn = lambda: self._decode_input.CreatePerHostEnqueueOp(i) tf.logging.info('self._train_steps_per_loop: %d', self._train_steps_per_loop) tf.logging.info('self._decode_steps_per_loop: %d', self._decode_steps_per_loop) train = wrap_computation_in_while_loop(train_infeed_fn, self._train_steps_per_loop, host_device) with tf.device(host_device): with tf.control_dependencies([train]): decode = wrap_computation_in_while_loop(decode_infeed_fn, self._decode_steps_per_loop, host_device) return decode
def IncBy(self, delta): """Increment the counter by delta and return the new value.""" # NOTE: We must ensure _value is computed (_var + 0) before # updating _var with delta. delta = tf.cast(delta, tf.int64) with tf.control_dependencies([self._value]): scalar(self._name, self._value) return tf.identity(tf.assign_add(self._var, delta))
def MaybeGuaranteeConstGetter(getter, name, *args, **kwargs): global _CONST_GUARANTEE if _CONST_GUARANTEE: with tf.control_dependencies(None): return tf.guarantee_const(getter(name, *args, **kwargs), name=name + '/GuaranteeConst') else: return getter(name, *args, **kwargs)
def _ApplyAndReset(): with tf.control_dependencies([ self._opt.Apply( lr, py_utils.ApplyGradMultiplier(var_grad, 1. / p.accum_steps)) ]): return tf.group(*[ tf.assign(a, tf.zeros_like(a)) for _, a in var_grad.Flatten() ])
def FProp(self, theta, inputs, paddings=None): """Apply batch normalization. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. inputs: The inputs tensor. Shaped [..., dim]. paddings: The paddings tensor. Shaped [..., 1], with the same rank as the input tensor. Returns: Output after applying batch normalization, with the same shape as 'inputs'. """ p = self.params if paddings is None: paddings = self._GetDefaultPaddings(inputs) with tf.name_scope(p.name): norm_mean, norm_variance, beta, gamma = self.ComputeAndUpdateMoments( theta, inputs, paddings) with tf.control_dependencies([ py_utils.assert_greater_equal( norm_variance, tf.zeros_like(norm_variance)), py_utils.assert_shape_match([tf.shape(inputs)[-1]], tf.shape(norm_mean)), py_utils.assert_shape_match([tf.shape(inputs)[-1]], tf.shape(norm_variance)), ]): if p.use_fused_batch_norm_for_eval and self.do_eval: bn_output, _, _ = nn.fused_batch_norm(inputs, gamma, beta, norm_mean, norm_variance, self._epsilon, is_training=False) else: bn_output = tf.nn.batch_normalization( inputs, norm_mean, norm_variance, beta, gamma, self._epsilon) if p.set_padded_output_to_zero: bn_output *= 1.0 - paddings return bn_output
def FProp(self, theta, current_step): p = self.params with tf.name_scope(p.name): steps = self._best_step best_step = steps[0] last_step = steps[1] ref_step = tf.maximum(self._ref_step, best_step) f = self._cur_factor # Decay if no improvement within window. new_factor = tf.where(last_step - ref_step < p.window, f, tf.maximum(p.min_factor, f * p.decay)) # Update ref_step if we decayed. new_step = tf.where(tf.equal(new_factor, f), ref_step, last_step) update_step = tf.assign(self._ref_step, new_step) with tf.control_dependencies([update_step]): return tf.assign(self._cur_factor, new_factor)
def Finalize(self): """Finishes creation of the overall figure, returning the image summary.""" subplot_grid_shape = self._subplot_grid_shape if subplot_grid_shape is None: subplot_grid_shape = (len(self._subplots), 1) # AddMatplotlibFigureSummary (due to restrictions of py_func) only supports # flattened list of tensors so we must do some bookkeeping to maintain a # mapping from _SubplotMetadata object to flattened_tensors. subplot_slices = [] flattened_tensors = [] for subplot in self._subplots: start = len(flattened_tensors) subplot_slices.append((start, start + len(subplot.tensor_list))) flattened_tensors.extend(subplot.tensor_list) def PlotFunc(fig, *numpy_data_list): gs = gridspec.GridSpec(*subplot_grid_shape, **self._gridspec_kwargs) for n, subplot in enumerate(self._subplots): axes = fig.add_subplot(gs[n]) start, end = subplot_slices[n] subplot_data = numpy_data_list[start:end] subplot.plot_func(fig, axes, *subplot_data) func = functools.partial(_RenderMatplotlibFigures, self._figsize, self._max_outputs, PlotFunc) batch_sizes = [tf.shape(t)[0] for t in flattened_tensors] num_tensors = len(flattened_tensors) with tf.control_dependencies([ tf.assert_equal(batch_sizes, [batch_sizes[0]] * num_tensors, summarize=num_tensors) ]): rendered = tf.py_func(func, flattened_tensors, tf.uint8, name='RenderMatplotlibFigures') return tf.summary.image(self._name, rendered, max_outputs=self._max_outputs)
def TpuTrainStep(*args): """Train a shard of a batch on a single TPU core. Args: *args: metrics values from previous steps. Returns: New summed metrics values and a train_op. """ self._model = self._task_params.Instantiate() self._task = self._model.GetTask() self._task.AddChild('input', self._input) self._model.ConstructFPropBPropGraph() per_step_eval_metrics = self._eval_metrics.SetMetrics( self._task.eval_metrics, args) outfeed_op = self._OutfeedEnqueue(self._task.per_example_tensors) summed_metrics = [] assert len(per_step_eval_metrics) == len(args) with tf.control_dependencies([outfeed_op]): for x, y in zip(per_step_eval_metrics, args): summed_metrics.append(x + y) return summed_metrics + [self._model.GetTask().train_op]
def TrainAndDecode(): with tf.control_dependencies([TpuTrain()]): return DecodeLoopFn()
def computation(i): ops = op_fn() if not isinstance(ops, list): ops = [ops] with tf.control_dependencies(ops): return tf.Print(i + 1, [i], 'while_loop:')