def _finish(self, update_ops, name_scope): with tf.control_dependencies(update_ops): ops1 = self.magnitude_optimizer._finish([], name_scope + "_m") # pylint: disable=protected-access ops2 = self.direction_optimizer._finish([], name_scope + "_d") # pylint: disable=protected-access if self.use_global_norm: # apply global grafting with tf.control_dependencies([ops1, ops2]): m_global_norm = tf.Variable(0.) d_global_norm = tf.Variable(0.) for var in self._variables: m_step_norm = self.get_slot(var, "m_step_norm") d_step_norm = self.get_slot(var, "d_step_norm") tf.assign_add(m_global_norm, m_step_norm**2) tf.assign_add(d_global_norm, d_step_norm**2) multiplier = tf.sqrt(m_global_norm / tf.maximum(d_global_norm, 1e-30)) step_ops = [] for var in self._variables: d_step = self.get_slot(var, "scratch_copy") step = tf.where(tf.greater(d_step_norm, 0), multiplier * d_step, tf.zeros_like(d_step)) step_op = tf.assign_add( var, self._learning_rate_tensor * step) step_ops.append(step_op) return tf.group(*step_ops, name=name_scope) return tf.group(*([ops1, ops2] + update_ops), name=name_scope)
def _BPropForVariables(self, vmap): """Constructs the backward graph.""" train_ops = self._BPropGenTrainOps(vmap) # TODO(rpang): try to structure _train_op as: # tf.cond(skip_step, <only update skip stats>, <all updates>) # so that we skip all other updates when a step is skipped. with tf.control_dependencies( [tf.group(*tf.nest.flatten(train_ops), name='train_ops')]): self._train_op = tf.group(self._post_train_ops, name='bprop')
def __init__(self, inference_graph, subgraph_name=None, checkpoint=None, device_type="gpu", tf_master="", session_config=None, clear_device_placement=False): assert device_type in ["cpu", "gpu", "tpu"] subgraph_name = subgraph_name or "default" if isinstance(inference_graph, six.string_types): tf.logging.info("Reading inference graph from %s.", inference_graph) inference_graph = LoadInferenceGraph(inference_graph, clear_device_placement) self._inference_graph = inference_graph self._checkpoint = checkpoint self._device_type = device_type self._tf_master = tf_master self._session_config = session_config self._graph = tf.Graph() with self._graph.as_default(): tf.logging.info( "Loading inference graph for prediction subgraph_name={}.". format(subgraph_name)) self._saver = tf.train.Saver(saver_def=inference_graph.saver_def) with tf.device("/%s:0" % "cpu" if device_type == "tpu" else device_type): tf.import_graph_def(inference_graph.graph_def, name="") if device_type == "tpu": # If no tpu init op exists, create it here. try: self._graph.get_operation_by_name("tpu_init_op") except KeyError: tf.group(tf.tpu.initialize_system(), name="tpu_init_op") self._graph.finalize() if inference_graph.subgraphs: if subgraph_name not in inference_graph.subgraphs: raise ValueError( "Subgraph %s not defined. Valid subgraphs: %s" % (subgraph_name, list(inference_graph.subgraphs.keys()))) subgraph = inference_graph.subgraphs[subgraph_name] self._fetches = subgraph.fetches self._feeds = subgraph.feeds else: self._fetches = inference_graph.fetches self._feeds = inference_graph.feeds # Lock for creating new sessions. self._sess_lock = threading.Lock() self._cur_sess_id = 0 self._CreateNewSession()
def apply_gradients(self, grads_and_vars, global_step=None, name=None): if self._num_micro_batches == 1: return self._opt.apply_gradients(grads_and_vars, global_step) global_step = global_step or py_utils.GetOrCreateGlobalStepVar() with tf.init_scope(): self._create_slots([v for (_, v) in grads_and_vars]) accums = [] variables = [] for g, v in grads_and_vars: accum = self.get_slot(v, 'grad_accum') variables.append(v) # pytype: disable=attribute-error if isinstance(g, tf.IndexedSlices): scaled_grad = tf.IndexedSlices(g.values / self._num_micro_batches, g.indices, dense_shape=g.dense_shape) else: scaled_grad = g / self._num_micro_batches accum_tensor = accum.read_value() accums.append(accum.assign(accum_tensor + scaled_grad)) # pytype: enable=attribute-error def _ApplyAndReset(): normalized_accums = accums if self._apply_crs_to_grad: normalized_accums = [ tf.tpu.cross_replica_sum(accum.read_value()) for accum in accums ] apply_op = self._opt.apply_gradients( list(zip(normalized_accums, variables))) with tf.control_dependencies([apply_op]): zero_op = [ tf.assign(accum, tf.zeros_like(accum)) for accum in accums ] return tf.group(zero_op, tf.assign_add(global_step, 1)) def _Accum(): return tf.no_op() accum_step = tf.cond( tf.equal( tf.math.floormod(self._counter + 1, self._num_micro_batches), 0), _ApplyAndReset, # Apply the accumulated gradients and reset. _Accum) # Accumulate gradients. with tf.control_dependencies([tf.group(accums)]): return tf.group(accum_step, tf.assign_add(self._counter, 1))
def partitioned_variable_assign(partitioned_var, new_value): """Assign op for partitioned variables. Args: partitioned_var: A partitioned tensorflow variable new_value: Value to be assigned to the variable var Returns: A tensorflow op that groups the assign ops for each of the variable slices """ # Determine which axis was used to partition the variable. Currently # tensorflow allows partitioning variable only along 1 axis. axis = 0 if len(partitioned_var) == 1 else determine_partitioned_axis( partitioned_var) partition_sizes = np.array( [partition.get_shape()[axis] for partition in partitioned_var]) new_partitioned_values = tf.split(new_value, tf.convert_to_tensor(partition_sizes, dtype=tf.int32), axis=axis) op_list = [] for partition in partitioned_var: op_list.append( variable_assign(partition, new_partitioned_values[len(op_list)])) return tf.group(*op_list, name=partitioned_var.name + '_group_assign')
def Apply(self, lr, var_grad): p = self.params def _Acc(vg): """Updating accumulators.""" v, g = vg with tf.variable_scope(v.op.name): a = py_utils.CreateVariable( 'grad_accumulator', py_utils.WeightParams(v.get_shape(), py_utils.WeightInit.Constant(0.0), self.params.dtype), trainable=False) a = tf.assign_add(a, g) return py_utils.VarGrad(v, a) var_grad = var_grad.Transform(_Acc) def _ApplyAndReset(): with tf.control_dependencies([ self._opt.Apply( lr, py_utils.ApplyGradMultiplier(var_grad, 1. / p.accum_steps)) ]): return tf.group( *[tf.assign(a, tf.zeros_like(a)) for _, a in var_grad.Flatten()]) if self.params.add_summary_in_apply: self.AddSummary(lr, self.GetOptimizer(lr), var_grad) return tf.cond( tf.equal( tf.math.floormod(self.global_step, p.accum_steps), p.accum_steps - 1), _ApplyAndReset, lambda: tf.group(tf.no_op()))
def _TestHelper(self, params, test_data, enroll_data): """Returns the attentive scores for the given test and enrollment data. Args: params: Babelfish configuration parameters for setting up the attentive_scoring_layer. test_data: Test data related to 2 test utterances each with 2 key vectors of 2 dimensions and 2 value vectors of 3 dimensions. Each utterance representation contains a packed form of the key and value vectors. The result is a 2d tensor (of tf.float32 elements) of dimension [num_test_utts, representation_dim]. In this example, num_test_utts is 2 and representation_dim is 10 (or 2 keys * (2 key_dim + 3 value_dim)). enroll_data: Enrollment data related to 2 speakers each with 2 enrollment utterances. Each utterance is composed of 2 key vectors of 2 dimensions and 2 value vectors of 3 dimensions. Each utterance is a packed form of the key and value vectors. The result is a 3d tensor (of tf.float32 elements) of dimension [num_enroll_spks, num_enroll_utts_per_spk, representation_dim]. Returns: The output of the attentive scoring. The result is a numpy np.float32 tensor of shape [num_test_utts, num_enroll_spks]. """ with self.session() as sess: tf.random.set_seed(_TF_RANDOM_SEED) attention_network = params.Instantiate() output = attention_network.FProp((test_data, enroll_data)) sess.run( tf.group(tf.global_variables_initializer(), tf.tables_initializer())) return sess.run(output)
def _ApplyAndReset(): with tf.control_dependencies([ self._opt.Apply( lr, py_utils.ApplyGradMultiplier(var_grad, 1. / p.accum_steps)) ]): return tf.group( *[tf.assign(a, tf.zeros_like(a)) for _, a in var_grad.Flatten()])
def _Apply(): """Use the matched optimizer to apply the gradients.""" train_ops = [] non_default_regex = [ regex for regex in self._optimizer_map if regex != 'default_optimizer' ] for regex in self._optimizer_map: if var_grad_map[regex]: opt = tf_optimizer_map[regex] train_ops.append(opt.apply_gradients(var_grad_map[regex])) # pylint: disable=cell-var-from-loop, g-long-lambda if regex == 'default_optimizer': filtered_var_grad = var_grad.FilterKeyVal( lambda k, v: any([ re.match(i, v.var.name) for i in non_default_regex ])) else: filtered_var_grad = var_grad.FilterKeyVal( lambda k, v: (re.match(regex, v.var.name))) # pylint: enable=cell-var-from-loop, g-long-lambda self._optimizer_map[regex].AddSummary( self._lr_map[regex], opt, filtered_var_grad) return tf.group(*train_ops, name='composite_optimizer_train_op')
def testAttentionNetworkFPropTrainableScaleFactor(self): """Checks that the forward propagation is correct for trainable scaling.""" # Enable trainable scaling self.params.use_trainable_scale_factor = True self.params.scale_factor = 2.0 with self.session() as sess: tf.random.set_seed(_TF_RANDOM_SEED) attention_network = self.params.Instantiate() output = attention_network.FProp( (self.test_data, self.enroll_data), attention_network.theta) sess.run( tf.group(tf.global_variables_initializer(), tf.tables_initializer())) output_result = sess.run(output) log_scale_factor_result = sess.run( attention_network.theta.trainable_log_scale_factor) expected_output = np.array([[1.0, 0.434889], [-0.269374, -0.202443]]) self.assertAllClose(expected_output, output_result, rtol=_REL_TOLERANCE, atol=_ABS_TOLERANCE) # Check that the log of the scale_factor=2 is as expected self.assertAllClose(0.6931471824645996, log_scale_factor_result, rtol=_REL_TOLERANCE, atol=_ABS_TOLERANCE)
def PostTrainingStepUpdate(self, global_step): ops = [ super(PassiveAsymQDomain, self).PostTrainingStepUpdate(global_step) ] for t_name in self._t_names: ops.extend(self._RecordTensor(t_name)) self._SummarizeTensor(t_name) return tf.group(ops)
def CreateTpuEmbeddingEnqueueOps(self): """Creates the TpuEmbedding enqueue ops on the host. Note that this must be called after the instantiation of the monolithic TPUEmbeddingLayer. """ p = self.params cluster = self.cluster num_tpu_hosts = cluster.num_tpu_hosts num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1 tpu_embedding_collection = tf.get_collection(py_utils.TPU_EMBEDDING) tpu_embedding = (tpu_embedding_collection[0] if tpu_embedding_collection else None) enqueue_ops = [] if num_tpu_hosts > 1 and tpu_embedding is not None: if not p.use_per_host_infeed: tf.logging.fatal( 'TPU Embedding must be used with per_host_infeed with multiple ' 'TPU host topologies.') tpu_emb_input_keys = (list(tpu_embedding.feature_to_config_dict.keys()) if tpu_embedding is not None else []) tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys) if not tpu_embedding: return for task_id in range(num_infeed_hosts): host_device = '/task:{}/device:CPU:0'.format(task_id) with tf.device(host_device): if isinstance(self._batch, py_utils.NestedMap): # Hack: bucket_keys and xxx.bucket_keys are not needed on TPU. # Note that when MultiTaskData is used, bucket_keys will be at the # second level of the dictionary. self._batch = self._batch.FilterKeyVal( lambda k, _: not k.endswith('bucket_keys')) tf.logging.info('host_device: %s, batch: %r', host_device, self._batch) enqueue_dict_per_core = [ {} for _ in range(tpu_embedding.num_cores_per_host) ] num_cores_per_host = tpu_embedding.num_cores_per_host for key in tpu_emb_input_keys: feat = self._batch[key] tpu_emb_feat_splitted = tf.split(feat, num_cores_per_host) for core, split in enumerate(tpu_emb_feat_splitted): # Dense to sparse. Note the assumption of a padding id. sample_indices = tf.where(tf.not_equal(split, -1)) embedding_indices = tf.gather_nd(split, sample_indices) enqueue_data = tpu_embedding_lib.EnqueueData( embedding_indices, sample_indices) enqueue_dict_per_core[core][key] = enqueue_data enqueue_ops += tpu_embedding.generate_enqueue_ops( enqueue_dict_per_core) self._tpu_infeed_op.append(tf.group(*enqueue_ops))
def PostTrainingLoop(self, outfeed=None): """Construct the post training loop op. Args: outfeed: a dict of tensors dequeued from TPU outfeed queue. """ with py_utils.GlobalStepContext(self._global_step_var): self._post_training_loop_op = tf.group( *[opt.ApplyPostTrainingLoop() for opt in self.learners])
def ApplyPostTrainingLoop(self): """Applies any computation to run after each tpu trainining loop. Returns: Ops to run after training loop ends. """ invoke_async_ops = self._optimizer.invoke_async_preconditioner_computation( tf.cast(py_utils.GetGlobalStep(), tf.int32)) assign_ops = self._optimizer.assign_preconditioner_to_host_vars() return tf.group(*[invoke_async_ops, assign_ops])
def ApplyPostTrainingLoop(self): """Apply computation to run after each tpu training loop for each optimizer. Returns: Ops to run after training loop ends. """ post_training_ops = [ opt.ApplyPostTrainingLoop() for _, opt in self._optimizer_map.items() ] return tf.group(*post_training_ops)
def Apply(self, metrics, vmap, gradient_mask=None, gradient_adjuster=None): """Computes updates on 'vmap' to optimize 'loss'. TODO(rpang): explore merging gradient_mask and gradient_adjuster. Args: metrics: A Dict[str, (value, weight)], from which loss can be extracted according to p.loss_name. vmap: A `.NestedMap` object containing variables to optimize. gradient_mask: if not None, a dict mapping variable names to a 0/1 scalar. gradient_adjuster: if not None, a function that mutates a given var_grads. Returns: (losses, op, eval_metrics), where - losses is a list of scalar tensors; - op is a tf.Operation to update variables; - eval_metrics is a Dict[str, (value, weight)], where each value/weight is a scalar tensor. """ # We apply gradients outside the name_scope to maintain backwards # compatibility on variables created by self.optimizer.Apply(). losses, var_grads, eval_metrics = self._ComputeLossesAndGradients( metrics, vmap) if 'tpu_embedding_var_grads' in var_grads: tpu_embedding_var_grads = var_grads.tpu_embedding_var_grads del var_grads.tpu_embedding_var_grads tpu_embedding_collection = py_utils.GetTpuEmbeddingGraphCollection( )[0] assert tpu_embedding_collection tpu_emb_update_op, stats = tpu_embedding_collection.ApplyGradients( py_utils.GetTaskCallScope(), tpu_embedding_var_grads.Transform( lambda var_grad: var_grad.grad)) eval_metrics.update(stats) else: tpu_emb_update_op = tf.no_op() assert py_utils.GetGlobalStep() is not None lr = self.LearningRate() var_grads, stats = self.AdjustGradients( var_grads, gradient_mask=gradient_mask, gradient_adjuster=gradient_adjuster) eval_metrics.update(stats) self._var_grads = var_grads eval_metrics['learning_rate'] = (tf.convert_to_tensor(lr), tf.convert_to_tensor(1.)) var_update_op = tf.group( [tpu_emb_update_op, self.optimizer.Apply(lr, var_grads)]) return losses, var_update_op, eval_metrics
def _ApplyAndReset(): normalized_accums = accums if self._apply_crs_to_grad: normalized_accums = [ tf.tpu.cross_replica_sum(accum.read_value()) for accum in accums ] apply_op = self._opt.apply_gradients( list(zip(normalized_accums, variables))) with tf.control_dependencies([apply_op]): zero_op = [tf.assign(accum, tf.zeros_like(accum)) for accum in accums] return tf.group(zero_op, tf.assign_add(global_step, 1))
def PostTrainingStepUpdate(self): """Returns a TF op which will be invoked at each training step. Subclasses of `BaseLayer` can implement this method. The method should return a TF op to be invoked during training after gradients are applied. """ update_ops = [ child.PostTrainingStepUpdate() for child in self._private_children.Flatten() ] return tf.group(*update_ops)
def _BuildRestore(self): """Builds restore ops.""" assign_ops = [] for var in self._vars: val, = io_ops.restore_v2( prefix=self._restore_prefix_ph, tensor_names=[_VarKey(var)], shape_and_slices=[""], dtypes=[var.dtype]) assign_ops.append(var.assign(val)) self._restore_op = tf.group(*assign_ops)
def ApplyPostTrainingLoop(self, global_step): """Applies any computation to run after each tpu trainining loop. Args: global_step: Global step variable. Returns: Ops to run after training loop ends. """ invoke_async_ops = self._optimizer.invoke_async_preconditioner_computation( tf.cast(global_step, tf.int32)) assign_ops = self._optimizer.assign_preconditioner_to_host_vars() return tf.group(*[invoke_async_ops, assign_ops])
def ApplyPostTrainingLoop(self, global_step): """Apply any computation to run after each tpu training loop for each optimizer. Args: global_step: Global step variable. Returns: Ops to run after training loop ends. """ post_training_ops = [ opt.ApplyPostTrainingLoop(global_step) for _, opt in self._optimizer_map.items() ] return tf.group(*post_training_ops)
def assign_preconditioner_to_host_vars(self): """Assign/Grab latest copy of preconditioners.""" keys_shapes_and_preconditioner_vars = [] assign_ops = [] for var in self._all_vars_for_preconditioning: shape = var.get_shape() if not self._fallback_to_diagonal_for_shape(shape): partitioned_v = TensorPartitioner.partition_tensor( var, self._partition_info) num_partitions = len(partitioned_v) for pt_idx, pt in enumerate(partitioned_v): pt_shape = pt.get_shape() preconditioner_exists_for_dim = ( self._preconditioner_available_for_dims(pt_shape)) var_rank = len(pt_shape) for i in range(var_rank): if preconditioner_exists_for_dim[i]: key = self._key_for_var(var, i, pt_idx) preconditioner = self.get_slot( var, self._preconditioner_key_for_partition_and_dim( i, pt_idx, num_partitions)) keys_shapes_and_preconditioner_vars.append( (key, tf.shape(preconditioner), preconditioner)) if not keys_shapes_and_preconditioner_vars: return tf.no_op() keys, shapes, preconditioner_vars = zip( *keys_shapes_and_preconditioner_vars) preconditioner_vals, successes = x_ops.get_preconditioners( shapes, keys=keys, preconditioner_compute_graphdef=( self._preconditioner_compute_graphdef)) for preconditioner_var, preconditioner_val, success in zip( preconditioner_vars, preconditioner_vals, successes): success_mult = tf.cast(success, preconditioner.dtype) assign_ops.append( state_ops.assign( preconditioner_var, (1.0 - success_mult) * preconditioner_var + success_mult * preconditioner_val)) return tf.group(*assign_ops)
def _BuildRestore(self): """Builds restore ops.""" assign_ops = [] with self._var_graph.as_default(): per_device = collections.defaultdict(lambda: []) for var in self._vars: per_device[var.device].append(var) for device, var_list in per_device.items(): with self._var_graph.device(device): for var in var_list: val, = io_ops.restore_v2( prefix=self._restore_prefix_ph, tensor_names=[_VarKey(var)], shape_and_slices=[""], dtypes=[var.dtype]) assign_ops.append(var.assign(val)) self._restore_op = tf.group(*assign_ops)
def PostTrainingStepUpdate(self, global_step): """Updates moving_mean, moving_variance after each training step.""" p = self.params # Get sufficient stats that accumulates over microbatches. counts = self.accumulators.counts.GetValue() mean_ss = self.accumulators.mean_ss.GetValue() variance_ss = self.accumulators.variance_ss.GetValue() # Compute batch mean and batch variance from sufficient stats mean, variance = tf.nn.normalize_moments(counts, mean_ss, variance_ss, None) decay = tf.convert_to_tensor(1.0 - p.decay, p.dtype) # Update moving_mean, moving_variance from batch mean and batch variance. with tf.name_scope(p.name) as scope: with tf.colocate_with(self.vars.moving_mean): mean_update = tf.assign_sub( self.vars.moving_mean, tf.where(tf.greater(counts, 0.5), (self.vars.moving_mean - tf.cast(mean, p.dtype)) * decay, tf.zeros_like(self.vars.moving_mean)), name='moving_mean_update') with tf.colocate_with(self.vars.moving_variance): var_update = tf.assign_sub( self.vars.moving_variance, tf.where(tf.greater(counts, 0.5), (self.vars.moving_variance - tf.cast(variance, p.dtype)) * decay, tf.zeros_like(self.vars.moving_variance)), name='moving_variance_update') py_utils.CheckNumerics( self.vars.moving_mean, 'moving mean of {} failed numeric check'.format(scope)) py_utils.CheckNumerics( self.vars.moving_variance, 'moving variance of {} failed numeric check'.format(scope)) self.accumulators.counts.Reset() self.accumulators.mean_ss.Reset() self.accumulators.variance_ss.Reset() return tf.group(mean_update, var_update)
def _TestHelperWithState(self, params, list_of_batches): """Returns the expected outputs for the tests. Args: params: Babelfish configuration parameters for setting up the cumulative_statistics_layer. list_of_batches: A list of padded batches of examples. The structure is a list of the following: { 'features': tf.tensor(float32) of shape(len, batch, dim) 'paddings': tf.tensor(float32) of shape(len, batch) } Returns: A dictionary containing numpy arrays of the expected test outputs. The structure is as follows: { 'features': np.array(float32) of shape(len, batch, dim) 'paddings': np.array(float32) of shape(len, batch) } """ with self.session() as sess: tf.random.set_seed(_TF_RANDOM_SEED) network = params.Instantiate() batch_size = list_of_batches[0].features.shape[1] state = network.zero_state(network.theta, batch_size) for batch_t in list_of_batches: output = network.FProp(network.theta, batch_t, state) # Pass the output state over to the next batch as input state. state = output.state sess.run( tf.group(tf.global_variables_initializer(), tf.tables_initializer())) return sess.run(output)
def _BPropForVariables(self, vmap): """Constructs the backward graph.""" bprop_variable_filters = self.input_generator.GetBpropVariableFilters() # Only compute the mask if the variable filters are not empty. if bprop_variable_filters != [''] * len(bprop_variable_filters): self._ComputeGradientMask(bprop_variable_filters) train_ops = {} # mapping from op name to op. gradient_mask = None if self._per_input_gradient_mask: # TODO(neerajgaur): Change this to use source_selected from input_batch. onehot = self.input_generator.GetInputSourceOneHot() gradient_mask = { k: tf.tensordot(v, onehot, 1) for k, v in six.iteritems(self._per_input_gradient_mask) } all_losses = [] for optimization in self.learners: loss_name = optimization.params.name metric = self._metrics.get(loss_name, None) if metric is None: raise ValueError('Loss %s not found in metrics %s' % (loss_name, list(self._metrics.keys()))) loss = metric[0] all_losses.append(loss) train_ops['train/%s' % loss_name], eval_metrics = optimization.Apply( loss, vmap, gradient_mask=gradient_mask, gradient_adjuster=self.AdjustGradients) for key, (value, weight) in six.iteritems(eval_metrics): self.AddEvalMetric(key + '/' + loss_name, value, weight) relevant_bn_updates, _ = py_utils.FindRelevantBatchNormUpdates( all_losses, tf.get_collection(py_utils.BATCH_NORM_UPDATES)) train_ops['bn_updates'] = relevant_bn_updates # Get the op to update the weight masks and thresholds train_ops['mask_updates'] = self._GetMaskUpdateOp() # Post training step update. train_ops['post_step'] = self.PostTrainingStepUpdate(self.global_step) with tf.control_dependencies(tf.nest.flatten(train_ops)): true_global_step = py_utils.GetOrCreateGlobalStepVar() with tf.colocate_with(true_global_step): increment_global_steps = tf.assign_add(true_global_step, 1) if self._global_step_var != true_global_step: with tf.colocate_with(self._global_step_var): increment_global_steps = tf.group( increment_global_steps, tf.assign_add(self._global_step_var, 1)) train_ops['global_step'] = increment_global_steps # If we are using Tpu Embeddings, generate the monolithic send # gradient op. tpu_embedding_activations = tf.get_collection( py_utils.TPU_EMBEDDING_ACTIVATIONS) if tpu_embedding_activations: tpu_embedding_activations_dict = tpu_embedding_activations[0] tpu_embedding = tf.get_collection(py_utils.TPU_EMBEDDING)[0] tpu_embedding_send_gradient_op = py_utils.ComputeTpuEmbeddingGradients( self.loss, tpu_embedding_activations_dict, tpu_embedding) train_ops['tpu_embedding'] = tpu_embedding_send_gradient_op for op_name, op in six.iteritems(train_ops): assert op is not None, op_name # TODO(rpang): try to structure _train_op as: # tf.cond(skip_step, <only update skip stats>, <all updates>) # so that we skip all other updates when a step is skipped. self._train_op = tf.group(*tf.nest.flatten(train_ops), name='bprop')
def CreateTpuFeeds(self): """Creates the TPU infeed queue from preprocessed batch.""" p = self.params cluster = self.cluster num_tpu_hosts = cluster.num_tpu_hosts num_cores_per_host = cluster.total_worker_devices // num_tpu_hosts tf.logging.info( 'CreateTPUFeeds num_splits_per_client={} ' 'num_devices_per_split={} num_tpu_hosts={} use_per_host_infeed={}'. format(cluster.num_splits_per_client, cluster.num_devices_per_split, num_tpu_hosts, p.use_per_host_infeed)) assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts) if (cluster.num_devices_per_split > num_cores_per_host and p.use_per_host_infeed): tf.logging.fatal( 'Doesn\'t support per host infeed mode when ' 'num_devices_per_split({}) > num_cores_per_host({})'.format( cluster.num_devices_per_split, num_cores_per_host)) num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1 with py_utils.outside_all_rewrites(): assert py_utils.use_tpu() assert not self._made_tpu_infeed shards = tpu_function.get_tpu_context( ).number_of_shards // num_infeed_hosts tf.logging.info('shards {}'.format(shards)) input_ops_list = [] queues = [] tpu_embedding_collection = tf.get_collection( py_utils.TPU_EMBEDDING) tpu_embedding = (tpu_embedding_collection[0] if tpu_embedding_collection else None) if num_tpu_hosts > 1 and tpu_embedding is not None: if not p.use_per_host_infeed: tf.logging.fatal( 'TPU Embedding must be used with per_host_infeed with multiple ' 'TPU host topologies.') tpu_emb_input_keys = (list( tpu_embedding.feature_to_config_dict.keys()) if tpu_embedding is not None else []) tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys) batch = None for task_id in range(num_infeed_hosts): host_device = '/task:{}/device:CPU:0'.format(task_id) with tf.device(host_device): batch = self.GetPreprocessedInputBatch() if isinstance(batch, py_utils.NestedMap): # Hack: bucket_keys and xxx.bucket_keys are not needed on TPU. # Note that when MultiTaskData is used, bucket_keys will be at the # second level of the dictionary. batch = batch.FilterKeyVal( lambda k, _: not k.endswith('bucket_keys')) tf.logging.info('host_device: %s, batch: %r', host_device, batch) if tpu_embedding is not None: enqueue_dict_per_core = [ {} for _ in range(tpu_embedding.num_cores_per_host) ] num_cores_per_host = tpu_embedding.num_cores_per_host for key in tpu_emb_input_keys: feat = batch[key] tpu_emb_feat_splitted = tf.split( feat, num_cores_per_host) for core, split in enumerate( tpu_emb_feat_splitted): # Dense to sparse. Note the assumption of a padding id. sample_indices = tf.where( tf.not_equal(split, -1)) embedding_indices = tf.gather_nd( split, sample_indices) enqueue_data = tpu_embedding_lib.EnqueueData( embedding_indices, sample_indices) enqueue_dict_per_core[core][key] = enqueue_data input_ops_list += tpu_embedding.generate_enqueue_ops( enqueue_dict_per_core) for k, x in batch.FlattenItems(): assert x.shape.is_fully_defined(), ( 'Shape must be fully defined: %s: %s' % (k, x)) # TODO(cwhipkey): if it's a string (or other type not supported on # TPU), drop it from feeding and on the other end add in an op that # fails if used. shapes = batch.Transform(lambda x: x.shape).Flatten() dtypes = batch.Transform(lambda x: x.dtype).Flatten() tf.logging.info('host_device: %s infeed shapes: %r', host_device, shapes) tf.logging.info('host_device: %s infeed dtypes: %r', host_device, dtypes) if p.use_partitioned_infeed_queue: device_assignment = py_utils.GetTpuDeviceAssignment() host_device = device_assignment.host_device( replica=0, job=tf.flags.FLAGS.tf_master) host_id = int( host_device.split('/task:')[1].split('/device:') [0]) tf.logging.info('host_id: {} host_device: {}'.format( host_id, host_device)) q = tpu_feed._PartitionedInfeedQueue( # pylint: disable=protected-access number_of_tuple_elements=len(dtypes), device_assignment=device_assignment, host_id=host_id, input_partition_dims=[[p.num_partitions, 1] for _ in dtypes], tuple_types=dtypes, tuple_shapes=shapes) else: q = tpu_feed.InfeedQueue(tuple_types=dtypes, tuple_shapes=shapes) assert shards is not None q.set_number_of_shards(shards) queues.append(q) tf.logging.info('q=%r', q) if p.use_partitioned_infeed_queue: input_ops = q.generate_enqueue_ops([batch.Flatten()]) elif p.use_per_host_infeed: # TODO(ylc/zhifengc): Add this to a policy module and test it. def TPUOrdinalFunction(shard_index_in_host): device_assignment = py_utils.GetTpuDeviceAssignment( ) if device_assignment: # We put both enqueue/dequeue ops at core 0 in each replica. replica = device_assignment.lookup_replicas( task_id, 0)[shard_index_in_host] # pylint: disable=cell-var-from-loop return device_assignment.tpu_ordinal( replica=replica) else: return shard_index_in_host input_ops = q.split_inputs_and_generate_enqueue_ops( batch.Flatten(), placement_function=lambda x: host_device, # pylint: disable=cell-var-from-loop tpu_ordinal_function=TPUOrdinalFunction) else: input_ops = q.split_inputs_and_generate_enqueue_ops( batch.Flatten(), device_assignment=py_utils.GetTpuDeviceAssignment( )) input_ops_list += input_ops tf.logging.info('input_ops_list %s', input_ops_list) tpu_infeed_op = tf.group(*input_ops_list) self._made_tpu_infeed = True # Let trainer.py use multiple threads to drive the infeed op. for _ in range(p.tpu_infeed_parallelism): tf.add_to_collection(py_utils.ENQUEUE_OPS, tpu_infeed_op) self._tpu_infeed_op = tpu_infeed_op with tf.device(tf.tpu.core(0)): tensors = queues[0].generate_dequeue_op() return batch.Pack(tensors)
def CreateTpuEnqueueOps(self): """Create the host-side enqueue ops. This should be called in an outer non-TPU context. """ assert not self._tpu_queues, ( 'CreateTpuEnqueueOps should only be called ' 'once.') self._tpu_queues = [] p = self.params cluster = self.cluster num_tpu_hosts = cluster.num_tpu_hosts num_cores_per_host = cluster.total_worker_devices // num_tpu_hosts tf.logging.info( 'CreateTpuEnqueueOps num_splits_per_client={} ' 'num_devices_per_split={} num_tpu_hosts={} use_per_host_infeed={}'. format(cluster.num_splits_per_client, cluster.num_devices_per_split, num_tpu_hosts, p.use_per_host_infeed)) assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts) if (cluster.num_devices_per_split > num_cores_per_host and p.use_per_host_infeed): tf.logging.fatal( 'Doesn\'t support per host infeed mode when ' 'num_devices_per_split({}) > num_cores_per_host({})'.format( cluster.num_devices_per_split, num_cores_per_host)) num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1 shards = (cluster.total_worker_devices // num_infeed_hosts) // cluster.num_devices_per_split tf.logging.info('shards {}'.format(shards)) input_ops_list = [] tpu_embedding_collection = tf.get_collection(py_utils.TPU_EMBEDDING) tpu_embedding = (tpu_embedding_collection[0] if tpu_embedding_collection else None) if num_tpu_hosts > 1 and tpu_embedding is not None: if not p.use_per_host_infeed: tf.logging.fatal( 'TPU Embedding must be used with per_host_infeed with multiple ' 'TPU host topologies.') tpu_emb_input_keys = (list(tpu_embedding.feature_to_config_dict.keys()) if tpu_embedding is not None else []) tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys) tf.logging.info('num_infeed_hosts: %d', num_infeed_hosts) for task_id in range(num_infeed_hosts): host_device = '/task:{}/device:CPU:0'.format(task_id) with tf.device(host_device): self._batch = self.GetPreprocessedInputBatch() if isinstance(self._batch, py_utils.NestedMap): # Hack: bucket_keys and xxx.bucket_keys are not needed on TPU. # Note that when MultiTaskData is used, bucket_keys will be at the # second level of the dictionary. self._batch = self._batch.FilterKeyVal( lambda k, _: not k.endswith('bucket_keys')) tf.logging.info('host_device: %s, batch: %r', host_device, self._batch) for k, x in self._batch.FlattenItems(): assert x.shape.is_fully_defined(), ( 'Shape must be fully defined: %s: %s' % (k, x)) # TODO(cwhipkey): if it's a string (or other type not supported on # TPU), drop it from feeding and on the other end add in an op that # fails if used. shapes = self._batch.Transform(lambda x: x.shape).Flatten() dtypes = self._batch.Transform(lambda x: x.dtype).Flatten() tf.logging.info('host_device: %s infeed shapes: %r', host_device, shapes) tf.logging.info('host_device: %s infeed dtypes: %r', host_device, dtypes) if p.use_partitioned_infeed_queue: device_assignment = py_utils.GetTpuDeviceAssignment() host_device = device_assignment.host_device( replica=0, job=tf.flags.FLAGS.tf_master) host_id = int( host_device.split('/task:')[1].split('/device:')[0]) tf.logging.info('host_id: {} host_device: {}'.format( host_id, host_device)) q = tpu_feed._PartitionedInfeedQueue( # pylint: disable=protected-access number_of_tuple_elements=len(dtypes), device_assignment=device_assignment, host_id=host_id, input_partition_dims=[[p.num_partitions] + [1] * (len(s) - 1) for s in shapes], tuple_types=dtypes, tuple_shapes=shapes) else: q = tpu_feed.InfeedQueue(tuple_types=dtypes, tuple_shapes=shapes) assert shards is not None q.set_number_of_shards(shards) self._tpu_queues.append(q) if p.use_partitioned_infeed_queue: input_ops = q.generate_enqueue_ops([self._batch.Flatten()]) elif p.use_per_host_infeed: # TODO(ylc/zhifengc): Add this to a policy module and test it. def TPUOrdinalFunction(shard_index_in_host): device_assignment = py_utils.GetTpuDeviceAssignment() if device_assignment: # We put both enqueue/dequeue ops at core 0 in each replica. replica = device_assignment.lookup_replicas( task_id, 0)[shard_index_in_host] # pylint: disable=cell-var-from-loop return device_assignment.tpu_ordinal( replica=replica) else: return shard_index_in_host input_ops = q.split_inputs_and_generate_enqueue_ops( self._batch.Flatten(), placement_function=lambda x: host_device, # pylint: disable=cell-var-from-loop tpu_ordinal_function=TPUOrdinalFunction) else: input_ops = q.split_inputs_and_generate_enqueue_ops( self._batch.Flatten(), device_assignment=py_utils.GetTpuDeviceAssignment()) input_ops_list += input_ops tf.logging.info('input_ops_list %s', input_ops_list) grouped_infeed_op = tf.group(*input_ops_list) self._tpu_infeed_op = [] for _ in range(p.tpu_infeed_parallelism): self._tpu_infeed_op.append(grouped_infeed_op)
def Export(cls, model_cfg, model_task_name=None, device_options=InferenceDeviceOptions( device='', retain_device_placement=False, var_options=None, gen_init_op=True, dtype_override=None), freeze_checkpoint=None, freeze_defaults=False, export_path=None, subgraph_filter=None, random_seed=None, disable_packed_input=True): """Exports a InferenceGraph proto with piecewise subgraphs. Sets FLAGS.enable_asserts to False unless user explicitly sets it to True. Args: model_cfg: a Params instance as returned by model_registry.GetParams(modelname, 'Test') or model_params.Model(). model_task_name: The task to generate an inference graph for. Should be None for single-task models. device_options: Device options for the accelerator used for serving. freeze_checkpoint: The checkpoint to load. Loads and freezes the model if given. freeze_defaults: Default initializes the graph and freeze. Useful for early testing of downstream tools without having a checkpoint. export_path: If not None, write the inference graph in ASCII to this path. subgraph_filter: A list of subgraph names. If not None or empty, export only this list of inference subgraphs. random_seed: Fixes the random seed in the exported inference graph. disable_packed_input: Disable packed input for inference writing purposes. Returns: InferenceGraph proto. Raises: ValueError: if the model does not support the listed subgraphs. """ assert issubclass(model_cfg.cls, base_model.BaseModel) # Disable assertions unless user explicitly enables it. if FLAGS['enable_asserts'].using_default_value: FLAGS.enable_asserts = False # TODO(laurenzo): Work out how much we need to specify here in terms of # cluster configuration. cls._SetClusterParams(model_cfg.cluster, device_options) # Configure the model. model_cfg.random_seed = random_seed model_cfg.is_inference = True if disable_packed_input: def _DisablePackedInput(task): if (_ParamExists(task, 'encoder') and _ParamExists(task.encoder, 'packed_input')): task.encoder.packed_input = False if (_ParamExists(task, 'decoder') and _ParamExists(task.decoder, 'packed_input')): task.decoder.packed_input = False if issubclass(model_cfg.cls, base_model.MultiTaskModel): for _, task_param in model_cfg.task_params.IterParams(): _DisablePackedInput(task_param) else: _DisablePackedInput(model_cfg.task) tf.logging.info('Model %s params:', model_cfg.name) for line in model_cfg.ToText().split('\n'): tf.logging.info('%s', line) # Instantiate the graph. graph = tf.Graph() with graph.as_default(): tf.random.set_seed(random_seed) cluster = model_cfg.cluster.Instantiate() device = cluster.GetPlacer() tpu_const_scope = _DummyScope() if (IsTpu(device_options) and device_options.var_options == 'AS_CONSTANTS'): # Do not specify devices for variables if we are marking them as # constants. device = '' tpu_const_scope = ConstGuaranteeScope() with cluster, tf.device(device), tpu_const_scope: bfloat16_override = ShouldForceBfloat16ForWeightsAndActivations( device_options) if bfloat16_override: py_utils.UpdateDtype(model_cfg, tf.bfloat16) py_utils.UpdateFpropDtype(model_cfg, tf.bfloat16) # Hard-code TPU-related flags prior to instantiating model. old_enable_asserts = FLAGS.enable_asserts old_xla_device = FLAGS.xla_device if IsTpu(device_options): FLAGS.enable_asserts = False FLAGS.xla_device = 'tpu' # Ensure the global_step variable is created. _ = py_utils.GetOrCreateGlobalStepVar() try: mdl = model_cfg.Instantiate() task = mdl.GetTask(model_task_name) variables_to_restore = ( _MakeVariableDictionary(tf.global_variables()) if not mdl.ema else mdl.ema.variables_to_restore(mdl.variables_for_ema)) if bfloat16_override: saver_var_spec = ( bfloat16_variables .get_saver_spec_for_variables_with_bf16_overrides( variables_to_restore)) else: saver_var_spec = variables_to_restore saver = tf.train.Saver(saver_var_spec) tf.variables_initializer( tf.global_variables(), name='init_all_variables') if IsTpu(device_options) and device_options.gen_init_op: tf.group(tf.tpu.initialize_system(), name='tpu_init_op') inference_graph_proto = inference_graph_pb2.InferenceGraph() subgraphs_proto = task.Inference() if isinstance(subgraphs_proto, dict): subgraphs_proto = ConvertSubgraphDictToProto(subgraphs_proto) for name, subgraph in subgraphs_proto.subgraphs.items(): if not subgraph_filter or name in subgraph_filter: inference_graph_proto.subgraphs[name].CopyFrom(subgraph) # Add a table init op and global variable init op to the graph. # Tables can be declared anywhere in the graph, so this op has to be # added last. tf.tables_initializer(name='init_all_tables') finally: # Reset TPU-related flags after model instantiation. FLAGS.enable_asserts = old_enable_asserts FLAGS.xla_device = old_xla_device tf.logging.info('Graph contains ops: %r', [op.name for op in graph.get_operations()]) inference_graph_proto.saver_def.CopyFrom(saver.as_saver_def()) # Freezing. if freeze_defaults or freeze_checkpoint: output_op_names = GetOutputOpNames( graph, inference_graph_proto, preserve_colocation_nodes=False) if cls._DeviceSupportsFreezing(device_options): raise ValueError('freeze_checkpoint cannot be used with device ' + device_options.device) if freeze_checkpoint: tf.logging.info('Freezing graph from checkpoint: %s', freeze_checkpoint) graph_def = _FreezeGraphFromCheckpoint(graph, saver, freeze_checkpoint, output_op_names) elif freeze_defaults: tf.logging.info('Default initializing graph and freezing.') graph_def = _FreezeDefaults(graph, output_op_names) else: output_op_names = GetOutputOpNames(graph, inference_graph_proto) # Prune the graph to just the parts we need. # To support restoring, we have to not prune out the restore node. output_op_names.append('init_all_tables') output_op_names.append('init_all_variables') output_op_names.append('save/control_dependency') output_op_names.append('save/restore_all') if IsTpu(device_options) and device_options.gen_init_op: output_op_names.append('tpu_init_op') graph_def = graph.as_graph_def() tf.logging.info('Pruning graph to output ops: %r', output_op_names) graph_def = tf.graph_util.extract_sub_graph(graph_def, output_op_names) if not device_options.retain_device_placement: # Clear the device so that the runtime can choose. tf.logging.info('Clearing device placement for: %s', device_options.device) for node in graph_def.node: node.ClearField('device') for function in graph_def.library.function: for node_def in function.node_def: node_def.ClearField('device') inference_graph_proto.graph_def.CopyFrom(graph_def) if export_path: with tf.io.gfile.GFile(export_path, 'w') as f: f.write(text_format.MessageToString(inference_graph_proto)) return inference_graph_proto
def _resource_apply_dense(self, grad, var): if grad is None: tf.logging.warning('Gradient is None for variable %s' % var.name) return [] grad_dtype = var.dtype # TODO(lepikhin): add to params grad = tf.cast(grad, grad_dtype) factored_dims = self._factored_dims(var.shape.as_list()) if factored_dims: vr = self.get_slot(var, 'vr') vc = self.get_slot(var, 'vc') else: v = self.get_slot(var, 'v') if self._beta1: m = self.get_slot(var, 'm') cond = tf.constant(True) def _Upd(c, x): if not self._cond_is_finite: return c c = tf.math.logical_and(c, tf.reduce_all(tf.math.is_finite(x))) c = tf.math.logical_and( c, tf.reduce_all(tf.math.logical_not(tf.math.is_inf(x)))) return c def _Wrap(fn, x, y): if not self._cond_is_finite: return fn(x, y) return tf.cond(cond, lambda: fn(x, y), lambda: x) with tf.variable_scope(var.name[:-2] + '/Adafactor'): grad_squared = tf.math.square(grad) + tf.cast( self._epsilon1, grad_dtype) cond = _Upd(cond, grad_squared) decay_rate = tf.cast(self._decay_rate, var.dtype) old_val = tf.identity( var) # TODO(lepikhin): introduce gradient dtype lr = GetLrValue(self._learning_rate) if self._multiply_by_parameter_scale: update_scale = self._parameter_scale(old_val) * tf.cast( lr, grad_dtype) else: update_scale = lr mixing_rate = tf.cast(1.0 - decay_rate, grad_dtype) update_scale = tf.cast(update_scale, grad_dtype) updates = [] if factored_dims: d0, d1 = factored_dims vr_axis, vc_axis = d0, d1 grad_squared_row_mean = tf.reduce_mean(grad_squared, axis=vr_axis) grad_squared_col_mean = tf.reduce_mean(grad_squared, axis=vc_axis) # new_vr = (decay_rate * vr + mixing_rate * grad_squared_row_mean) new_vr = vr * decay_rate + grad_squared_row_mean * mixing_rate # new_vc = (decay_rate * vc + mixing_rate * grad_squared_col_mean) new_vc = vc * decay_rate + grad_squared_col_mean * mixing_rate cond = _Upd(cond, new_vr) cond = _Upd(cond, new_vc) vr_update = _Wrap(tf.assign, vr, new_vr) vc_update = _Wrap(tf.assign, vc, new_vc) updates.extend([vr_update, vc_update]) long_term_mean = tf.reduce_mean(new_vr, -1, keepdims=True) r_factor = tf.math.rsqrt(new_vr / long_term_mean) c_factor = tf.math.rsqrt(new_vc) x = grad * tf.expand_dims(r_factor, vr_axis) * tf.expand_dims( c_factor, vc_axis) else: new_v = v * decay_rate + grad_squared * mixing_rate cond = _Upd(cond, new_v) v_update = _Wrap(tf.assign, v, new_v) updates.append(v_update) x = grad * tf.math.rsqrt(new_v) if self._clipping_threshold is not None: clipping_denom = tf.maximum( tf.constant(1.0, grad_dtype), py_utils.ReduceRms(x) / tf.constant(self._clipping_threshold, grad_dtype)) x /= clipping_denom subtrahend = x * update_scale if self._beta1: new_m = (m * tf.constant(self._beta1, dtype=grad_dtype) + subtrahend * tf.constant(1.0 - self._beta1, dtype=grad_dtype)) subtrahend = new_m cond = _Upd(cond, new_m) updates.append(_Wrap(tf.assign, m, new_m)) # It is critical to use assign_sub instead of tf.assign(var - subtrahend) # for the case of bfloat16 activations, so as to avoid repeatedly # rounding the slice value, which results in poor quality. cond = _Upd(cond, subtrahend) var_update = _Wrap(tf.assign_sub, var, subtrahend) updates.append(var_update) return tf.group(*updates)