def _finish(self, update_ops, name_scope): with tf.control_dependencies(update_ops): ops1 = self.magnitude_optimizer._finish([], name_scope + "_m") # pylint: disable=protected-access ops2 = self.direction_optimizer._finish([], name_scope + "_d") # pylint: disable=protected-access if self.use_global_norm: # apply global grafting with tf.control_dependencies([ops1, ops2]): m_global_norm = tf.Variable(0.) d_global_norm = tf.Variable(0.) for var in self._variables: m_step_norm = self.get_slot(var, "m_step_norm") d_step_norm = self.get_slot(var, "d_step_norm") tf.assign_add(m_global_norm, m_step_norm**2) tf.assign_add(d_global_norm, d_step_norm**2) multiplier = tf.sqrt(m_global_norm / tf.maximum(d_global_norm, 1e-30)) step_ops = [] for var in self._variables: d_step = self.get_slot(var, "scratch_copy") step = tf.where(tf.greater(d_step_norm, 0), multiplier * d_step, tf.zeros_like(d_step)) step_op = tf.assign_add( var, self._learning_rate_tensor * step) step_ops.append(step_op) return tf.group(*step_ops, name=name_scope) return tf.group(*([ops1, ops2] + update_ops), name=name_scope)
def _Apply2(proj_layer, opt): inputs1 = np_input1 output1 = proj_layer.FPropDefaultTheta(inputs1, in_padding1) loss2_1 = tf.reduce_sum(output1) var_grads2_1 = py_utils.ComputeGradients(loss2_1, proj_layer.vars) grads2_1 = var_grads2_1.Transform(tuple) inputs1 = np_input2 output1 = proj_layer.FPropDefaultTheta(inputs1, in_padding1) loss2_2 = tf.reduce_sum(output1) var_grads2_2 = py_utils.ComputeGradients(loss2_2, proj_layer.vars) grads2_2 = var_grads2_2.Transform(tuple) with cluster_factory.ForTestingWorker(add_summary=True): _ = opt.Apply(lr, var_grads2_1) # Get `snapshots` of the intermediate variables vars2_intermediate = [v.read_value() for v in proj_layer.vars.Flatten()] tf.assign_add(py_utils.GetOrCreateGlobalStepVar(), 1) with cluster_factory.ForTestingWorker(add_summary=True): _ = opt.Apply(lr, var_grads2_2) vars2_1 = proj_layer.vars.Flatten() return vars2_intermediate, vars2_1, grads2_1, grads2_2
def testDecoderFPropDeterministicAttentionDropout(self): """Verify that attention dropout is deterministic given fixed seeds.""" with self.session(use_gpu=False) as sess: tf.set_random_seed(8372749040) p = self._DecoderParams( py_utils.VariationalNoiseParams(None, True, False, seed=1792)) p.use_while_loop_based_unrolling = False p.attention.atten_dropout_prob = 0.5 p.attention.atten_dropout_deterministic = True loss, per_sequence_loss = self._testDecoderFPropHelper(params=p) global_step = py_utils.GetGlobalStep() tf.global_variables_initializer().run() loss_val, per_sequence_loss_val, global_steps_val = sess.run( [loss, per_sequence_loss, global_step]) print('loss = ', loss_val, 'per sequence loss = ', per_sequence_loss_val) self.assertAllClose([3.587372, 15.0], loss_val) self.assertAllClose([14.171288, 9.965696, 10.221684, 19.451914], per_sequence_loss_val) self.assertEqual(0, global_steps_val) # Run another step to test global_step and time_step are incremented # correctly. sess.run(tf.assign_add(global_step, 1)) loss_val, per_sequence_loss_val, global_steps_val = sess.run( [loss, per_sequence_loss, global_step]) print('loss = ', loss_val, 'per sequence loss = ', per_sequence_loss_val) self.assertAllClose([3.626164, 15.0], loss_val) self.assertAllClose([14.70993, 10.572938, 10.516836, 18.592758], per_sequence_loss_val) self.assertEqual(1, global_steps_val)
def ComputePredictions(self, theta, input_batch): input_data = tf.random.normal([1, 10], dtype=tf.float32) + tf.cast( input_batch, tf.float32) add = tf.assign_add(self.vars.counter1, 1.) input_data += add result = self.ffn.FProp(theta.ffn, input_data) return {'result': result}
def testWeightSpecificSparsity(self): param_list = [ "begin_pruning_step=1", "pruning_frequency=1", "end_pruning_step=100", "target_sparsity=0.5", "weight_sparsity_map=[layer1:0.6,layer2/weights:0.75,.*kernel:0.6]", "threshold_decay=0.0" ] test_spec = ",".join(param_list) pruning_hparams = pruning.get_pruning_hparams().parse(test_spec) with tf.variable_scope("layer1"): w1 = tf.Variable(tf.linspace(1.0, 100.0, 100), name="weights") _ = pruning.apply_mask(w1) with tf.variable_scope("layer2"): w2 = tf.Variable(tf.linspace(1.0, 100.0, 100), name="weights") _ = pruning.apply_mask(w2) with tf.variable_scope("layer3"): w3 = tf.Variable(tf.linspace(1.0, 100.0, 100), name="kernel") _ = pruning.apply_mask(w3) p = pruning.Pruning(pruning_hparams) mask_update_op = p.conditional_mask_update_op() increment_global_step = tf.assign_add(self.global_step, 1) with self.cached_session() as session: tf.global_variables_initializer().run() for _ in range(110): session.run(mask_update_op) session.run(increment_global_step) self.assertAllClose(session.run(pruning.get_weight_sparsity()), [0.6, 0.75, 0.6])
def testConditionalMaskUpdate(self): param_list = [ "pruning_frequency=2", "begin_pruning_step=1", "end_pruning_step=6", "nbins=100" ] test_spec = ",".join(param_list) pruning_hparams = pruning.get_pruning_hparams().parse(test_spec) weights = tf.Variable(tf.linspace(1.0, 100.0, 100), name="weights") masked_weights = pruning.apply_mask(weights) sparsity = tf.Variable(0.00, name="sparsity") # Set up pruning p = pruning.Pruning(pruning_hparams, sparsity=sparsity) p._spec.threshold_decay = 0.0 mask_update_op = p.conditional_mask_update_op() sparsity_val = tf.linspace(0.0, 0.9, 10) increment_global_step = tf.assign_add(self.global_step, 1) non_zero_count = [] with self.cached_session() as session: tf.global_variables_initializer().run() for i in range(10): session.run(tf.assign(sparsity, sparsity_val[i])) session.run(mask_update_op) session.run(increment_global_step) non_zero_count.append(np.count_nonzero(masked_weights.eval())) # Weights pruned at steps 0,2,4,and,6 expected_non_zero_count = [100, 100, 80, 80, 60, 60, 40, 40, 40, 40] self.assertAllEqual(expected_non_zero_count, non_zero_count)
def IncBy(self, delta): """Increment the counter by delta and return the new value.""" # NOTE: We must ensure _value is computed (_var + 0) before # updating _var with delta. delta = tf.cast(delta, tf.int64) with tf.control_dependencies([self._value]): scalar(self._name, self._value) return tf.identity(tf.assign_add(self._var, delta))
def _ApplyAndReset(): normalized_accums = accums if self._apply_crs_to_grad: normalized_accums = [ tf.tpu.cross_replica_sum(accum.read_value()) for accum in accums ] apply_op = self._opt.apply_gradients( list(zip(normalized_accums, variables))) with tf.control_dependencies([apply_op]): zero_op = [tf.assign(accum, tf.zeros_like(accum)) for accum in accums] return tf.group(zero_op, tf.assign_add(global_step, 1))
def apply_gradients(self, grads_and_vars, global_step=None, name=None): if self._num_micro_batches == 1: return self._opt.apply_gradients(grads_and_vars, global_step) global_step = global_step or py_utils.GetOrCreateGlobalStepVar() with tf.init_scope(): self._create_slots([v for (_, v) in grads_and_vars]) accums = [] variables = [] for g, v in grads_and_vars: accum = self.get_slot(v, 'grad_accum') variables.append(v) # pytype: disable=attribute-error if isinstance(g, tf.IndexedSlices): scaled_grad = tf.IndexedSlices(g.values / self._num_micro_batches, g.indices, dense_shape=g.dense_shape) else: scaled_grad = g / self._num_micro_batches accum_tensor = accum.read_value() accums.append(accum.assign(accum_tensor + scaled_grad)) # pytype: enable=attribute-error def _ApplyAndReset(): normalized_accums = accums if self._apply_crs_to_grad: normalized_accums = [ tf.tpu.cross_replica_sum(accum.read_value()) for accum in accums ] apply_op = self._opt.apply_gradients( list(zip(normalized_accums, variables))) with tf.control_dependencies([apply_op]): zero_op = [ tf.assign(accum, tf.zeros_like(accum)) for accum in accums ] return tf.group(zero_op, tf.assign_add(global_step, 1)) def _Accum(): return tf.no_op() accum_step = tf.cond( tf.equal( tf.math.floormod(self._counter + 1, self._num_micro_batches), 0), _ApplyAndReset, # Apply the accumulated gradients and reset. _Accum) # Accumulate gradients. with tf.control_dependencies([tf.group(accums)]): return tf.group(accum_step, tf.assign_add(self._counter, 1))
def _Acc(vg): """Updating accumulators.""" v, g = vg with tf.variable_scope(v.op.name): _, a = py_utils.CreateVariable( 'grad_accumulator', py_utils.WeightParams(v.get_shape(), py_utils.WeightInit.Constant(0.0), self.params.dtype), trainable=False) a = tf.assign_add(a, g) return py_utils.VarGrad(v, a)
def _Acc(vg): """Updating accumulators.""" v, g = vg scope_name = v.name if scope_name.endswith(':0'): scope_name = scope_name[:-2] with tf.variable_scope(scope_name): a = py_utils.CreateVariable( 'grad_accumulator', py_utils.WeightParams(v.get_shape(), py_utils.WeightInit.Constant(0.0), self.params.dtype), trainable=False) a = tf.assign_add(a, g) return py_utils.VarGrad(v, a)
def testDecoderFPropDeterministicAttentionDropout(self): """Verify that attention dropout is deterministic given fixed seeds.""" with self.session(use_gpu=False): tf.random.set_seed(8372749040) p = _DecoderParams( py_utils.VariationalNoiseParams(None, True, False, seed=1792)) p.use_while_loop_based_unrolling = False p.attention.atten_dropout_prob = 0.5 p.attention.atten_dropout_deterministic = True loss, per_sequence_loss = self._testDecoderFPropHelper(params=p) global_step = py_utils.GetGlobalStep() self.evaluate(tf.global_variables_initializer()) loss_val, per_sequence_loss_val, global_steps_val = self.evaluate( [loss, per_sequence_loss, global_step]) print('loss = ', loss_val, 'per sequence loss = ', per_sequence_loss_val) self.assertAllClose([3.332992, 15.0], loss_val) self.assertAllClose([13.942583, 9.632538, 9.677502, 16.742266], per_sequence_loss_val) self.assertEqual(0, global_steps_val) # Run another step to test global_step and time_step are incremented # correctly. self.evaluate(tf.assign_add(global_step, 1)) loss_val, per_sequence_loss_val, global_steps_val = self.evaluate( [loss, per_sequence_loss, global_step]) print('loss = ', loss_val, 'per sequence loss = ', per_sequence_loss_val) self.assertAllClose([3.565631, 15.0], loss_val) self.assertAllClose([14.560061, 10.566417, 10.554007, 17.803982], per_sequence_loss_val) self.assertEqual(1, global_steps_val)
def testAccumulator(self): # testAccumulator compares # - explicit averaging of independently computed var_grads1 and # var_grads2, # - Accumulator(SGD) optimizer effectively doing this over 2 steps. np.random.seed(12345) np_input1 = np.random.normal(0.1, 0.5, [2, 4, 3]) np.random.seed(12346) np_input2 = np.random.normal(0.1, 0.5, [2, 4, 3]) with self.session(use_gpu=True, graph=tf.Graph()) as sess: tf.random.set_seed(123456) params = layers.ProjectionLayer.Params() params.name = 'proj' params.dtype = tf.float64 params.input_dim = 3 params.output_dim = 2 params.params_init = py_utils.WeightInit.Gaussian(0.01, 123456) params.batch_norm = False proj_layer = layers.ProjectionLayer(params) inputs1 = tf.placeholder(shape=[2, 4, 3], dtype=tf.float64) in_padding1 = tf.zeros([2, 4, 1], dtype=tf.float64) inputs2 = tf.placeholder(shape=[2, 4, 3], dtype=tf.float64) in_padding2 = tf.zeros([2, 4, 1], dtype=tf.float64) output1 = proj_layer.FPropDefaultTheta(inputs1, in_padding1) output2 = proj_layer.FPropDefaultTheta(inputs2, in_padding2) loss1 = tf.reduce_sum(output1) loss2 = tf.reduce_sum(output2) var_grads1 = py_utils.ComputeGradients(loss1, proj_layer.vars) var_grads2 = py_utils.ComputeGradients(loss2, proj_layer.vars) op = optimizer.SGD.Params() opt = op.Instantiate() lr = 1e-1 with tf.control_dependencies([loss1, loss2]): var_update_op1 = opt.Apply( lr, py_utils.ApplyGradMultiplier(var_grads1, 1. / 2.)) with tf.control_dependencies([var_update_op1]): var_update_op2 = opt.Apply( lr, py_utils.ApplyGradMultiplier(var_grads2, 1. / 2.)) self.evaluate(tf.global_variables_initializer()) vars1 = self.evaluate(proj_layer.vars.Flatten()) loss1_1, grads1_1, loss1_2, grads1_2 = sess.run( [ loss1, var_grads1.Transform(tuple), loss2, var_grads2.Transform(tuple) ], feed_dict={ inputs1: np_input1, inputs2: np_input2, }, ) sess.run([var_update_op2], feed_dict={ inputs1: np_input1, inputs2: np_input2, }) vars1_1 = self.evaluate(proj_layer.vars.Flatten()) with self.session(use_gpu=True, graph=tf.Graph()) as sess: tf.random.set_seed(123456) params = layers.ProjectionLayer.Params() params.name = 'proj' params.dtype = tf.float64 params.input_dim = 3 params.output_dim = 2 params.params_init = py_utils.WeightInit.Gaussian(0.01, 123456) params.batch_norm = False proj_layer = layers.ProjectionLayer(params) in_padding1 = tf.zeros([2, 4, 1], dtype=tf.float64) inputs1 = tf.placeholder(shape=[2, 4, 3], dtype=tf.float64) output1 = proj_layer.FPropDefaultTheta(inputs1, in_padding1) loss = tf.reduce_sum(output1) var_grads = py_utils.ComputeGradients(loss, proj_layer.vars) op = optimizer.Accumulator.Params().Set( accum_steps=2, dtype=tf.float64, optimizer_tpl=optimizer.SGD.Params()) opt = op.Instantiate() lr = 1e-1 with cluster_factory.ForTestingWorker(add_summary=True): var_update_op = opt.Apply(lr, var_grads) increment_global_step_op = tf.assign_add( py_utils.GetOrCreateGlobalStepVar(), 1) self.evaluate(tf.global_variables_initializer()) vars2 = self.evaluate(proj_layer.vars.Flatten()) loss2_1, grads2_1 = sess.run( [loss, var_grads.Transform(tuple)], feed_dict={ inputs1: np_input1, }) loss2_2, grads2_2 = sess.run( [loss, var_grads.Transform(tuple)], feed_dict={ inputs1: np_input2, }) acc_0 = self.evaluate([ v for v in tf.global_variables() if 'grad_accumulator' in v.name ])[0] sess.run([var_update_op], feed_dict={ inputs1: np_input1, }) acc_1 = self.evaluate([ v for v in tf.global_variables() if 'grad_accumulator' in v.name ])[0] vars2_intermediate = self.evaluate(proj_layer.vars.Flatten()) self.evaluate(increment_global_step_op) sess.run([var_update_op], feed_dict={ inputs1: np_input2, }) acc_2 = self.evaluate([ v for v in tf.global_variables() if 'grad_accumulator' in v.name ])[0] vars2_1 = self.evaluate(proj_layer.vars.Flatten()) summary = tf.Summary.FromString( self.evaluate(tf.summary.merge_all())) tf.logging.info(f'summary: {summary}') self.assertEqual(summary.value[0].tag, 'sgd_lr') self.assertAllClose(vars1, vars2) self.assertAllClose(acc_0, np.zeros_like(acc_0)) self.assertAllClose(acc_1, grads2_1['w'][1]) self.assertAllClose(acc_2, np.zeros_like(acc_0)) self.assertAllClose(loss1_1, loss2_1) self.assertAllClose(loss1_2, loss2_2) self.assertAllClose(grads1_1, grads2_1) self.assertAllClose(grads1_2, grads2_2) self.assertAllClose(vars1, vars2_intermediate) self.assertAllClose(vars2[0], grads2_1['w'][0]) self.assertAllClose(vars2[0], grads2_2['w'][0]) self.assertAllClose( vars1[0] - 0.5 * lr * (grads1_1['w'][1] + grads1_2['w'][1]), vars1_1[0]) self.assertAllClose( vars2[0] - 0.5 * lr * (grads2_1['w'][1] + grads2_2['w'][1]), vars2_1[0]) self.assertAllClose(vars2, vars2_intermediate) self.assertAllClose(vars1_1, vars2_1)
def _BPropGenTrainOps(self, vmap, metrics=None, add_summary=True): """Populates the train_ops dictionary in a backwards pass.""" metrics = metrics or self._metrics bprop_variable_filters = self.input_generator.GetBpropVariableFilters() # Only compute the mask if the variable filters are not empty. if bprop_variable_filters != [''] * len(bprop_variable_filters): self._ComputeGradientMask(bprop_variable_filters) train_ops = {} # mapping from op name to op. gradient_mask = None if self._per_input_gradient_mask: # TODO(neerajgaur): Change this to use source_selected from input_batch. onehot = self.input_generator.GetInputSourceOneHot() gradient_mask = { k: tf.tensordot(v, onehot, 1) for k, v in self._per_input_gradient_mask.items() } all_losses = [] for optimization in self.learners: learner_name = optimization.params.name (losses, train_ops['train/%s' % learner_name], eval_metrics) = optimization.Apply( metrics, vmap, gradient_mask=gradient_mask, gradient_adjuster=self.AdjustGradients) all_losses.extend(losses) if add_summary: for key, (value, weight) in eval_metrics.items(): self.AddEvalMetric(key + '/' + learner_name, value, weight) relevant_bn_updates, _ = py_utils.FindRelevantBatchNormUpdates( all_losses, tf.get_collection(py_utils.BATCH_NORM_UPDATES)) train_ops['bn_updates'] = relevant_bn_updates var_update_ops = [ tf.group(*tf.nest.flatten(train_ops), name='var_update_ops') ] # Post training step update. with tf.control_dependencies(var_update_ops): post_step_op = self.PostTrainingStepUpdate() train_ops = {} with tf.control_dependencies([post_step_op]): # Get the op to update the weight masks and thresholds mask_update_op = self._GetMaskUpdateOp() train_ops['mask_updates'] = mask_update_op with tf.control_dependencies([mask_update_op]): true_global_step = py_utils.GetOrCreateGlobalStepVar() with tf.ops.colocate_with(true_global_step): if self.params.defer_global_step_update: increment_global_steps = true_global_step else: increment_global_steps = tf.assign_add(true_global_step, 1) if self._global_step_var != true_global_step: with tf.ops.colocate_with(self._global_step_var): increment_global_steps = tf.group( increment_global_steps, tf.assign_add(self._global_step_var, 1)) train_ops['global_step'] = increment_global_steps # If we are using Tpu Embeddings, generate the monolithic send # gradient op. if tf.get_collection(py_utils.TPU_EMBEDDING): tpu_embedding = tf.get_collection(py_utils.TPU_EMBEDDING)[0] sparse_grads = ( tpu_embedding_gradient.get_gradients_through_dummy_table_variables( tpu_embedding)) tpu_embedding_send_gradient_op = tpu_embedding.generate_send_gradients_op( sparse_grads, py_utils.GetGlobalStep()) train_ops['tpu_embedding'] = tpu_embedding_send_gradient_op tpu_embedding_summary_tensors = tf.get_collection( py_utils.TPU_EMBEDDING_SUMMARY_TENSORS) if add_summary: for name, value, weight in tpu_embedding_summary_tensors: self.AddEvalMetric(name, value, weight, raise_if_already_added=False) for op_name, op in train_ops.items(): assert op is not None, op_name return train_ops
def FProp(self, theta, x, paddings=None, update=False): """Computes distances of the given input 'x' to all centroids. This implementation applies layer normalization on 'x' internally first, and the returned 'dists' is computed using the normalized 'x'. Args: theta: A `.NestedMap` of weights' values of this layer. x: A tensor of shape [B, L, N, H]. paddings: If not None, a tensor of shape [B, L]. update: bool, whether to update centroids using x. Returns: dists: "distances" of the given input 'x' to all centroids. Shape [B, L, N, K]. k_means_loss: the average squared Euclidean distances to the closest centroid, a scalar. """ p = self.params if paddings is None: paddings = tf.zeros_like(x[:, :, 0, 0]) # Shape [B, L, 1, 1] paddings_4d = paddings[:, :, None, None] if p.apply_layer_norm: x = KMeansClusteringForAtten.LayerNorm(x, p.epsilon) # 'x' is normalized (but theta.means is not), we use negative dot product to # approximate the Euclidean distance here. dists = -tf.einsum('BLNH, NKH -> BLNK', x, theta.means) # For padded positions we update the distances to very large numbers. very_large_dists = tf.ones_like(dists) * tf.constant( 0.1, dtype=dists.dtype) * dists.dtype.max paddings_tiled = tf.tile(paddings_4d, [1, 1, p.num_heads, p.num_clusters]) dists = tf.where(paddings_tiled > 0.0, very_large_dists, dists) # Shape [B, L, N, K], the same as 'dists' above. nearest_one_hot = tf.one_hot( tf.math.argmin(dists, axis=-1), p.num_clusters, dtype=py_utils.FPropDtype(p)) # Same shape as the input 'x'. nearest_centroid = tf.einsum('BLNK, NKH -> BLNH', nearest_one_hot, theta.means) diff = tf.math.squared_difference(x, tf.stop_gradient(nearest_centroid)) diff = py_utils.ApplyPadding(paddings_4d, diff) diff = tf.math.reduce_mean(diff, axis=2) # The commitment loss which when back proped against encourages the 'x' # values to commit to their chosen centroids. k_means_loss = tf.math.reduce_sum(diff) / tf.math.reduce_sum(1.0 - paddings) summary_utils.scalar('k_means/squared_distance_loss', k_means_loss) # TODO(zhouwk): investigate normalizing theta.means after each update. means_norm = tf.norm(theta.means) summary_utils.scalar('k_means/centroid_l2_norm/min', tf.math.reduce_min(means_norm)) summary_utils.scalar('k_means/centroid_l2_norm/mean', tf.math.reduce_mean(means_norm)) if not update: return dists, k_means_loss # To update the centroids (self.vars.means), we apply gradient descent on # the mini-batch of input 'x', which yields the following: # new_centroid = centroid + (1 - decay) * (x_mean - centroid) # where x_mean is the average over all the input vectors closest to this # centroid. # # Note that this approach is equivalent with backprop via # loss = tf.math.reduce_mean( # tf.math.squared_difference(tf.stop_gradient(x), nearest_centroid))) # , except that here the learning rate is independently set via 'decay'. # Ensure that the padded positions are not used to update the centroids. nearest_one_hot = py_utils.ApplyPadding(paddings_4d, nearest_one_hot) # Sum away batch and sequence length dimensions to get per cluster count. # Shape: [N, K] per_cluster_count = tf.reduce_sum(nearest_one_hot, axis=[0, 1]) summary_utils.histogram('k_means/per_cluster_vec_count', per_cluster_count) # Sum of the input 'x' per each closest centroid. sum_x = tf.einsum('BLNK, BLNH -> NKH', nearest_one_hot, x) if py_utils.use_tpu(): per_cluster_count = tf.tpu.cross_replica_sum(per_cluster_count) sum_x = tf.tpu.cross_replica_sum(sum_x) # If per_cluster_count for a cluster is 0, then 'nearest_one_hot' in that # cluster's position will always be 0, hence 'sum_x' in that dimension will # be 0. new_means = sum_x / tf.maximum( tf.constant(1.0, dtype=per_cluster_count.dtype), tf.expand_dims(per_cluster_count, axis=-1)) # We use exponential moving average. TODO(zhouwk): investigate smooth this # over an exponentially moving averaged per cluster count. # # Note that we intentionally do not normalize the means after this update # as empirically this works better. update_means_diff = tf.cast((1.0 - p.decay) * (new_means - theta.means), self.vars.means.dtype) return py_utils.with_dependencies( [tf.assign_add(self.vars.means, update_means_diff)], dists), k_means_loss
def _BPropGenTrainOps(self, vmap, metrics=None, add_summary=True): """Populates the train_ops dictionary in a backwards pass.""" metrics = metrics or self._metrics bprop_variable_filters = self.input_generator.GetBpropVariableFilters() # Only compute the mask if the variable filters are not empty. if bprop_variable_filters != [''] * len(bprop_variable_filters): self._ComputeGradientMask(bprop_variable_filters) train_ops = {} # mapping from op name to op. gradient_mask = None if self._per_input_gradient_mask: # TODO(neerajgaur): Change this to use source_selected from input_batch. onehot = self.input_generator.GetInputSourceOneHot() gradient_mask = { k: tf.tensordot(v, onehot, 1) for k, v in self._per_input_gradient_mask.items() } all_losses = [] for optimization in self.learners: learner_name = optimization.params.name loss_name = optimization.params.loss_name or learner_name metric = metrics.get(loss_name, None) if metric is None: raise ValueError('Loss %s not found in metrics %s' % (loss_name, list(metrics.keys()))) loss = metric[0] all_losses.append(loss) train_ops['train/%s' % learner_name], eval_metrics = optimization.Apply( loss, vmap, gradient_mask=gradient_mask, gradient_adjuster=self.AdjustGradients) if add_summary: for key, (value, weight) in eval_metrics.items(): self.AddEvalMetric(key + '/' + learner_name, value, weight) relevant_bn_updates, _ = py_utils.FindRelevantBatchNormUpdates( all_losses, tf.get_collection(py_utils.BATCH_NORM_UPDATES)) train_ops['bn_updates'] = relevant_bn_updates var_update_ops = [ tf.group(*tf.nest.flatten(train_ops), name='var_update_ops') ] # Post training step update. with tf.control_dependencies(var_update_ops): post_step_op = self.PostTrainingStepUpdate(self.global_step) train_ops = {} with tf.control_dependencies([post_step_op]): # Get the op to update the weight masks and thresholds mask_update_op = self._GetMaskUpdateOp() train_ops['mask_updates'] = mask_update_op with tf.control_dependencies([mask_update_op]): true_global_step = py_utils.GetOrCreateGlobalStepVar() with tf.ops.colocate_with(true_global_step): increment_global_steps = tf.assign_add(true_global_step, 1) if self._global_step_var != true_global_step: with tf.ops.colocate_with(self._global_step_var): increment_global_steps = tf.group( increment_global_steps, tf.assign_add(self._global_step_var, 1)) train_ops['global_step'] = increment_global_steps # If we are using Tpu Embeddings, generate the monolithic send # gradient op. tpu_embedding_activations = tf.get_collection( py_utils.TPU_EMBEDDING_ACTIVATIONS) if tpu_embedding_activations: tpu_embedding_activations_dict = tpu_embedding_activations[0] tpu_embedding = tf.get_collection(py_utils.TPU_EMBEDDING)[0] tpu_embedding_send_gradient_op = py_utils.ComputeTpuEmbeddingGradients( self.loss, tpu_embedding_activations_dict, tpu_embedding) train_ops['tpu_embedding'] = tpu_embedding_send_gradient_op for op_name, op in train_ops.items(): assert op is not None, op_name return train_ops
def _BPropForVariables(self, vmap): """Constructs the backward graph.""" bprop_variable_filters = self.input_generator.GetBpropVariableFilters() # Only compute the mask if the variable filters are not empty. if bprop_variable_filters != [''] * len(bprop_variable_filters): self._ComputeGradientMask(bprop_variable_filters) train_ops = {} # mapping from op name to op. gradient_mask = None if self._per_input_gradient_mask: # TODO(neerajgaur): Change this to use source_selected from input_batch. onehot = self.input_generator.GetInputSourceOneHot() gradient_mask = { k: tf.tensordot(v, onehot, 1) for k, v in six.iteritems(self._per_input_gradient_mask) } all_losses = [] for optimization in self.learners: loss_name = optimization.params.name metric = self._metrics.get(loss_name, None) if metric is None: raise ValueError('Loss %s not found in metrics %s' % (loss_name, list(self._metrics.keys()))) loss = metric[0] all_losses.append(loss) train_ops['train/%s' % loss_name], eval_metrics = optimization.Apply( loss, vmap, gradient_mask=gradient_mask, gradient_adjuster=self.AdjustGradients) for key, (value, weight) in six.iteritems(eval_metrics): self.AddEvalMetric(key + '/' + loss_name, value, weight) relevant_bn_updates, _ = py_utils.FindRelevantBatchNormUpdates( all_losses, tf.get_collection(py_utils.BATCH_NORM_UPDATES)) train_ops['bn_updates'] = relevant_bn_updates # Get the op to update the weight masks and thresholds train_ops['mask_updates'] = self._GetMaskUpdateOp() # Post training step update. train_ops['post_step'] = self.PostTrainingStepUpdate(self.global_step) with tf.control_dependencies(tf.nest.flatten(train_ops)): true_global_step = py_utils.GetOrCreateGlobalStepVar() with tf.colocate_with(true_global_step): increment_global_steps = tf.assign_add(true_global_step, 1) if self._global_step_var != true_global_step: with tf.colocate_with(self._global_step_var): increment_global_steps = tf.group( increment_global_steps, tf.assign_add(self._global_step_var, 1)) train_ops['global_step'] = increment_global_steps # If we are using Tpu Embeddings, generate the monolithic send # gradient op. tpu_embedding_activations = tf.get_collection( py_utils.TPU_EMBEDDING_ACTIVATIONS) if tpu_embedding_activations: tpu_embedding_activations_dict = tpu_embedding_activations[0] tpu_embedding = tf.get_collection(py_utils.TPU_EMBEDDING)[0] tpu_embedding_send_gradient_op = py_utils.ComputeTpuEmbeddingGradients( self.loss, tpu_embedding_activations_dict, tpu_embedding) train_ops['tpu_embedding'] = tpu_embedding_send_gradient_op for op_name, op in six.iteritems(train_ops): assert op is not None, op_name # TODO(rpang): try to structure _train_op as: # tf.cond(skip_step, <only update skip stats>, <all updates>) # so that we skip all other updates when a step is skipped. self._train_op = tf.group(*tf.nest.flatten(train_ops), name='bprop')