def apply_gradients(self, grads_and_vars, global_step=None, name=None): if self._num_micro_batches == 1: return self._opt.apply_gradients(grads_and_vars, global_step) global_step = global_step or py_utils.GetOrCreateGlobalStepVar() with tf.init_scope(): self._create_slots([v for (_, v) in grads_and_vars]) accums = [] variables = [] for g, v in grads_and_vars: accum = self.get_slot(v, 'grad_accum') variables.append(v) # pytype: disable=attribute-error if isinstance(g, tf.IndexedSlices): scaled_grad = tf.IndexedSlices( g.values / self._num_micro_batches, g.indices, dense_shape=g.dense_shape) else: scaled_grad = g / self._num_micro_batches accum_tensor = accum.read_value() accums.append(accum.assign(accum_tensor + scaled_grad)) # pytype: enable=attribute-error def _ApplyAndReset(): normalized_accums = accums if self._apply_crs_to_grad: normalized_accums = [ tf.tpu.cross_replica_sum(accum.read_value()) for accum in accums ] apply_op = self._opt.apply_gradients( list(zip(normalized_accums, variables))) with tf.control_dependencies([apply_op]): zero_op = [tf.assign(accum, tf.zeros_like(accum)) for accum in accums] return tf.group(zero_op, tf.assign_add(global_step, 1)) def _Accum(): return tf.no_op() accum_step = tf.cond( tf.equal( tf.math.floormod(self._counter + 1, self._num_micro_batches), 0), _ApplyAndReset, # Apply the accumulated gradients and reset. _Accum) # Accumulate gradients. with tf.control_dependencies([tf.group(accums)]): return tf.group(accum_step, tf.assign_add(self._counter, 1))
def testVariableThetaValue(self): with self.session(): layer_p = TestLayer.Params().Set(name='test') layer = layer_p.Instantiate() tf.global_variables_initializer().run() self.assertAllClose(layer.vars.w.eval(), layer.theta.w.eval()) b_eval = layer.vars.b.eval() self.assertAllClose(b_eval, layer.theta.b.eval()) self.assertAllClose(b_eval, layer._private_theta['b'].eval()) # theta reflects the vars change. new_b = layer.vars.b.assign(tf.ones_like(layer.vars.b) * 3.) with tf.control_dependencies([new_b]): self.assertAllClose(b_eval * 3., new_b.eval()) self.assertAllClose(layer.vars.b.eval(), new_b.eval()) self.assertAllClose(layer.vars.b.eval(), layer.theta.b.eval())
def IsWithinBBox(points, bbox): """Checks if points are within a 2-d bbox. The function returns true if points are strictly inside the box. It also returns true when the points are exactly on the box edges. Args: points: a float Tensor of shape [..., 2] of points to be tested. The last coordinates are (x, y). bbox: a float Tensor of shape [..., 4, 2] of bboxes. The last coordinates are the four corners of the bbox and (x, y). The corners are assumed to be given in counter-clockwise order. Returns: Tensor: If ``pshape = tf.shape(points)[:-1]`` and ``bshape = tf.shape(bbox)[:-2]``, returns a boolean tensor of shape ``tf.concat(pshape, bshape)``, where each element is true if the point is inside to the corresponding box. If a point falls exactly on an edge of the bbox, it is also true. """ bshape = py_utils.GetShape(bbox)[:-2] pshape = py_utils.GetShape(points)[:-1] bbox = py_utils.HasShape(bbox, bshape + [4, 2]) points = py_utils.HasShape(points, pshape + [2]) # Enumerate all 4 edges: v1, v2, v3, v4 = (bbox[..., 0, :], bbox[..., 1, :], bbox[..., 2, :], bbox[..., 3, :]) v1v2v3_check = tf.reduce_all(_IsCounterClockwiseDirection(v1, v2, v3)) v2v3v4_check = tf.reduce_all(_IsCounterClockwiseDirection(v2, v3, v4)) v4v1v2_check = tf.reduce_all(_IsCounterClockwiseDirection(v4, v1, v2)) v3v4v1_check = tf.reduce_all(_IsCounterClockwiseDirection(v3, v4, v1)) with tf.control_dependencies([ py_utils.Assert(v1v2v3_check, [v1, v2, v3]), py_utils.Assert(v2v3v4_check, [v3, v3, v4]), py_utils.Assert(v4v1v2_check, [v4, v1, v2]), py_utils.Assert(v3v4v1_check, [v3, v4, v1]) ]): is_inside = tf.math.logical_and( tf.math.logical_and(_IsOnLeftHandSideOrOn(points, v1, v2), _IsOnLeftHandSideOrOn(points, v2, v3)), tf.math.logical_and(_IsOnLeftHandSideOrOn(points, v3, v4), _IsOnLeftHandSideOrOn(points, v4, v1))) # Swap the last two dimensions. ndims = is_inside.shape.ndims return tf.transpose(is_inside, list(range(ndims - 2)) + [ndims - 1, ndims - 2])
def FProp(self, theta, inputs, paddings=None): """Apply batch normalization. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. inputs: The inputs tensor. Shaped [..., dim]. paddings: The paddings tensor. Shaped [..., 1], with the same rank as the input tensor. Returns: Output after applying batch normalization, with the same shape as 'inputs'. """ p = self.params if paddings is None: paddings = self._GetDefaultPaddings(inputs) with tf.name_scope(p.name): norm_mean, norm_variance, beta, gamma = self.ComputeAndUpdateMoments( theta, inputs, paddings) with tf.control_dependencies([ py_utils.assert_greater_equal( norm_variance, tf.zeros_like(norm_variance)), py_utils.assert_shape_match([tf.shape(inputs)[-1]], tf.shape(norm_mean)), py_utils.assert_shape_match([tf.shape(inputs)[-1]], tf.shape(norm_variance)), ]): if p.use_fused_batch_norm_for_eval and self.do_eval: bn_output, _, _ = nn.fused_batch_norm(inputs, gamma, beta, norm_mean, norm_variance, self._epsilon, is_training=False) else: bn_output = tf.nn.batch_normalization( inputs, norm_mean, norm_variance, beta, gamma, self._epsilon) if p.set_padded_output_to_zero: bn_output *= 1.0 - paddings return bn_output
def Value(self): p = self.params with tf.name_scope(p.name): steps = self._best_step best_step = steps[0] last_step = steps[1] ref_step = tf.maximum(self.theta.ref_step, best_step) f = self.theta.cur_factor # Decay if no improvement within window. new_factor = tf.where(last_step - ref_step < p.window, f, tf.maximum(p.min_factor, f * p.decay)) # Update ref_step if we decayed. new_step = tf.where(tf.equal(new_factor, f), ref_step, last_step) update_step = tf.assign(self.vars.ref_step, new_step) with tf.control_dependencies([update_step]): return tf.assign(self.vars.cur_factor, new_factor)
def ExtractLogMelFeatures(wav_bytes_t): """Create Log-Mel Filterbank Features from raw bytes. Args: wav_bytes_t: Tensor representing raw wav file as a string of bytes. It is currently assumed that the wav file is encoded at 16KHz (see DecodeWav, below). Returns: A Tensor representing three stacked log-Mel filterbank energies, sub-sampled every three frames. """ # We want to use these parameters exactly. def _CreateAsrFrontend(): """Parameters corresponding to default ASR frontend.""" p = asr_frontend.MelAsrFrontend.Params() p.sample_rate = 16000. p.frame_size_ms = 25. p.frame_step_ms = 10. p.num_bins = 80 p.lower_edge_hertz = 125. p.upper_edge_hertz = 7600. p.preemph = 0.97 p.noise_scale = 0. p.pad_end = False return p.Instantiate() sample_rate, audio = DecodeWav(wav_bytes_t) audio *= 32768 # Remove channel dimension, since we have a single channel. audio = tf.squeeze(audio, axis=1) # TODO(drpng): make batches. audio = tf.expand_dims(audio, axis=0) static_sample_rate = 16000 mel_frontend = _CreateAsrFrontend() with tf.control_dependencies( [tf.assert_equal(sample_rate, static_sample_rate)]): outputs = mel_frontend.FPropDefaultTheta( py_utils.NestedMap(src_inputs=audio, paddings=tf.zeros_like(audio))) log_mel = outputs.src_inputs return log_mel
def GetNext(self): """Override of the root's GetNext to support checking repeat sentinel.""" self._InitIterator() if py_utils.GetUnitTestSession(): self.Initialize(py_utils.GetUnitTestSession()) batch = self._iterator[self.host_id].get_next() # Sentinel check. if self._repeat_with_sentinel and not self._repeat_steps: assert_op = tf.debugging.assert_none_equal( batch[self.params.sentinel_key], tf.constant(self.params.sentinel_value), summarize=1, message='REPEAT_SENTINEL_') tf.logging.info('sentinel constant dtype %r', tf.constant(self.params.sentinel_value)) with tf.control_dependencies([assert_op]): # This identity transform will throw tf.errors.InvalidArgumentError # if assert_op fails (sentinel_key takes on sentinel_value). batch = batch.Transform(tf.identity) return batch
def TpuTrainStep(*args): """Train a shard of a batch on a single TPU core. Args: *args: metrics values from previous steps. Returns: New summed metrics values and a train_op. """ self._model = self._task_params.Instantiate() self._model.ConstructFPropBPropGraph() per_step_eval_metrics = self._eval_metrics.SetMetrics( self._model.GetTask().eval_metrics, args) outfeed_op = self._OutfeedEnqueue( self._model.GetTask().per_example_tensors) summed_metrics = [] assert len(per_step_eval_metrics) == len(args) with tf.control_dependencies([outfeed_op]): for x, y in zip(per_step_eval_metrics, args): summed_metrics.append(x + y) return summed_metrics + [self._model.GetTask().train_op]
def TpuTrainStep(*args): """Train a shard of a batch on a single TPU core. Args: *args: metrics values from previous steps. Returns: New summed metrics values and a train_op. """ with py_utils.OpportunisticVariableReuseScope(True): self._model.InstantiateVariables() self._model.ConstructFPropBPropGraph() per_step_eval_metrics = self._eval_metrics.SetMetrics( self._task.eval_metrics, args) outfeed_op = self._OutfeedEnqueue(self._task.per_example_tensors) summed_metrics = [] assert len(per_step_eval_metrics) == len(args) with tf.control_dependencies([outfeed_op]): for x, y in zip(per_step_eval_metrics, args): summed_metrics.append(x + y) return summed_metrics + [self._task.train_op]
def Finalize(self): """Finishes creation of the overall figure, returning the image summary.""" subplot_grid_shape = self._subplot_grid_shape if subplot_grid_shape is None: subplot_grid_shape = (len(self._subplots), 1) # AddMatplotlibFigureSummary (due to restrictions of py_func) only supports # flattened list of tensors so we must do some bookkeeping to maintain a # mapping from _SubplotMetadata object to flattened_tensors. subplot_slices = [] flattened_tensors = [] for subplot in self._subplots: start = len(flattened_tensors) subplot_slices.append((start, start + len(subplot.tensor_list))) flattened_tensors.extend(subplot.tensor_list) def PlotFunc(fig, *numpy_data_list): gs = gridspec.GridSpec(*subplot_grid_shape, **self._gridspec_kwargs) for n, subplot in enumerate(self._subplots): axes = fig.add_subplot(gs[n]) start, end = subplot_slices[n] subplot_data = numpy_data_list[start:end] subplot.plot_func(fig, axes, *subplot_data) func = functools.partial(_RenderMatplotlibFigures, self._figsize, self._max_outputs, PlotFunc) batch_sizes = [tf.shape(t)[0] for t in flattened_tensors] num_tensors = len(flattened_tensors) with tf.control_dependencies([ tf.assert_equal(batch_sizes, [batch_sizes[0]] * num_tensors, summarize=num_tensors) ]): rendered = tf.py_func(func, flattened_tensors, tf.uint8, name='RenderMatplotlibFigures') return tf.summary.image(self._name, rendered, max_outputs=self._max_outputs)
def DecodeWithTheta(self, theta, input_batch): """Constructs the inference graph.""" p = self.params # from IPython import embed; embed() with tf.name_scope('decode'), tf.name_scope(p.name): with tf.name_scope('encoder'): encoder_outputs = self._FrontendAndEncoderFProp( theta, input_batch.src) if p.inference_compute_only_log_softmax: global_step = tf.train.get_global_step() increment_global_step = tf.assign(global_step, global_step + 1) with tf.control_dependencies([increment_global_step]): log_probabilities = tf.transpose(tf.nn.log_softmax( encoder_outputs.encoded, axis=2), perm=(1, 0, 2)) with tf.name_scope('decoder'): decoder_outs = self._DecodeCTC(encoder_outputs) # encoder_outputs's shape is [T,B,F] return { 'log_probabilities': log_probabilities, 'log_probabilities_lengths': py_utils.LengthsFromBitMask(encoder_outputs.padding, 0), 'int64_uttid': input_batch.sample_ids, 'int64_audio_document_id': input_batch.audio_document_ids, 'num_utterances_in_audio_document': input_batch.num_utterances_in_audio_document, 'transcripts': decoder_outs.transcripts, } with tf.name_scope('decoder'): decoder_outs = self._DecodeCTC(encoder_outputs) decoder_metrics = py_utils.RunOnTpuHost(self._CalculateErrorRates, decoder_outs, input_batch) return decoder_metrics
def GetNext(self): """Override of the root's GetNext to support checking repeat sentinel.""" # Use `init_scope()` to ensure that the datasets and iterators are created # outside of the function-building graph. This ensures that these creation # operations are not repeated in subsequent `tf.function` calls. with tf.init_scope(): self._InitIterator() if py_utils.GetUnitTestSession(): self.Initialize(py_utils.GetUnitTestSession()) batch = self._iterator[self.host_id].get_next() # Sentinel check. if self._repeat_with_sentinel and not self._repeat_steps: assert_op = tf.debugging.assert_none_equal( batch[self.params.sentinel_key], tf.constant(self.params.sentinel_value), summarize=1, message='REPEAT_SENTINEL_') tf.logging.info('sentinel constant dtype %r', tf.constant(self.params.sentinel_value)) with tf.control_dependencies([assert_op]): # This identity transform will throw tf.errors.InvalidArgumentError # if assert_op fails (sentinel_key takes on sentinel_value). batch = batch.Transform(tf.identity) return batch
def _BPropGenTrainOps(self, vmap, metrics=None, add_summary=True): """Populates the train_ops dictionary in a backwards pass.""" metrics = metrics or self._metrics bprop_variable_filters = self.input_generator.GetBpropVariableFilters() # Only compute the mask if the variable filters are not empty. if bprop_variable_filters != [''] * len(bprop_variable_filters): self._ComputeGradientMask(bprop_variable_filters) train_ops = {} # mapping from op name to op. gradient_mask = None if self._per_input_gradient_mask: # TODO(neerajgaur): Change this to use source_selected from input_batch. onehot = self.input_generator.GetInputSourceOneHot() gradient_mask = { k: tf.tensordot(v, onehot, 1) for k, v in self._per_input_gradient_mask.items() } all_losses = [] for optimization in self.learners: learner_name = optimization.params.name loss_name = optimization.params.loss_name or learner_name metric = metrics.get(loss_name, None) if metric is None: raise ValueError('Loss %s not found in metrics %s' % (loss_name, list(metrics.keys()))) loss = metric[0] all_losses.append(loss) train_ops['train/%s' % learner_name], eval_metrics = optimization.Apply( loss, vmap, gradient_mask=gradient_mask, gradient_adjuster=self.AdjustGradients) if add_summary: for key, (value, weight) in eval_metrics.items(): self.AddEvalMetric(key + '/' + learner_name, value, weight) relevant_bn_updates, _ = py_utils.FindRelevantBatchNormUpdates( all_losses, tf.get_collection(py_utils.BATCH_NORM_UPDATES)) train_ops['bn_updates'] = relevant_bn_updates var_update_ops = [ tf.group(*tf.nest.flatten(train_ops), name='var_update_ops') ] # Post training step update. with tf.control_dependencies(var_update_ops): post_step_op = self.PostTrainingStepUpdate(self.global_step) train_ops = {} with tf.control_dependencies([post_step_op]): # Get the op to update the weight masks and thresholds mask_update_op = self._GetMaskUpdateOp() train_ops['mask_updates'] = mask_update_op with tf.control_dependencies([mask_update_op]): true_global_step = py_utils.GetOrCreateGlobalStepVar() with tf.ops.colocate_with(true_global_step): increment_global_steps = tf.assign_add(true_global_step, 1) if self._global_step_var != true_global_step: with tf.ops.colocate_with(self._global_step_var): increment_global_steps = tf.group( increment_global_steps, tf.assign_add(self._global_step_var, 1)) train_ops['global_step'] = increment_global_steps # If we are using Tpu Embeddings, generate the monolithic send # gradient op. tpu_embedding_activations = tf.get_collection( py_utils.TPU_EMBEDDING_ACTIVATIONS) if tpu_embedding_activations: tpu_embedding_activations_dict = tpu_embedding_activations[0] tpu_embedding = tf.get_collection(py_utils.TPU_EMBEDDING)[0] tpu_embedding_send_gradient_op = py_utils.ComputeTpuEmbeddingGradients( self.loss, tpu_embedding_activations_dict, tpu_embedding) train_ops['tpu_embedding'] = tpu_embedding_send_gradient_op for op_name, op in train_ops.items(): assert op is not None, op_name return train_ops
def testAccumulator(self): # testAccumulator compares # - explicit averaging of independently computed var_grads1 and # var_grads2, # - Accumulator(SGD) optimizer effectively doing this over 2 steps. np.random.seed(12345) np_input1 = np.random.normal(0.1, 0.5, [2, 4, 3]) np.random.seed(12346) np_input2 = np.random.normal(0.1, 0.5, [2, 4, 3]) with self.session(use_gpu=True, graph=tf.Graph()) as sess: tf.random.set_seed(123456) params = layers.ProjectionLayer.Params() params.name = 'proj' params.dtype = tf.float64 params.input_dim = 3 params.output_dim = 2 params.params_init = py_utils.WeightInit.Gaussian(0.01, 123456) params.batch_norm = False proj_layer = layers.ProjectionLayer(params) inputs1 = tf.placeholder(shape=[2, 4, 3], dtype=tf.float64) in_padding1 = tf.zeros([2, 4, 1], dtype=tf.float64) inputs2 = tf.placeholder(shape=[2, 4, 3], dtype=tf.float64) in_padding2 = tf.zeros([2, 4, 1], dtype=tf.float64) output1 = proj_layer.FPropDefaultTheta(inputs1, in_padding1) output2 = proj_layer.FPropDefaultTheta(inputs2, in_padding2) loss1 = tf.reduce_sum(output1) loss2 = tf.reduce_sum(output2) var_grads1 = py_utils.ComputeGradients(loss1, proj_layer.vars) var_grads2 = py_utils.ComputeGradients(loss2, proj_layer.vars) op = optimizer.SGD.Params() opt = op.Instantiate() lr = 1e-1 with tf.control_dependencies([loss1, loss2]): var_update_op1 = opt.Apply( lr, py_utils.ApplyGradMultiplier(var_grads1, 1. / 2.)) with tf.control_dependencies([var_update_op1]): var_update_op2 = opt.Apply( lr, py_utils.ApplyGradMultiplier(var_grads2, 1. / 2.)) self.evaluate(tf.global_variables_initializer()) vars1 = self.evaluate(proj_layer.vars.Flatten()) loss1_1, grads1_1, loss1_2, grads1_2 = sess.run( [ loss1, var_grads1.Transform(tuple), loss2, var_grads2.Transform(tuple) ], feed_dict={ inputs1: np_input1, inputs2: np_input2, }, ) sess.run([var_update_op2], feed_dict={ inputs1: np_input1, inputs2: np_input2, }) vars1_1 = self.evaluate(proj_layer.vars.Flatten()) with self.session(use_gpu=True, graph=tf.Graph()) as sess: tf.random.set_seed(123456) params = layers.ProjectionLayer.Params() params.name = 'proj' params.dtype = tf.float64 params.input_dim = 3 params.output_dim = 2 params.params_init = py_utils.WeightInit.Gaussian(0.01, 123456) params.batch_norm = False proj_layer = layers.ProjectionLayer(params) in_padding1 = tf.zeros([2, 4, 1], dtype=tf.float64) inputs1 = tf.placeholder(shape=[2, 4, 3], dtype=tf.float64) output1 = proj_layer.FPropDefaultTheta(inputs1, in_padding1) loss = tf.reduce_sum(output1) var_grads = py_utils.ComputeGradients(loss, proj_layer.vars) op = optimizer.Accumulator.Params().Set( accum_steps=2, dtype=tf.float64, optimizer_tpl=optimizer.SGD.Params()) opt = op.Instantiate() lr = 1e-1 with cluster_factory.ForTestingWorker(add_summary=True): var_update_op = opt.Apply(lr, var_grads) increment_global_step_op = tf.assign_add( py_utils.GetOrCreateGlobalStepVar(), 1) self.evaluate(tf.global_variables_initializer()) vars2 = self.evaluate(proj_layer.vars.Flatten()) loss2_1, grads2_1 = sess.run( [loss, var_grads.Transform(tuple)], feed_dict={ inputs1: np_input1, }) loss2_2, grads2_2 = sess.run( [loss, var_grads.Transform(tuple)], feed_dict={ inputs1: np_input2, }) acc_0 = self.evaluate([ v for v in tf.global_variables() if 'grad_accumulator' in v.name ])[0] sess.run([var_update_op], feed_dict={ inputs1: np_input1, }) acc_1 = self.evaluate([ v for v in tf.global_variables() if 'grad_accumulator' in v.name ])[0] vars2_intermediate = self.evaluate(proj_layer.vars.Flatten()) self.evaluate(increment_global_step_op) sess.run([var_update_op], feed_dict={ inputs1: np_input2, }) acc_2 = self.evaluate([ v for v in tf.global_variables() if 'grad_accumulator' in v.name ])[0] vars2_1 = self.evaluate(proj_layer.vars.Flatten()) summary = tf.Summary.FromString( self.evaluate(tf.summary.merge_all())) tf.logging.info(f'summary: {summary}') self.assertEqual(summary.value[0].tag, 'sgd_lr') self.assertAllClose(vars1, vars2) self.assertAllClose(acc_0, np.zeros_like(acc_0)) self.assertAllClose(acc_1, grads2_1['w'][1]) self.assertAllClose(acc_2, np.zeros_like(acc_0)) self.assertAllClose(loss1_1, loss2_1) self.assertAllClose(loss1_2, loss2_2) self.assertAllClose(grads1_1, grads2_1) self.assertAllClose(grads1_2, grads2_2) self.assertAllClose(vars1, vars2_intermediate) self.assertAllClose(vars2[0], grads2_1['w'][0]) self.assertAllClose(vars2[0], grads2_2['w'][0]) self.assertAllClose( vars1[0] - 0.5 * lr * (grads1_1['w'][1] + grads1_2['w'][1]), vars1_1[0]) self.assertAllClose( vars2[0] - 0.5 * lr * (grads2_1['w'][1] + grads2_2['w'][1]), vars2_1[0]) self.assertAllClose(vars2, vars2_intermediate) self.assertAllClose(vars1_1, vars2_1)
def TrainAndDecode(): with tf.control_dependencies([TpuTrain()]): return _DecodeFn()
def _internal_apply_dense(self, grad, var, magnitude_optimizer_apply_fn, direction_optimizer_apply_fn): # pylint: disable=g-doc-args """Main optimization logic of AdaGraft, which calls the child optimizers. Args: grad: Tensor containing gradients. var: Tensor containing parameter values. magnitude_optimizer_apply_fn: Apply magnitude optimizer. direction_optimizer_apply_fn: Apply direction optimizer. Returns: The final update op, which increments var by the grafted step. Pseudocode: - Copy weights into scratch space 'scratch_copy'. - Run magnitude_optimizer in-place. - Use scratch copy to figure out how far we moved ('magnitude_step'). - Copy weights back. - Run direction_optimizer in-place. - Move weights along the line segment with scratch_copy. """ if self.use_global_norm: self._variables.append(var) # Slot with current parameter values scratch_slot = self.get_slot(var, "scratch_copy") old_var = tf.assign(scratch_slot, var) with tf.control_dependencies([old_var]): m_updated_var = magnitude_optimizer_apply_fn(grad, var) # pylint: disable=protected-access # Run magnitude optimizer and compute the norm of the update. with tf.control_dependencies([m_updated_var]): m_step = var - old_var m_step_norm = tf.norm(m_step) if self.diagnostic or self.use_global_norm: m_step_norm = tf.assign(self.get_slot(var, "m_step_norm"), m_step_norm) # Run direction optimizer and compute its norm, and the direction. with tf.control_dependencies([m_step_norm]): flushed_var = tf.assign(var, old_var) with tf.control_dependencies([flushed_var]): d_updated_var = direction_optimizer_apply_fn(grad, var) # pylint: disable=protected-access # Run an update of the direction optimizer with magnitude optimizer norm. with tf.control_dependencies([d_updated_var]): d_step = var - old_var d_step_norm = tf.norm(d_step) if self.diagnostic or self.use_global_norm: d_step_norm = tf.assign(self.get_slot(var, "d_step_norm"), d_step_norm) if self.use_global_norm: flushed_var = tf.assign(var, old_var) with tf.control_dependencies([d_step_norm, flushed_var]): return tf.assign(scratch_slot, d_step) step = tf.where(tf.greater(d_step_norm, 0), (m_step_norm / tf.maximum(d_step_norm, 1e-30)) * d_step, tf.zeros_like(d_step)) return tf.assign(var, old_var + self._learning_rate_tensor * step)
def FProp(self, theta, inputs, paddings=None): """Apply group normalization. Args: theta: A NestedMap object containing weights' values of this layer and its children layers. inputs: The inputs tensor with shape [batch_size, height, width, channel]. paddings: The paddings tensor with shape [batch_size, height]. Intended to be used for sequence processing where `height` is `time`. Returns: A single tensor as the output after applying group normalization, with the same shape as 'inputs'. Or a output, output_paddings pair if input paddings is not None. """ p = self.params n, h, w, c = tf.unstack(tf.shape(inputs), axis=0, num=4) group_size = p.dim // p.num_groups num_groups = p.num_groups min_group_size = p.min_group_size if p.dim > p.min_group_size else p.dim if group_size <= min_group_size: group_size = min_group_size num_groups = p.dim // group_size with tf.name_scope(p.name): x = tf.reshape(inputs, [n, h, w, num_groups, group_size]) if paddings is None: counts, means_ss, variance_ss, _, = tf.nn.sufficient_statistics( x, axes=[1, 2, 4], keepdims=True) norm_mean, norm_variance = tf.nn.normalize_moments( counts, means_ss, variance_ss, None) else: expanded_paddings = tf.reshape(paddings, [n, h, 1, 1, 1]) if p.cumulative: norm_mean, norm_variance = ComputeMomentsWithPadding( x, expanded_paddings, reduce_over_dims=[2, 4], cumulative_axis=1, keepdims=True) else: norm_mean, norm_variance = ComputeMomentsWithPadding( x, expanded_paddings, [1, 2, 4], keepdims=True) norm_mean = py_utils.CheckNumerics( norm_mean, 'mean of %s failed numeric check' % p.name) norm_variance = py_utils.CheckNumerics( norm_variance, 'variance of %s failed numeric check' % p.name) beta = theta.beta gamma = theta.gamma t = h if p.cumulative else 1 with tf.control_dependencies([ py_utils.assert_greater_equal( norm_variance, tf.cast(0., norm_variance.dtype)), py_utils.assert_shape_match([n, t, 1, num_groups, 1], tf.shape(norm_mean)), py_utils.assert_shape_match([n, t, 1, num_groups, 1], tf.shape(norm_variance)), ]): x = (x - norm_mean) / tf.sqrt(norm_variance + self._epsilon) x = tf.reshape(x, [n, h, w, c]) gn_output = x * gamma + beta gn_output = tf.reshape(gn_output, [n, h, w, c]) if paddings is None: return gn_output else: return gn_output, paddings
def _BPropForVariables(self, vmap): """Constructs the backward graph.""" bprop_variable_filters = self.input_generator.GetBpropVariableFilters() # Only compute the mask if the variable filters are not empty. if bprop_variable_filters != [''] * len(bprop_variable_filters): self._ComputeGradientMask(bprop_variable_filters) train_ops = {} # mapping from op name to op. gradient_mask = None if self._per_input_gradient_mask: # TODO(neerajgaur): Change this to use source_selected from input_batch. onehot = self.input_generator.GetInputSourceOneHot() gradient_mask = { k: tf.tensordot(v, onehot, 1) for k, v in six.iteritems(self._per_input_gradient_mask) } all_losses = [] for optimization in self.learners: loss_name = optimization.params.name metric = self._metrics.get(loss_name, None) if metric is None: raise ValueError('Loss %s not found in metrics %s' % (loss_name, list(self._metrics.keys()))) loss = metric[0] all_losses.append(loss) train_ops['train/%s' % loss_name], eval_metrics = optimization.Apply( loss, vmap, gradient_mask=gradient_mask, gradient_adjuster=self.AdjustGradients) for key, (value, weight) in six.iteritems(eval_metrics): self.AddEvalMetric(key + '/' + loss_name, value, weight) relevant_bn_updates, _ = py_utils.FindRelevantBatchNormUpdates( all_losses, tf.get_collection(py_utils.BATCH_NORM_UPDATES)) train_ops['bn_updates'] = relevant_bn_updates # Get the op to update the weight masks and thresholds train_ops['mask_updates'] = self._GetMaskUpdateOp() # Post training step update. train_ops['post_step'] = self.PostTrainingStepUpdate(self.global_step) with tf.control_dependencies(tf.nest.flatten(train_ops)): true_global_step = py_utils.GetOrCreateGlobalStepVar() with tf.colocate_with(true_global_step): increment_global_steps = tf.assign_add(true_global_step, 1) if self._global_step_var != true_global_step: with tf.colocate_with(self._global_step_var): increment_global_steps = tf.group( increment_global_steps, tf.assign_add(self._global_step_var, 1)) train_ops['global_step'] = increment_global_steps # If we are using Tpu Embeddings, generate the monolithic send # gradient op. tpu_embedding_activations = tf.get_collection( py_utils.TPU_EMBEDDING_ACTIVATIONS) if tpu_embedding_activations: tpu_embedding_activations_dict = tpu_embedding_activations[0] tpu_embedding = tf.get_collection(py_utils.TPU_EMBEDDING)[0] tpu_embedding_send_gradient_op = py_utils.ComputeTpuEmbeddingGradients( self.loss, tpu_embedding_activations_dict, tpu_embedding) train_ops['tpu_embedding'] = tpu_embedding_send_gradient_op for op_name, op in six.iteritems(train_ops): assert op is not None, op_name # TODO(rpang): try to structure _train_op as: # tf.cond(skip_step, <only update skip stats>, <all updates>) # so that we skip all other updates when a step is skipped. self._train_op = tf.group(*tf.nest.flatten(train_ops), name='bprop')
def FProp(self, theta, inputs, paddings=None): """Apply group normalization. Args: theta: A NestedMap object containing weights' values of this layer and its children layers. inputs: The inputs tensor with shape [batch_size, height, width, channel]. paddings: The paddings tensor with shape [batch_size, height]. Intended to be used for sequence processing where `height` is `time`. Returns: A single tensor as the output after applying group normalization, with the same shape as 'inputs'. Or a output, output_paddings pair if input paddings is not None. """ p = self.params inputs = py_utils.with_dependencies([ py_utils.assert_greater_equal(py_utils.GetRank(inputs), p.input_rank) ], inputs) min_group_size = min(p.min_group_size, p.dim) group_size = max(p.dim // p.num_groups, min_group_size) num_groups = p.dim // group_size input_shape = py_utils.GetShape(inputs) with tf.name_scope(p.name): x = tf.reshape(inputs, input_shape[:-1] + [num_groups, group_size]) expanded_rank = p.input_rank + 1 all_dims = list(range(expanded_rank)) if paddings is None: # Skip d0, d[-2] axes = all_dims[1:-2] + all_dims[-1:] counts, means_ss, variance_ss, _, = tf.nn.sufficient_statistics( x, axes=axes, keepdims=True) norm_mean, norm_variance = tf.nn.normalize_moments( counts, means_ss, variance_ss, None) else: expanded_paddings = tf.reshape( paddings, input_shape[:2] + [1] * (expanded_rank - 2)) # skip the batching and group dim if p.cumulative: # Skip d0, d1 and d[-2] reduce_over_dims = all_dims[2:-2] + all_dims[-1:] norm_mean, norm_variance = ComputeMomentsWithPadding( x, expanded_paddings, reduce_over_dims=reduce_over_dims, cumulative_axis=1, keepdims=True) else: # Skip d0, d[-2] reduce_over_dims = all_dims[1:-2] + all_dims[-1:] norm_mean, norm_variance = ComputeMomentsWithPadding( x, expanded_paddings, reduce_over_dims, keepdims=True) norm_mean = py_utils.CheckNumerics( norm_mean, 'mean of %s failed numeric check' % p.name) norm_variance = py_utils.CheckNumerics( norm_variance, 'variance of %s failed numeric check' % p.name) beta = theta.beta gamma = theta.gamma n = input_shape[0] t = input_shape[1] if p.cumulative else 1 norm_shape = [n, t, 1, num_groups, 1 ] if p.input_rank == 4 else [n, t, num_groups, 1] with tf.control_dependencies([ py_utils.assert_greater_equal( norm_variance, tf.cast(0., norm_variance.dtype)), py_utils.assert_shape_match(norm_shape, tf.shape(norm_mean)), py_utils.assert_shape_match(norm_shape, tf.shape(norm_variance)), ]): x = (x - norm_mean) / tf.sqrt(norm_variance + self._epsilon) x = tf.reshape(x, input_shape) gn_output = x * gamma + beta gn_output = tf.reshape(gn_output, input_shape) if paddings is None: return gn_output else: return gn_output, paddings
def GuaranteeConstGetter(next_creator, **kwargs): if _CONST_GUARANTEE: with tf.control_dependencies(None): name = kwargs['var_name'] + '/GuaranteeConst' return tf.guarantee_const(next_creator(**kwargs), name=name) return next_creator(**kwargs)
def _BPropGenTrainOps(self, vmap, metrics=None, add_summary=True): """Populates the train_ops dictionary in a backwards pass.""" metrics = metrics or self._metrics bprop_variable_filters = self.input_generator.GetBpropVariableFilters() # Only compute the mask if the variable filters are not empty. if bprop_variable_filters != [''] * len(bprop_variable_filters): self._ComputeGradientMask(bprop_variable_filters) train_ops = {} # mapping from op name to op. gradient_mask = None if self._per_input_gradient_mask: # TODO(neerajgaur): Change this to use source_selected from input_batch. onehot = self.input_generator.GetInputSourceOneHot() gradient_mask = { k: tf.tensordot(v, onehot, 1) for k, v in self._per_input_gradient_mask.items() } all_losses = [] for optimization in self.learners: learner_name = optimization.params.name (losses, train_ops['train/%s' % learner_name], eval_metrics) = optimization.Apply( metrics, vmap, gradient_mask=gradient_mask, gradient_adjuster=self.AdjustGradients) all_losses.extend(losses) if add_summary: for key, (value, weight) in eval_metrics.items(): self.AddEvalMetric(key + '/' + learner_name, value, weight) relevant_bn_updates, _ = py_utils.FindRelevantBatchNormUpdates( all_losses, tf.get_collection(py_utils.BATCH_NORM_UPDATES)) train_ops['bn_updates'] = relevant_bn_updates var_update_ops = [ tf.group(*tf.nest.flatten(train_ops), name='var_update_ops') ] # Post training step update. with tf.control_dependencies(var_update_ops): post_step_op = self.PostTrainingStepUpdate() train_ops = {} with tf.control_dependencies([post_step_op]): # Get the op to update the weight masks and thresholds mask_update_op = self._GetMaskUpdateOp() train_ops['mask_updates'] = mask_update_op with tf.control_dependencies([mask_update_op]): true_global_step = py_utils.GetOrCreateGlobalStepVar() with tf.ops.colocate_with(true_global_step): if self.params.defer_global_step_update: increment_global_steps = true_global_step else: increment_global_steps = tf.assign_add(true_global_step, 1) if self._global_step_var != true_global_step: with tf.ops.colocate_with(self._global_step_var): increment_global_steps = tf.group( increment_global_steps, tf.assign_add(self._global_step_var, 1)) train_ops['global_step'] = increment_global_steps # If we are using Tpu Embeddings, generate the monolithic send # gradient op. if tf.get_collection(py_utils.TPU_EMBEDDING): tpu_embedding = tf.get_collection(py_utils.TPU_EMBEDDING)[0] sparse_grads = ( tpu_embedding_gradient.get_gradients_through_dummy_table_variables( tpu_embedding)) tpu_embedding_send_gradient_op = tpu_embedding.generate_send_gradients_op( sparse_grads, py_utils.GetGlobalStep()) train_ops['tpu_embedding'] = tpu_embedding_send_gradient_op tpu_embedding_summary_tensors = tf.get_collection( py_utils.TPU_EMBEDDING_SUMMARY_TENSORS) if add_summary: for name, value, weight in tpu_embedding_summary_tensors: self.AddEvalMetric(name, value, weight, raise_if_already_added=False) for op_name, op in train_ops.items(): assert op is not None, op_name return train_ops
def control_after_assigns(self): if not self._assign_ops: return tf.no_op() with tf.control_dependencies(self._assign_ops): return tf.no_op()