def Apply(self, lr, var_grad): """Applies the gradient to the variable. Args: lr: A scalar or a callable that returns the learning rate. var_grad: A `.NestedMap` of (var, grad) pairs. Returns: The variable update op. Raises: RuntimeError: When `lr` is not a callable in Eager mode and user did not enable the scalar lr option. """ if py_utils.IsEagerMode() and not callable(lr): # Ensure that the learning rate is always updated in Eager mode. raise RuntimeError('In Eager mode, `lr` must be a callable.') # In Graph mode, always re-create the optimizer to remain consistent with # the old logic for the Graph trainer. # TODO(jiaweix): Recreating optimizers in Graph mode seems unnecessary. if self._optimizer is None or not py_utils.IsEagerMode(): self._optimizer = self.GetOptimizer(lr) def _Apply(): if not var_grad.Flatten(): tf.logging.warning( 'No gradients are available for optimizer.Apply(). ' 'Make sure this is expected.') return tf.no_op() if self.params.use_bf16_gradients_ar: return self._optimizer.apply_gradients( [(tf.cast(g, tf.float32), v) for (v, g) in var_grad.Flatten()], name='meta_backprop') else: return self._optimizer.apply_gradients( [(g, v) for (v, g) in var_grad.Flatten()], name='meta_backprop') clear_variable_scope = self.params.clear_variable_scope if clear_variable_scope is None: clear_variable_scope = not py_utils.IsEagerMode() if clear_variable_scope: # Many optimizers, e.g., Adam, Adagrad, etc., create # variables. We need to ensure name scope and variable scope are # cleared. Otherwise, tpu.batch_parallel does not work. with tf.name_scope(None): with tf.variable_scope( tf.VariableScope(use_resource=True, reuse=self.VarReuseForSlotVars())): var_update_op = _Apply() else: var_update_op = _Apply() if self.params.add_summary_in_apply: lr_value = GetLrValue(lr) self.AddSummary(lr_value, self._optimizer, var_grad) return var_update_op
def ReplicatedGenericInput(processor, num_replicas, replica_device_fn, **kwargs): """Builds a replicated input pipeline. This is similar to GenericInput, except that the input processing can be distributed across devices and then concatenated at the current device. Args: processor: see comments for GenericInput. num_replicas: the number of input processing replicas. Usually set to number of infeed hosts. replica_device_fn: a int -> string function that takes the replica index in range [0, num_replicas) and returns a TF device string, e.g., lambda i: '/task:{}/device:CPU:0'.format(i) **kwargs: additional keyword args for x_ops.generic_input. Returns: A tuple of (outputs, bucket_keys): - outputs: a NestedMap or a list of tensors, similar to `processor`'s return, except every tensor will have an additional dimension 0 that represents the batch dimension. The batch size will be (num_replicas * bucket_batch_limit[...]), i.e., kwargs['bucket_batch_limit'] specifies the per-replica batch size. - bucket_keys: a tf.int32 vector. Raises: RuntimeError: If called in pure Eager/tf.function mode without `generic_input_v2_key` defined. """ if num_replicas > 1 and 'bucket_batch_limit' in kwargs: assert all(b == max(kwargs['bucket_batch_limit']) for b in kwargs['bucket_batch_limit']) replica_outputs = [] if py_utils.IsEagerMode(): current_key = kwargs.pop('generic_input_v2_key', None) if current_key is None: raise RuntimeError(_MISSING_KEY_ERR) for replica_i in range(num_replicas): # Blend `replica_i` into the key for _GENERIC_CACHE_V2 to distinguish # different GenericInputV2 ops in the same Datasource object. if py_utils.IsEagerMode(): kwargs['generic_input_v2_key'] = (current_key, replica_i) replica_device = replica_device_fn(replica_i) with tf.device(replica_device): replica_outputs.append(GenericInput(processor, **kwargs)) output_nmaps, output_bucket_keys = zip(*replica_outputs) concat_nmap = tf.nest.map_structure(lambda *t: tf.concat(t, axis=0), *output_nmaps) concat_bucket_keys = tf.concat(output_bucket_keys, axis=0) return concat_nmap, concat_bucket_keys
def scalar(name, value, while_loop_reduce='mean'): """Adds summary scalar. Outside of tpu_summary.context() does nothing. Args: name: string name value: scalar tensor value while_loop_reduce: optional argument, determines what to do when this summary appears inside a tf.while_loop. Can be 'mean' or 'sum'. Raises: RuntimeError: if the function is called in Eager mode. """ if py_utils.IsEagerMode(): raise RuntimeError(EAGER_MODE_EXCEPTION_STR) assert while_loop_reduce in ('mean', 'sum') ctx = TpuSummaryContext.current() if ctx is None: return x = TpuSummaryScalar() x.name = str(name) x.value = tf.convert_to_tensor(value) if x.value.shape != (): # pylint: disable=g-explicit-bool-comparison raise ValueError('use tpu_summary.tensor() instead: %r' % value) x.name_scope = tf.get_default_graph().get_name_scope() x.while_loop_reduce = while_loop_reduce ctx.summary_tensors.append(x)
def _CreateVariableInternal(self, name: str, meta: CreateVariableMeta) -> None: """Immediately creates the variable described by `meta`. DO NOT OVERRIDE. For internal use only. Subclasses of BaseLayer should use self.CreateVariable() to create variables. Args: name: The variable name. meta: A CreateVariableMeta describing the variable to be created. """ meta.kwargs.setdefault('default_seed', self.params.random_seed) var = py_utils.CreateVariable(name, meta.var_params, **meta.kwargs) self._private_vars[name] = var if py_utils.IsEagerMode(): # With eager trainer, always use the variable directly. value = var else: if self.cluster.params.worker.gpus_per_replica > 0: # On GPU (which always trains a single step per session.run()), # reference a tensor in FProp to cache it on device and avoid extraneous # sends from reading variables from ps multiple times. with tf.device(var.device): value = tf.identity(var, name=name) elif self.params.add_name_to_theta: value = tf.identity(var, name=name) else: value = var if meta.theta_fn is not None: self._private_theta_fn[name] = meta.theta_fn self._private_theta[name] = value
def _MaybeConstructSharedModel(self, train_cfg): """Construct a single shared copy of the model if this is a MultiTaskModel. If the share_model_object parameter is set, for MultiTaskModels, we create a MultiTaskSubModel for each task, but construct the model only once. Args: train_cfg: The params for a SingleTaskModel or MultiTaskModel. Returns: A MultiTaskModel, if train_cfg is a MultiTaskModel params object. """ if not issubclass(train_cfg.cls, base_model.MultiTaskModel): return None if not train_cfg.share_model_object: return None with self._cluster, tf.container( self._container_id), contextlib.ExitStack() as stack: if not py_utils.IsEagerMode(): stack.enter_context(self._graph.as_default()) stack.enter_context(tf.device(self._cluster.GetPlacer())) with py_utils.VariableStore(), py_utils.VariableRenameScope( self._variable_renaming_rules): py_utils.GetOrCreateGlobalStepVar() shared_model = train_cfg.Instantiate() return shared_model
def available_devices(self): """Returns all compute devices available in a 2D array. Returns: A 2D array (python list of python lists) of strings. ret[i, j] is the j-th visible device on i-th visible replica. """ from lingvo.core import py_utils # pylint: disable=g-import-not-at-top if self.job_spec.tpus_per_replica and not py_utils.IsEagerMode(): ret = np.empty((1, self.num_devices_per_split), np.object) for i in range(self.num_devices_per_split): ret[0, i] = tf.tpu.core(i) return ret if self.job == 'trainer' and self.asynchronous: # In async mode, each trainer task can only use its own devices. return self.ListDevices(self.job_spec)[self.task:(self.task + 1), :] if self.job == 'trainer_client' and self.synchronous: # In sync mode, trainer_client can use every device. return self.ListDevices(self.job_spec) if self.job == 'executor_tpu' and self.synchronous: # executor_tpu can use every device. return self.ListDevices(self.job_spec) if self.job in ('controller', 'train_summaries', 'evaler', 'decoder'): # Our current policy is that each controller/evaler/decoder task # only uses 1 replica. return self.ListDevices(self.job_spec)[self.task:(self.task + 1), :] assert False, (self.job, self.mode)
def _GetSession(self, **kwargs): if py_utils.IsEagerMode(): raise ValueError('_GetSession is not supported in eager mode.') graph = kwargs.pop('graph', self._graph) return tf.Session(self._tf_master, graph=graph, config=py_utils.SessionConfig(**kwargs))
def _CreateCheckpointer(self, train_dir, model, init_op=None): """Wrapper method for override purposes.""" if py_utils.IsEagerMode(): if FLAGS.write_v2_checkpoints: return checkpointer.EagerCheckpointerV2( train_dir, model, init_op) return checkpointer.EagerCheckpointerV1(train_dir, model, init_op) return checkpointer.Checkpointer(train_dir, model, init_op)
def UpdateClusterParamsFromFlags(self, cluster, job_name): """Update `cluster` with a training cluster configuration from flags.""" cluster.mode = FLAGS.mode cluster.job = job_name cluster.task = FLAGS.task cluster.do_eval = job_name in ['evaler', 'decoder'] cluster.logdir = FLAGS.logdir cluster.controller.name = FLAGS.controller_job cluster.controller.gpus_per_replica = FLAGS.controller_gpus cluster.worker.name = FLAGS.worker_job cluster.worker.replicas = FLAGS.worker_replicas cluster.worker.gpus_per_replica = FLAGS.worker_gpus cluster.worker.tpus_per_replica = FLAGS.worker_tpus cluster.worker.num_tpu_hosts = FLAGS.worker_num_tpu_hosts cluster.worker.devices_per_split = FLAGS.worker_split_size if FLAGS.additional_worker_jobs: for additional_job in FLAGS.additional_worker_jobs: cluster.worker.additional_worker_names.append(additional_job) if FLAGS.tpu: job_name = cluster.worker.name.replace('/job:', '', 1) worker_hosts = _GetClusterSpecDict()[job_name] if FLAGS.additional_worker_jobs: for additional_job in cluster.worker.additional_worker_names: additional_job_name = additional_job.replace( '/job:', '', 1) worker_hosts.extend( _GetClusterSpecDict()[additional_job_name]) cluster.worker.targets = ','.join('grpc://{}'.format(host) for host in worker_hosts) cluster.ps.name = FLAGS.ps_job cluster.ps.replicas = FLAGS.ps_replicas cluster.ps.gpus_per_replica = FLAGS.ps_gpus cluster.input.name = FLAGS.input_job cluster.input.replicas = FLAGS.input_replicas cluster.input.targets = FLAGS.input_targets if py_utils.IsEagerMode(): cluster.evaler.name = '/job:localhost' cluster.decoder.name = '/job:localhost' else: cluster.evaler.name = FLAGS.evaler_job cluster.decoder.name = FLAGS.decoder_job cluster.evaler.replicas = FLAGS.evaler_replicas cluster.evaler.gpus_per_replica = FLAGS.evaler_gpus cluster.decoder.replicas = FLAGS.decoder_replicas cluster.decoder.gpus_per_replica = FLAGS.decoder_gpus cluster.tf_data_service_address = FLAGS.tf_data_service_address cluster.add_summary = FLAGS.add_summary cluster.reporting_job = FLAGS.vizier_reporting_job
def Apply(self, lr, var_grad): """Applies the gradient to the variable. Args: lr: A scalar or callable that returns the base learning rate. var_grad: A `.NestedMap` of (var, grad) pairs. Returns: The variable update op. Raises: RuntimeError: When `lr` is not a callable in Eager mode. """ # In Graph mode, always re-create the optimizer to remain consistent with # the old logic for the Graph trainer. # TODO(jiaweix): Recreating optimizers in Graph mode seems unnecessary. if self._optimizer is None or not py_utils.IsEagerMode(): self._optimizer = self.GetOptimizer(lr) def _Apply(): return self._optimizer.apply_gradients( [(g, v) for (v, g) in var_grad.Flatten()], name='meta_backprop') clear_variable_scope = self.params.clear_variable_scope if clear_variable_scope is None: clear_variable_scope = not py_utils.IsEagerMode() if clear_variable_scope: # Many optimizers, e.g., Adam, Adagrad, etc., create # variables. We need to ensure name scope and variable scope are # cleared. Otherwise, tpu.batch_parallel does not work. with tf.name_scope(None): with tf.variable_scope( tf.VariableScope(use_resource=True, reuse=self.VarReuseForSlotVars())): var_update_op = _Apply() else: var_update_op = _Apply() if self.params.add_summary_in_apply: lr_value = GetLrValue(lr) self.AddSummary(lr_value, self._optimizer, var_grad) return var_update_op
def Finalize(self): """Finishes creation of the overall figure, returning the image summary.""" rendered = self.FinalizeImage() if py_utils.IsEagerMode(): return tf.compat.v2.summary.image(self._name, rendered, max_outputs=self._max_outputs, step=py_utils.GetGlobalStep()) else: return tf.summary.image(self._name, rendered, max_outputs=self._max_outputs)
def pw_tensor(name, value): """Adds summary tensor.""" if py_utils.IsEagerMode(): raise RuntimeError(EAGER_MODE_EXCEPTION_STR) ctx = TpuSummaryContext.current() if ctx is None: return x = PwTpuSummaryTensor() x.name = str(name) x.value = tf.convert_to_tensor(value) x.name_scope = tf.get_default_graph().get_name_scope() ctx.summary_tensors.append(x)
def tensor(name, value): """Adds summary tensor. Similar to scalar() but allows other shapes.""" if py_utils.IsEagerMode(): raise RuntimeError(EAGER_MODE_EXCEPTION_STR) ctx = TpuSummaryContext.current() if ctx is None: return x = TpuSummaryScalar() x.name = str(name) x.value = tf.convert_to_tensor(value) x.name_scope = tf.get_default_graph().get_name_scope() x.while_loop_reduce = 'stack' ctx.summary_tensors.append(x)
def CreateVariable(self, name: str, var_params: hyperparams.Params, **kwargs) -> None: """Create a variable of this layer according to the parameter `var_params`. E.g.:: def __init__(self, ...): # A layer's constructor self.CreateVariable( 'weight', py_utils.WeightParams(shape=[100, 100])) Args: name: Variable name which is used as the key into vars/theta. var_params: `Params` used to create the variable. **kwargs: Keyword args passed to `.py_utils.CreateVariable`. """ kwargs.setdefault('default_seed', self.params.random_seed) if self.params.device_mesh is not None: if (len([dim for dim in var_params.shape if dim > 1]) > 1 and var_params.tensor_split_dims_mapping is None): tf.logging.warning( 'tensor_split_dims_mapping missing for %s.%s: shape=%s', self.path, name, var_params.shape) self._CheckName(name) if (self.params.skip_lp_regularization and py_utils.SKIP_LP_REGULARIZATION not in var_params.collections): var_params = py_utils.WeightParams( shape=var_params.shape, dtype=var_params.dtype, init=var_params.init, collections=(var_params.collections + [py_utils.SKIP_LP_REGULARIZATION])) self._var_symbolic_shape_map[name] = var_params.shape var = py_utils.CreateVariable(name, var_params, **kwargs) self._private_vars[name] = var if py_utils.IsEagerMode(): # With eager trainer, always use the variable directly. value = var else: if self.cluster.params.worker.gpus_per_replica > 0: # On GPU (which always trains a single step per session.run()), # reference a tensor in FProp to cache it on device and avoid extraneous # sends from reading variables from ps multiple times. with tf.device(var.device): value = tf.identity(var, name=name) else: value = var self._private_theta[name] = value
def RunSave(sess, global_step): # Run TPU embedding retrieve ops. # NOTE: this is expensive, so only run it when we're checkpointing. if not py_utils.IsEagerMode(): tf.logging.info('Retrieve params.') sess.run(self._retrieve_ops) tf.logging.info('Retrieve params done.') # Save program state first, so it's recoverable after we restore # from checkpoint. for program in self._programs: program.SaveProgramState(sess, global_step) # Save the checkpoints asynchronously. self._checkpointer.Save(sess, global_step, sync=False)
def __init__(self, train_dir, models, init_op=None, train_params=None, save_only=False): """Initialize Checkpointer. Args: train_dir: Training directory for saving checkpoints. models: One or a list of BaseModel instances. Cannot be empty. If there are more than one models and `train_params` is None, the save intervals will be only determined by the first model. init_op: The initialize variables op. If unset, it will call tf.global_variables_initializer(). train_params: If specified, use these training params instead of those in the `model`. save_only: This checkpointer is only intended for saving checkpoints. """ self._train_dir = train_dir self._save_only = save_only if init_op: self._init_op = init_op else: self._init_op = tf.global_variables_initializer() self._save_path = os.path.join(self._train_dir, 'ckpt') if not isinstance(models, (list, tuple)): models = [models] self._models = models if train_params: self._train_params = train_params else: self._train_params = models[0].params.train self._next_checkpoint_seconds = 0 self._save_interval_seconds = self._train_params.save_interval_seconds self._save_interval_steps = self._train_params.save_interval_steps self._prev_ckpt_step = None self._saver = self._GetSaver() if not py_utils.IsEagerMode(): self._uninitialized_vars = tf.report_uninitialized_variables( tf.global_variables()) self._BuildInitFromCheckpointRules()
def _WaitTillInit(job=None): """Wait until the model is ready.""" try: if py_utils.IsEagerMode(): topology = tf.tpu.experimental.initialize_tpu_system( resolver) else: # tpu.initialize_system() is called with None as embedding_config, as # embedding_config is not available yet. Later in _Loop, it is called # with the correct embedding_config. Since it cannot be called twice # in the same graph with different embedding_config, we use a # dummy_graph here. dummy_graph = tf.Graph() with dummy_graph.as_default(): tpu_initialize_system_op = tf.tpu.initialize_system( embedding_config=None, job=job) with self._GetSession(graph=dummy_graph) as sess: topology = sess.run(tpu_initialize_system_op) if train_cfg.train.tpu_computation_shape is None: computation_shape = py_utils.ComputationShape( num_devices_per_split, topology) else: computation_shape = train_cfg.train.tpu_computation_shape assert num_devices_per_split == np.prod(computation_shape) if train_cfg.train.tpu_device_order_mode is None: self.device_assignment = device_assignment_lib.device_assignment( topology, computation_shape=computation_shape, num_replicas=data_parallelism) else: self.device_assignment = device_assignment_lib.device_assignment( topology, computation_shape=computation_shape, num_replicas=data_parallelism, device_order_mode=train_cfg.train.tpu_device_order_mode ) py_utils.SetTpuDeviceAssignment(self.device_assignment, job) tf.logging.info('device_assignment.core_assignment: %s', str(self.device_assignment.core_assignment)) tf.logging.info( 'device_assignment.topology.device_coordinates: %s', str(self.device_assignment.topology.device_coordinates)) except py_utils.transient_tf_errors as e: tf.logging.info('TPU initialization failed: %s', e) raise
def Stop(self, session=None): """Returns true if stop criterion is met.""" if self._node is not None: if py_utils.IsEagerMode(): self._best_step, self._last_step = self._node() else: self._best_step, self._last_step = session.run(self._node()) s = ( self._last_step - self._best_step > self.params.window and self._last_step >= self.params.min_steps) if self.params.verbose: tf.logging.info( 'early stop check: best_step=%d, last_step=%d, stop=%d', self._best_step, self._last_step, s) return s else: return False
def CollectVarHistogram(vs_gs): """Adds histogram summaries for variables and gradients.""" for name, (var, grad) in vs_gs.FlattenItems(): name = py_utils.SanitizeScopeKey(name) with tf.device(var.device), tf.name_scope(name + '/summary'): if isinstance(grad, tf.IndexedSlices): var = tf.gather(var, grad.indices) grad = grad.values if var.dtype.is_complex: var = tf.abs(var) grad = tf.abs(grad) if py_utils.IsEagerMode(): histogram_v2(f'var_hist/{name}', var) histogram_v2(f'grad_hist/{name}', grad) else: histogram(f'var_hist/{name}', var) histogram(f'grad_hist/{name}', grad)
def AddNormSummary(name, vs_gs): """"Returns and creates summary for norms of vs and their gradients gs. Args: name: A name string for summary. vs_gs: A `.NestedMap` or a list of `.NestedMap` of (variable, gradient). Returns: norm of variables, and norm of gradients. """ flatten = py_utils.Flatten(vs_gs) v_norm = tf.sqrt(py_utils.SumSquared([v for (v, _) in flatten])) g_norm = tf.sqrt(py_utils.SumSquared([g for (_, g) in flatten])) if py_utils.IsEagerMode(): scalar_v2(f'var_norm/{name}', v_norm) scalar_v2(f'grad_norm/{name}', g_norm) else: scalar(f'var_norm/{name}', v_norm) scalar(f'grad_norm/{name}', g_norm) return v_norm, g_norm
def _ShouldStop(self, sess=None, step=None, check_early_stop=True): """Check if the runner should stop. Args: sess: tf.Session. step: The current GlobalStep. check_early_stop: Whether or not we want to check the EarlyStop condition. In TPU-training, we don't want to check this at the every step granularity in the enqueue thread, as this may starve the TPU training loop which by default operates at the 1000 steps granularity. Returns: Whether runner should stop. """ if step is None: if py_utils.IsEagerMode(): step = py_utils.GetGlobalStep().numpy() else: step = sess.run(py_utils.GetGlobalStep()) if step >= self.params.train.max_steps: tf.logging.info('ShouldStop: step:%6d params.train.max_steps:%6d', step, self.params.train.max_steps) return True if (self._max_steps_for_early_stop and step >= self._max_steps_for_early_stop): tf.logging.info( 'ShouldStop: step:%6d _max_steps_for_early_stop:%6d', step, self._max_steps_for_early_stop) return True if check_early_stop and self._ShouldEarlyStop(sess): tf.logging.info('ShouldStop: Early stopping.') return True if self._trial and self._trial.ShouldStop(): tf.logging.info('ShouldStop: Trial finished.') return True return False
def RetryLoop(): try: if py_utils.IsEagerMode(): global_step = py_utils.GetGlobalStep().numpy else: global_step = sess.run(py_utils.GetGlobalStep()) except tf.errors.FailedPreconditionError as e: tf.logging.info( '%s: Probably the expected race on global_step: %s', self._job_name, e) raise msg = 'step:%6d' % global_step self._SetStatusMessage(msg) if start_up_delay_steps: if global_step < start_up_delay_steps: msg = 'global step (%d) has not reached start up delay steps (%d)' % ( global_step, self._start_up_delay_steps) tf.logging.info('%s: %s', self._job_name, msg) raise tf.errors.FailedPreconditionError(node_def=None, op=None, message=msg) return global_step
def GetNext(self): """Return input batch from p.file_patterns list weighted by p.weights. Examples in the batch will be mixed together from different file_pattern source proportionally to the weights. Returns: An input batch. """ p = self.params file_patterns = ','.join(p.file_pattern) if p.file_type: file_patterns = f'{p.file_type}:{file_patterns}' extra_args = dict() if p.source_id_offset != 0: extra_args['input_source_id_offset'] = p.source_id_offset # In TF1 mode, the python method `GetNext` only gets called once for each # DataSource object during graph construction. # In TF2 mode, however, `GetNext` can be called many times. We must specify # keys to uniquely identify its `GenericInputV2` resource. This # ensures that the resource is properly reused. if py_utils.IsEagerMode(): # The current DataSource object is used as the key to GenericInputV2 ops. extra_args['generic_input_v2_key'] = self if p.weights: # Within-batch mixing. batch = self._input_generator._DataSourceFromFilePattern( # pylint: disable=protected-access file_patterns, input_source_weights=p.weights, **extra_args) else: # Default. batch = self._input_generator._DataSourceFromFilePattern( # pylint: disable=protected-access file_patterns, **extra_args) return batch
def Apply(self, metrics, vmap, gradient_mask=None, gradient_adjuster=None): """Computes updates on 'vmap' to optimize 'loss'. TODO(rpang): explore merging gradient_mask and gradient_adjuster. Args: metrics: A Dict[str, (value, weight)], from which loss can be extracted according to p.loss_name. vmap: A `.NestedMap` object containing variables to optimize. gradient_mask: if not None, a dict mapping variable names to a 0/1 scalar. gradient_adjuster: if not None, a function that mutates a given var_grads. Returns: (losses, op, eval_metrics), where - losses is a list of scalar tensors; - op is a tf.Operation to update variables; - eval_metrics is a Dict[str, (value, weight)], where each value/weight is a scalar tensor. """ # We apply gradients outside the name_scope to maintain backwards # compatibility on variables created by self.optimizer.Apply(). losses, var_grads, eval_metrics = self._ComputeLossesAndGradients( metrics, vmap) if 'tpu_embedding_var_grads' in var_grads: tpu_embedding_var_grads = var_grads.tpu_embedding_var_grads del var_grads.tpu_embedding_var_grads tpu_embedding_collection = py_utils.GetTpuEmbeddingGraphCollection( )[0] assert tpu_embedding_collection tpu_emb_update_op, stats = tpu_embedding_collection.ApplyGradients( py_utils.GetTaskCallScope(), tpu_embedding_var_grads.Transform( lambda var_grad: var_grad.grad)) eval_metrics.update(stats) else: tpu_emb_update_op = tf.no_op() assert py_utils.GetGlobalStep() is not None lr = self.LearningRate() var_grads, stats = self.AdjustGradients( var_grads, gradient_mask=gradient_mask, gradient_adjuster=gradient_adjuster) eval_metrics.update(stats) self._var_grads = var_grads eval_metrics['learning_rate'] = (tf.convert_to_tensor(lr), tf.convert_to_tensor(1.)) if py_utils.IsEagerMode(): # Use a callback for learning_rate so we can keep just one instance of # optimizer object in the whole trainer lifetime in Eager mode. lr_or_callable = self.LearningRate else: lr_or_callable = lr with self._SelfVariableScope(): var_update_op = tf.group([ tpu_emb_update_op, self.optimizer.Apply(lr_or_callable, var_grads) ]) return losses, var_update_op, eval_metrics
def Apply(self, lr, var_grad): """For each optimizer, apply the gradient to the variable. Args: lr: A scalar. The base learning rate. var_grad: A `.NestedMap` of (var, grad) pairs. Returns: The variable update op. Raises: Exception: When the regex overlaps with or does not cover all variables. """ # Override inherited GetOptimizer even though learning rate is unused. if not self._tf_optimizer_map or not py_utils.IsEagerMode(): self._tf_optimizer_map = self.GetOptimizer(0) var_grad_map = {regex: [] for regex in self._lingvo_optimizer_map} for (v, g) in var_grad.Flatten(): regex_match = 0 for regex in self._lingvo_optimizer_map: if re.match(regex, v.name): var_grad_map[regex].append((g, v)) regex_match += 1 if regex_match == 0: var_grad_map['default_optimizer'].append((g, v)) if regex_match > 1: raise Exception( 'Variable {} is matched {} times by regex {}'.format( v.name, regex_match, list(self._lingvo_optimizer_map.keys()))) def _Apply(): """Use the matched optimizer to apply the gradients.""" train_ops = [] non_default_regex = [ regex for regex in self._lingvo_optimizer_map if regex != 'default_optimizer' ] for regex in self._lingvo_optimizer_map: if var_grad_map[regex]: opt = self._tf_optimizer_map[regex] train_ops.append(opt.apply_gradients(var_grad_map[regex])) # pylint: disable=cell-var-from-loop, g-long-lambda if regex == 'default_optimizer': filtered_var_grad = var_grad.FilterKeyVal( lambda k, v: any([ re.match(i, v.var.name) for i in non_default_regex ])) else: filtered_var_grad = var_grad.FilterKeyVal( lambda k, v: (re.match(regex, v.var.name))) # pylint: enable=cell-var-from-loop, g-long-lambda self._lingvo_optimizer_map[regex].AddSummary( self._lr_map[regex], opt, filtered_var_grad) return tf.group(*train_ops, name='composite_optimizer_train_op') # Many optimizers, e.g., Adam, Adagrad, etc., create # variables. We need to ensure name scope and variable scope are # cleared. Otherwise, tpu.batch_parallel does not work. var_reuse = False if py_utils.GetOpportunisticVariableReuse(): var_reuse = tf.AUTO_REUSE with tf.name_scope(None): with tf.variable_scope( tf.VariableScope(use_resource=True, reuse=var_reuse)): var_update_op = _Apply() return var_update_op
def _Loop(self): with self._cluster, tf.container( self._container_id), contextlib.ExitStack() as stack: if py_utils.IsEagerMode(): sess = None else: sess = self._GetSession(disable_meta_optimizer=FLAGS. disable_meta_optimizer_in_executor) stack.enter_context(sess) sess.reset(self._tf_master) config_proto = (self._tpu_embedding.config_proto if self._tpu_embedding is not None else None) for worker in self._cluster.all_worker_names: sess.run( tf.tpu.initialize_system(embedding_config=config_proto, job=worker)) # Initialize the variables first, if needed. # Need to call create global step again because this is run in a thread. py_utils.GetOrCreateGlobalStepVar() if self._checkpoint_to_load: path = self._checkpointer.RestoreFromPath( sess, checkpoint_path=self._checkpoint_to_load) else: path = self._checkpointer.Restore(sess) # Run the compiles in parallel. compile_fns = [] for program in self._programs: program.LoadProgramState(path, sess) compile_fns += [program.Compile] threadpool = multiprocessing.dummy.Pool(len(compile_fns)) futures = [] tf.logging.info( f'Compiling {len(compile_fns)} programs in parallel.') for fn in compile_fns: futures += [threadpool.apply_async(fn, args=(sess, ))] for future in futures: future.get() if not py_utils.IsEagerMode(): sess.run(self._initialize_tables) sess.run(self._initialize_local_vars) sess.run(self._load_ops) program_schedule = None # Threadpool to run code in programs async with TF Sessions (on TPUs). # This unblocks TPU from waiting for CPU processing on "main" thread, and # saves time for time-consuming CPU steps (e.g. PostProcessDecodeOut). program_threadpool = multiprocessing.dummy.Pool(1) start_time = time.time() while True: cycle_start_time = time.time() if py_utils.IsEagerMode(): global_step = py_utils.GetGlobalStep().numpy() else: global_step = sess.run(py_utils.GetGlobalStep()) def RunSave(sess, global_step): # Run TPU embedding retrieve ops. # NOTE: this is expensive, so only run it when we're checkpointing. if not py_utils.IsEagerMode(): tf.logging.info('Retrieve params.') sess.run(self._retrieve_ops) tf.logging.info('Retrieve params done.') # Save program state first, so it's recoverable after we restore # from checkpoint. for program in self._programs: program.SaveProgramState(sess, global_step) # Save the checkpoints asynchronously. self._checkpointer.Save(sess, global_step, sync=False) checkpoint_write_secs = 0.0 if not self._ml_perf_log and self._checkpointer.ShouldSave( global_step): checkpoint_write_start = time.time() RunSave(sess, global_step) checkpoint_write_secs = time.time( ) - checkpoint_write_start # If a task is explicitly selected, only run the programs associated # with that task. if self._single_task_mode or self._model_task_name: tf.logging.info('Single task mode: %s', self._model_task_name) program_schedule = self._program_schedule_dict[ self._model_task_name] else: # Otherwise, sample a task. model_task = self.task_scheduler.Sample(global_step) tf.logging.info('Sampled %s', model_task) program_schedule = self._program_schedule_dict[model_task] done, train_time_in_secs, eval_time_in_secs = program_schedule.Run( sess, program_threadpool) executor_cycle_in_secs = time.time() - cycle_start_time self._ExportMetrics( executor_cycle_secs=executor_cycle_in_secs, executor_train_time_secs=train_time_in_secs, executor_eval_time_secs=eval_time_in_secs, checkpoint_write_secs=checkpoint_write_secs, ) def _ShutDown(): # Wait for the save ops to finish before exit. self._checkpointer.Sync() program_threadpool.close() program_threadpool.join() tf.logging.info( 'Program schedule told us to stop.\n' 'Shutting down programs after running %f seconds.', time.time() - start_time) program_schedule.Shutdown() if done: tf.logging.info( 'Program done after %f seconds. Waiting for threads to end.', time.time() - start_time) _ShutDown() return if py_utils.IsEagerMode(): global_step = py_utils.GetGlobalStep().numpy() else: global_step = sess.run(py_utils.GetGlobalStep()) if self._ShouldStop(sess, global_step): tf.logging.info('Training finished.') if not self._ml_perf_log: RunSave(sess, global_step) tf.logging.info( 'Program finished after %f seconds. Waiting for threads to end.', time.time() - start_time) _ShutDown() return
def Evaler(self): return eager_runners.Evaler if py_utils.IsEagerMode( ) else runners.Evaler
def _GetSession(self, **kwargs): if py_utils.IsEagerMode(): raise ValueError('Eager mode does not support _GetSession.') return super()._GetSession(cluster_def=self._worker_cluster_def, **kwargs)
def Decoder(self): return eager_runners.Decoder if py_utils.IsEagerMode( ) else runners.Decoder
def __init__(self, train_cfg, ps_params_dict, *args, **kwargs): """Construct an ExecutorTpu BaseRunner. Args: train_cfg: SingleTaskModelParams or MultiTaskModelParams ps_params_dict: A dict of top-level task name -> ProgramSchedule params, if train_cfg is a SingleTaskModelParams, we expect only one entry. *args: List args to pass through to BaseRunner. **kwargs: keyword args to pass through to BaseRunner. """ if py_utils.IsEagerMode(): assert tf.executing_eagerly() tf.logging.info(f'FLAGS.tf_master: {FLAGS.tf_master}') # Connect to the TPU runtime. resolver = tf.distribute.cluster_resolver.TPUClusterResolver( FLAGS.tf_master, job_name=FLAGS.worker_job[len('/job:'):]) tf.config.experimental_connect_to_cluster(resolver) super().__init__(train_cfg, *args, **kwargs) data_parallelism = self._cluster.num_splits_per_client assert data_parallelism num_devices_per_split = self._cluster.num_devices_per_split tf.logging.info('data_parallelism: %d, num_devices_per_split: %d', data_parallelism, num_devices_per_split) self.task_scheduler = None self._checkpoint_dir = os.path.join(self._logdir, 'train') self._variable_renaming_rules = [] self._ml_perf = None # If this is a multi-task model, grab the params for the TaskScheduler. if issubclass(train_cfg.cls, base_model.SingleTaskModel): tf.logging.info('single_task_model') assert len(ps_params_dict) == 1 self._model_task_name = list(ps_params_dict.keys())[0] self._single_task_mode = True elif issubclass(train_cfg.cls, base_model.MultiTaskModel): tf.logging.info('multi_task_model') if issubclass(train_cfg.cls, multitask_model.RegExSharedVariableModel): self._variable_renaming_rules = train_cfg.variable_renaming_rules if train_cfg.task_schedule is None: task_schedule_params = task_scheduler.ConstantScheduler.Params( ) task_schedule_params.task_probs = sorted( list(train_cfg.task_probs.IterParams())) else: task_schedule_params = train_cfg.task_schedule self.task_scheduler = task_schedule_params.Instantiate() self._single_task_mode = False else: tf.logging.fatal( 'Model %s is not a sub-class of SingleTaskModel or MultiTaskModel', train_cfg.cls) tf.logging.info('train_cfg.cls: %s', train_cfg.cls) self._WriteToLog(train_cfg.ToText(), self._checkpoint_dir, 'trainer_params.txt') self._WriteToLog( text_format.MessageToString(train_cfg.ToProto(), as_utf8=True), self._checkpoint_dir, 'trainer_params.pbtxt') if self._ml_perf is not None: self._ml_perf_log = True mlp_log.mlperf_print(key='benchmark', value=self._ml_perf.benchmark_name) else: self._ml_perf_log = False train_cfg = self.params @py_utils.RetryOnTransientTfError() def _WaitTillInit(job=None): """Wait until the model is ready.""" try: if py_utils.IsEagerMode(): topology = tf.tpu.experimental.initialize_tpu_system( resolver) else: # tpu.initialize_system() is called with None as embedding_config, as # embedding_config is not available yet. Later in _Loop, it is called # with the correct embedding_config. Since it cannot be called twice # in the same graph with different embedding_config, we use a # dummy_graph here. dummy_graph = tf.Graph() with dummy_graph.as_default(): tpu_initialize_system_op = tf.tpu.initialize_system( embedding_config=None, job=job) with self._GetSession(graph=dummy_graph) as sess: topology = sess.run(tpu_initialize_system_op) if train_cfg.train.tpu_computation_shape is None: computation_shape = py_utils.ComputationShape( num_devices_per_split, topology) else: computation_shape = train_cfg.train.tpu_computation_shape assert num_devices_per_split == np.prod(computation_shape) if train_cfg.train.tpu_device_order_mode is None: self.device_assignment = device_assignment_lib.device_assignment( topology, computation_shape=computation_shape, num_replicas=data_parallelism) else: self.device_assignment = device_assignment_lib.device_assignment( topology, computation_shape=computation_shape, num_replicas=data_parallelism, device_order_mode=train_cfg.train.tpu_device_order_mode ) py_utils.SetTpuDeviceAssignment(self.device_assignment, job) tf.logging.info('device_assignment.core_assignment: %s', str(self.device_assignment.core_assignment)) tf.logging.info( 'device_assignment.topology.device_coordinates: %s', str(self.device_assignment.topology.device_coordinates)) except py_utils.transient_tf_errors as e: tf.logging.info('TPU initialization failed: %s', e) raise if self._ml_perf_log: mlp_log.mlperf_print(key='init_start', value=None) if len(self._cluster.all_worker_names) > 1: for worker in self._cluster.all_worker_names: _WaitTillInit(worker) else: _WaitTillInit(None) shared_model = self._MaybeConstructSharedModel(train_cfg) self._program_schedule_dict = {} self._programs = [] self._ckpt_programs = [] self._checkpoint_to_load = None with self._cluster: # Create the ExponentialMovingAverage singleton shared by all programs, if # applicable. ema = py_utils.CreateEMAForModel(train_cfg, self._global_step_var) for task_string, program_schedule_params in ps_params_dict.items(): program_schedule_params.logdir = self._logdir program_schedule_params.num_splits_per_client = data_parallelism program_schedule_params.task_name = task_string # If the model was created above, we'll inject it here as a # shared_model. ps = program_schedule_params.Instantiate( shared_model=shared_model, trial=self._trial, ema=ema, tf_master=self._tf_master) self._program_schedule_dict[task_string] = ps tf.logging.info('program_schedule_params: %s', program_schedule_params.ToText()) self._programs += ps.Programs() if ps.train_program: self._ckpt_programs.append(ps.train_program) else: self._ckpt_programs += ps.Programs() if program_schedule_params.ml_perf.benchmark_name is not None: self._ml_perf = program_schedule_params.ml_perf if ('checkpoint_to_load' in program_schedule_params and program_schedule_params.checkpoint_to_load): if (self._checkpoint_to_load and (self._checkpoint_to_load != program_schedule_params.checkpoint_to_load)): raise ValueError( f'Multiple values found for checkpoint_to_load: ' f'{self._checkpoint_to_load}, ' f'{program_schedule_params.checkpoint_to_load}.') self._checkpoint_to_load = program_schedule_params.checkpoint_to_load tf.logging.info('num_programs: %d', len(self._programs)) # When running in a vizier trainer, the executor reports infeasiable runs # in case of errors. The programs report metrics and normal completions. for program in self._programs: if program._should_report_metrics: self._should_report_metrics = True with self._cluster, tf.container( self._container_id), contextlib.ExitStack() as stack: if not py_utils.IsEagerMode(): stack.enter_context(self._graph.as_default()) if FLAGS.use_tpu_mirrored_vars: resolver = tf.distribute.cluster_resolver.TPUClusterResolver( FLAGS.tf_master, job_name=FLAGS.worker_job[len('/job:'):]) self._tpu_strategy = tf.distribute.experimental.TPUStrategy( resolver, device_assignment=self.device_assignment) stack.enter_context(self._tpu_strategy.scope()) stack.enter_context( tpu_strategy._TPUReplicaContext(self._tpu_strategy)) else: stack.enter_context(tf.device(self._cluster.GetPlacer())) if FLAGS.pdb_on_exception: stack.enter_context(pdb_wrapper.catch_post_mortem()) with py_utils.VariableStore(), py_utils.VariableRenameScope( self._variable_renaming_rules): # `BuildTpuSubgraph` has to be called before checkpoint restore, so that # the optimizer slot variables are guaranteed to be initialized before # they get loaded. Otherwise, the optimizers' slot variables will not # be properly loaded when V1 checkpoint is used. for program in self._programs: program.BuildTpuSubgraph() py_utils.ClearTpuSummaryTensors() if not py_utils.IsEagerMode(): self._initialize_tables = tf.tables_initializer() self._initialize_local_vars = tf.local_variables_initializer() self._initialize_global_vars = tf.global_variables_initializer( ) checkpointer_models = [ program.GetModel() for program in self._ckpt_programs ] if py_utils.IsEagerMode(): if FLAGS.use_v2_checkpoints_in_eager: self._checkpointer = checkpointer.EagerCheckpointerV2( self._checkpoint_dir, models=checkpointer_models, init_op=None, train_params=train_cfg.train, save_only=False) else: self._checkpointer = checkpointer.EagerCheckpointerV1( self._checkpoint_dir, models=checkpointer_models, init_op=None, train_params=train_cfg.train, save_only=False) else: self._checkpointer = checkpointer.Checkpointer( self._checkpoint_dir, models=checkpointer_models, init_op=self._initialize_global_vars, train_params=train_cfg.train, save_only=False) for program in self._programs: program.SetStatusMessageFn(self._SetStatusMessage) tpu_embedding_collection = ( tpu_embedding_layers.TpuEmbeddingCollection.Get()) self._load_ops = tpu_embedding_collection.load_ops self._retrieve_ops = tpu_embedding_collection.retrieve_ops self._tpu_embedding = tpu_embedding_collection.tpu_embedding