def testBatchNormUpdatesWithUpdateUseGlobalStatsForTraining(self): tf.random.set_seed(398847392) np.random.seed(12345) params = layers.BatchNormLayer.Params() params.name = 'bn' params.dim = 3 params.use_moving_avg_in_training = True params.params_init = py_utils.WeightInit.Gaussian(0.1) bn_layer = layers.BatchNormLayer(params) in_padding1 = tf.zeros([2, 8, 1], dtype=tf.float32) bn_in1 = tf.constant(np.random.normal(0.1, 0.5, [2, 8, 3]), dtype=tf.float32) bn_out = bn_layer.FPropDefaultTheta(bn_in1, in_padding1) sig1 = tf.reduce_sum(bn_out) sig2 = tf.reduce_sum(bn_out * bn_out) # IMPORTANT: Keep these values consistent with the corresponding # test in layers_test.py self.assertAllClose(2.6575434, sig1, atol=1e-5) self.assertAllClose(15.473802, sig2) updates_collection = tf.get_collection(py_utils.BATCH_NORM_UPDATES) l1, l2 = py_utils.FindRelevantBatchNormUpdates(bn_out, updates_collection) self.assertEqual(l1, []) self.assertEqual(l2, [])
def _BPropForVariables(self, vmap): """Constructs the backward graph for the given variables. Args: vmap: a `.NestedMap` of variables. """ p = self.params tp = p.train # Compute gradients. self._var_grads = py_utils.ComputeGradients(self.loss, vmap) # L2 regularizer. if tp.l2_regularizer_weight is not None: l2_loss, self._var_grads = py_utils.AdjustGradientsWithLpLoss( self._var_grads, tp.l2_regularizer_weight, p=2.0) summary_utils.scalar(p, 'l2_loss', l2_loss) # L1 regularizer. if tp.l1_regularizer_weight is not None: l1_loss, self._var_grads = py_utils.AdjustGradientsWithLpLoss( self._var_grads, tp.l1_regularizer_weight, p=1.0) summary_utils.scalar(p, 'l1_loss', l1_loss) # Mask gradients only if the mask is set. if self._per_input_gradient_mask: bprop_onehot = self.input_generator.GetInputSourceOneHot() self._var_grads = py_utils.MaskGradients( self._var_grads, self._per_input_gradient_mask, bprop_onehot) # Apply gradient clipping. has_nan_or_inf, _, self._var_grads = self.ScaleGradients(self._var_grads) # Histogram summary. summary_utils.CollectVarHistogram(p, self._var_grads) lrs = self.lr_schedule.Value(self._global_step) summary_utils.scalar(p, 'lr_schedule', lrs) lr = tp.learning_rate * lrs var_update_op = self.optimizer.Apply(lr, self._var_grads) increment_global_step_ops = [] with tf.colocate_with(self._shared_global_step): increment_global_step_ops.append( tf.assign_add(self._shared_global_step, 1)) if self._task_global_step: with tf.colocate_with(self._task_global_step): increment_global_step_ops.append( tf.assign_add(self._task_global_step, 1)) increment_global_steps = tf.group(*increment_global_step_ops) relevant_bn_updates, _ = py_utils.FindRelevantBatchNormUpdates( self.loss, tf.get_collection(py_utils.BATCH_NORM_UPDATES)) batch_norm_updates = tf.group(*relevant_bn_updates) # Update stats. stats_updates = tf.group( self.IncrementTotalSamples(), self.IncrementTotalNans(tf.to_int32(has_nan_or_inf))) # Post training step update. post_training_step_updates = self.PostTrainingStepUpdate(self._global_step) # Get the op to update the weight masks and thresholds mask_update_op = self._GetMaskUpdateOp() # TODO(rpang): try to structure _train_op as: # tf.cond(skip_step, <only update skip stats>, <all updates>) # so that we skip all other updates when a step is skipped. # if p.contiguous: var_update_op = tf.group(var_update_op, self.last_state_group_op) self._train_op = tf.group( var_update_op, batch_norm_updates, stats_updates, post_training_step_updates, increment_global_steps, mask_update_op, name='train')
def _BPropForVariables(self, vmap): """Constructs the backward graph.""" bprop_variable_filters = self.input_generator.GetBpropVariableFilters() # Only compute the mask if the variable filters are not empty. if bprop_variable_filters != [''] * len(bprop_variable_filters): self._ComputeGradientMask(bprop_variable_filters) train_ops = {} # mapping from op name to op. gradient_mask = None if self._per_input_gradient_mask: # TODO(neerajgaur): Change this to use source_selected from input_batch. onehot = self.input_generator.GetInputSourceOneHot() gradient_mask = { k: tf.tensordot(v, onehot, 1) for k, v in six.iteritems(self._per_input_gradient_mask) } all_losses = [] for optimization in self.learners: loss_name = optimization.params.name metric = self._metrics.get(loss_name, None) if metric is None: raise ValueError('Loss %s not found in metrics %s' % (loss_name, list(self._metrics.keys()))) loss = metric[0] all_losses.append(loss) train_ops['train/%s' % loss_name], eval_metrics = optimization.Apply( loss, vmap, gradient_mask=gradient_mask, gradient_adjuster=self.AdjustGradients) for key, (value, weight) in six.iteritems(eval_metrics): self.AddEvalMetric(key + '/' + loss_name, value, weight) relevant_bn_updates, _ = py_utils.FindRelevantBatchNormUpdates( all_losses, tf.get_collection(py_utils.BATCH_NORM_UPDATES)) train_ops['bn_updates'] = relevant_bn_updates # Get the op to update the weight masks and thresholds train_ops['mask_updates'] = self._GetMaskUpdateOp() # Post training step update. train_ops['post_step'] = self.PostTrainingStepUpdate(self.global_step) with tf.control_dependencies(tf.nest.flatten(train_ops)): true_global_step = py_utils.GetOrCreateGlobalStepVar() with tf.colocate_with(true_global_step): increment_global_steps = tf.assign_add(true_global_step, 1) if self._global_step_var != true_global_step: with tf.colocate_with(self._global_step_var): increment_global_steps = tf.group( increment_global_steps, tf.assign_add(self._global_step_var, 1)) train_ops['global_step'] = increment_global_steps # If we are using Tpu Embeddings, generate the monolithic send # gradient op. tpu_embedding_activations = tf.get_collection( py_utils.TPU_EMBEDDING_ACTIVATIONS) if tpu_embedding_activations: tpu_embedding_activations_dict = tpu_embedding_activations[0] tpu_embedding = tf.get_collection(py_utils.TPU_EMBEDDING)[0] tpu_embedding_send_gradient_op = py_utils.ComputeTpuEmbeddingGradients( self.loss, tpu_embedding_activations_dict, tpu_embedding) train_ops['tpu_embedding'] = tpu_embedding_send_gradient_op for op_name, op in six.iteritems(train_ops): assert op is not None, op_name # TODO(rpang): try to structure _train_op as: # tf.cond(skip_step, <only update skip stats>, <all updates>) # so that we skip all other updates when a step is skipped. self._train_op = tf.group(*tf.nest.flatten(train_ops), name='bprop')
def _BPropGenTrainOps(self, vmap, metrics=None, add_summary=True): """Populates the train_ops dictionary in a backwards pass.""" metrics = metrics or self._metrics bprop_variable_filters = self.input_generator.GetBpropVariableFilters() # Only compute the mask if the variable filters are not empty. if bprop_variable_filters != [''] * len(bprop_variable_filters): self._ComputeGradientMask(bprop_variable_filters) train_ops = {} # mapping from op name to op. gradient_mask = None if self._per_input_gradient_mask: # TODO(neerajgaur): Change this to use source_selected from input_batch. onehot = self.input_generator.GetInputSourceOneHot() gradient_mask = { k: tf.tensordot(v, onehot, 1) for k, v in self._per_input_gradient_mask.items() } all_losses = [] for optimization in self.learners: learner_name = optimization.params.name loss_name = optimization.params.loss_name or learner_name metric = metrics.get(loss_name, None) if metric is None: raise ValueError('Loss %s not found in metrics %s' % (loss_name, list(metrics.keys()))) loss = metric[0] all_losses.append(loss) train_ops['train/%s' % learner_name], eval_metrics = optimization.Apply( loss, vmap, gradient_mask=gradient_mask, gradient_adjuster=self.AdjustGradients) if add_summary: for key, (value, weight) in eval_metrics.items(): self.AddEvalMetric(key + '/' + learner_name, value, weight) relevant_bn_updates, _ = py_utils.FindRelevantBatchNormUpdates( all_losses, tf.get_collection(py_utils.BATCH_NORM_UPDATES)) train_ops['bn_updates'] = relevant_bn_updates var_update_ops = [ tf.group(*tf.nest.flatten(train_ops), name='var_update_ops') ] # Post training step update. with tf.control_dependencies(var_update_ops): post_step_op = self.PostTrainingStepUpdate(self.global_step) train_ops = {} with tf.control_dependencies([post_step_op]): # Get the op to update the weight masks and thresholds mask_update_op = self._GetMaskUpdateOp() train_ops['mask_updates'] = mask_update_op with tf.control_dependencies([mask_update_op]): true_global_step = py_utils.GetOrCreateGlobalStepVar() with tf.ops.colocate_with(true_global_step): increment_global_steps = tf.assign_add(true_global_step, 1) if self._global_step_var != true_global_step: with tf.ops.colocate_with(self._global_step_var): increment_global_steps = tf.group( increment_global_steps, tf.assign_add(self._global_step_var, 1)) train_ops['global_step'] = increment_global_steps # If we are using Tpu Embeddings, generate the monolithic send # gradient op. tpu_embedding_activations = tf.get_collection( py_utils.TPU_EMBEDDING_ACTIVATIONS) if tpu_embedding_activations: tpu_embedding_activations_dict = tpu_embedding_activations[0] tpu_embedding = tf.get_collection(py_utils.TPU_EMBEDDING)[0] tpu_embedding_send_gradient_op = py_utils.ComputeTpuEmbeddingGradients( self.loss, tpu_embedding_activations_dict, tpu_embedding) train_ops['tpu_embedding'] = tpu_embedding_send_gradient_op for op_name, op in train_ops.items(): assert op is not None, op_name return train_ops
def _BPropGenTrainOps(self, vmap, metrics=None, add_summary=True): """Populates the train_ops dictionary in a backwards pass.""" metrics = metrics or self._metrics bprop_variable_filters = self.input_generator.GetBpropVariableFilters() # Only compute the mask if the variable filters are not empty. if bprop_variable_filters != [''] * len(bprop_variable_filters): self._ComputeGradientMask(bprop_variable_filters) train_ops = {} # mapping from op name to op. gradient_mask = None if self._per_input_gradient_mask: # TODO(neerajgaur): Change this to use source_selected from input_batch. onehot = self.input_generator.GetInputSourceOneHot() gradient_mask = { k: tf.tensordot(v, onehot, 1) for k, v in self._per_input_gradient_mask.items() } all_losses = [] for optimization in self.learners: learner_name = optimization.params.name (losses, train_ops['train/%s' % learner_name], eval_metrics) = optimization.Apply( metrics, vmap, gradient_mask=gradient_mask, gradient_adjuster=self.AdjustGradients) all_losses.extend(losses) if add_summary: for key, (value, weight) in eval_metrics.items(): self.AddEvalMetric(key + '/' + learner_name, value, weight) relevant_bn_updates, _ = py_utils.FindRelevantBatchNormUpdates( all_losses, tf.get_collection(py_utils.BATCH_NORM_UPDATES)) train_ops['bn_updates'] = relevant_bn_updates var_update_ops = [ tf.group(*tf.nest.flatten(train_ops), name='var_update_ops') ] # Post training step update. with tf.control_dependencies(var_update_ops): post_step_op = self.PostTrainingStepUpdate() train_ops = {} with tf.control_dependencies([post_step_op]): # Get the op to update the weight masks and thresholds mask_update_op = self._GetMaskUpdateOp() train_ops['mask_updates'] = mask_update_op with tf.control_dependencies([mask_update_op]): true_global_step = py_utils.GetOrCreateGlobalStepVar() with tf.ops.colocate_with(true_global_step): if self.params.defer_global_step_update: increment_global_steps = true_global_step else: increment_global_steps = tf.assign_add(true_global_step, 1) if self._global_step_var != true_global_step: with tf.ops.colocate_with(self._global_step_var): increment_global_steps = tf.group( increment_global_steps, tf.assign_add(self._global_step_var, 1)) train_ops['global_step'] = increment_global_steps # If we are using Tpu Embeddings, generate the monolithic send # gradient op. if tf.get_collection(py_utils.TPU_EMBEDDING): tpu_embedding = tf.get_collection(py_utils.TPU_EMBEDDING)[0] sparse_grads = ( tpu_embedding_gradient.get_gradients_through_dummy_table_variables( tpu_embedding)) tpu_embedding_send_gradient_op = tpu_embedding.generate_send_gradients_op( sparse_grads, py_utils.GetGlobalStep()) train_ops['tpu_embedding'] = tpu_embedding_send_gradient_op tpu_embedding_summary_tensors = tf.get_collection( py_utils.TPU_EMBEDDING_SUMMARY_TENSORS) if add_summary: for name, value, weight in tpu_embedding_summary_tensors: self.AddEvalMetric(name, value, weight, raise_if_already_added=False) for op_name, op in train_ops.items(): assert op is not None, op_name return train_ops