def Run(self, sess): tf.logging.info('Executing decode program for %s.', self._task_name) gsteps = py_utils.GetGlobalStep() global_step = sess.run(gsteps) if self._ml_perf_log: steps_per_epoch = self._ml_perf.steps_per_epoch epoch = int(global_step) // steps_per_epoch mlp_log.mlperf_print('eval_start', None, metadata={'epoch_num': (epoch + 1)}) infeed_future = self._infeed_pool.apply_async(self._InfeedLoop, args=(sess, )) dec_metrics = self._model_task.CreateDecoderMetrics() start_time = time.time() buffered_decode_out = [] for i in range(self._steps_per_loop): metrics_values = sess.run(self.metrics) decode_out = self._model_task.PostProcessDecodeOut( metrics_values, dec_metrics) tf.logging.info( 'step: %d %f' % (i, dec_metrics['num_samples_in_batch'].total_value)) if decode_out: buffered_decode_out.extend(decode_out) infeed_future.wait() if self._ml_perf_log: mlp_log.mlperf_print('eval_stop', None, metadata={'epoch_num': (epoch + 1)}) num_examples_metric = dec_metrics['num_samples_in_batch'] summaries = {k: v.Summary(k) for k, v in six.iteritems(dec_metrics)} elapsed_secs = time.time() - start_time example_rate = num_examples_metric.total_value / elapsed_secs summaries['examples/sec'] = tf.Summary(value=[ tf.Summary.Value(tag='examples/sec', simple_value=example_rate) ]) self._WriteSummaries(os.path.basename(self._program_dir), global_step, summaries) decode_out_path = os.path.join(self._program_dir, 'decoder_out_%09d' % global_step) decode_finalize_args = base_model.DecodeFinalizeArgs( decode_out_path=decode_out_path, decode_out=buffered_decode_out) self._model_task.DecodeFinalize(decode_finalize_args) if self._ml_perf_log: mlperf_metric = self._ml_perf.decoder_metric_name mlperf_metric_value = dec_metrics[mlperf_metric].value mlp_log.mlperf_print('eval_accuracy', mlperf_metric_value, metadata={'epoch_num': epoch}) if mlperf_metric_value > self._ml_perf.decoder_metric_success_threshold: tf.logging.info('ml_perf_final_threshold: %f exceeded', self._ml_perf.decoder_metric_success_threshold) mlp_log.mlperf_print('run_stop', None, metadata={'status': 'success'})
def Run(self, sess): gsteps = py_utils.GetGlobalStep() global_step = sess.run(gsteps) self.dec_metrics = self._decode_model_task.CreateDecoderMetrics() # Start TPU program thread. train_future = self._train_pool.apply_async(self._TrainAndDecode, args=(sess, )) if self._warmup_seconds > 0: # The first execution of the TPU program has a warm-up # so we delay feeding data yet as that's when the MLPerf timing # starts. This way, when we actually infeed, the TPU program # is immediately ready to execute/dequeue data. tf.logging.info('Waiting before first infeed.') time.sleep(self._warmup_seconds) self._warmup_seconds = 0 if self._ml_perf_log: if not self._run_start: self._run_start = mlp_log.mlperf_print(key='run_start', value=None) steps_per_epoch = self._ml_perf.steps_per_epoch epoch = int(global_step) // steps_per_epoch if epoch > self._ml_perf_epoch: self._ml_perf_epoch = epoch mlp_log.mlperf_print('block_start', None, metadata={ 'first_epoch_num': epoch + 1, 'epoch_count': 1 }) self.SetStatusMessage('MLPerf epoch: %d' % self._ml_perf_epoch) # Start infeed thread. infeed_future = self._infeed_pool.apply_async(self._InfeedLoop, args=(sess, )) infeed_future.wait() train_future.wait() if self._ml_perf_log: mlp_log.mlperf_print('eval_stop', None, metadata={'epoch_num': (epoch + 1)}) mlperf_metric = self._ml_perf.decoder_metric_name mlperf_metric_value = float(self.dec_metrics[mlperf_metric].value) mlp_log.mlperf_print('eval_accuracy', mlperf_metric_value, metadata={'epoch_num': epoch}) if mlperf_metric_value > self._ml_perf.decoder_metric_success_threshold: tf.logging.info('ml_perf_final_threshold: %f exceeded', self._ml_perf.decoder_metric_success_threshold) if not self._run_stop: self._run_stop = mlp_log.mlperf_print( 'run_stop', None, metadata={'status': 'success'}) self.SetStatusMessage('MLPerf run_time: %.2f' % (self._run_stop - self._run_start)) return True return False
def _Loop(self): with tf.container(self._container_id), self._GetSession( cluster_def=self._cluster_def) as sess: # Initialize the variables first, if needed. for program in self._programs: program.RestoreIfNeeded(sess) program.Compile(sess) sess.run(self._initialize_tables) sess.run(self._initialize_local_vars) if self._ml_perf_log: # Post-initialize/compile. mlp_log.mlperf_print(key='run_start', value=None) while True: global_step = sess.run(py_utils.GetGlobalStep()) if self._ShouldStop(sess, global_step): tf.logging.info('Training finished.') self.save_only_checkpointer.Save(sess, global_step) return # If a task is explicitly selected, only run the programs associated # with that task. if self._single_task_mode or self._model_task_name: tf.logging.info('Single task mode: %s', self._model_task_name) program_schedule = self._program_schedule_dict[ self._model_task_name] else: # Otherwise, sample a task. model_task = self.task_scheduler.Sample(global_step) tf.logging.info('Sampled %s', model_task) program_schedule = self._program_schedule_dict[model_task] program_schedule.Run(sess) # TODO(blee): More complex saving rules. Currently, we assume # we save after every task's program schedule execution. # # global_step local variable above is a result of sess.run, not a # tf variable, so when we do save_only_checkpointer.Save(...) here # py_utils.GetGlobalStep() is ahead of it by # (train_executions_per_eval * train_steps_per_loop) # steps ahead already, due to program_schedule.Run(sess). # self.save_only_checkpointer.Save(sess, py_utils.GetGlobalStep())
def Run(self, sess): gsteps = py_utils.GetGlobalStep() global_step = sess.run(gsteps) if self._ml_perf_log: if not self._run_start: self._run_start = mlp_log.mlperf_print(key='run_start', value=None) steps_per_epoch = self._ml_perf.steps_per_epoch epoch = int(global_step) // steps_per_epoch if epoch > self._ml_perf_epoch: self._ml_perf_epoch = epoch mlp_log.mlperf_print('block_start', None, metadata={ 'first_epoch_num': epoch + 1, 'epoch_count': 1 }) self.SetStatusMessage('MLPerf epoch: %d' % self._ml_perf_epoch) infeed_future = self._infeed_pool.apply_async(self._InfeedLoop, args=(sess, )) dec_metrics = self._decode_model_task.CreateDecoderMetrics() buffered_decode_out = [] for i in range(self._decode_steps_per_loop): metrics_values = sess.run(self.metrics) decode_out = self._decode_model_task.PostProcessDecodeOut( metrics_values, dec_metrics) tf.logging.info( 'step: %d %f' % (i, dec_metrics['num_samples_in_batch'].total_value)) if decode_out: buffered_decode_out.extend(decode_out) infeed_future.wait() if self._ml_perf_log: mlp_log.mlperf_print('eval_stop', None, metadata={'epoch_num': (epoch + 1)}) mlperf_metric = self._ml_perf.decoder_metric_name mlperf_metric_value = dec_metrics[mlperf_metric].value mlp_log.mlperf_print('eval_accuracy', mlperf_metric_value, metadata={'epoch_num': epoch}) if mlperf_metric_value > self._ml_perf.decoder_metric_success_threshold: tf.logging.info('ml_perf_final_threshold: %f exceeded', self._ml_perf.decoder_metric_success_threshold) if not self._run_stop: self._run_stop = mlp_log.mlperf_print( 'run_stop', None, metadata={'status': 'success'}) self.SetStatusMessage('MLPerf run_time: %.2f' % (self._run_stop - self._run_start)) return True return False
def Run(self, sess): tf.logging.info('Executing train program for %s.', self._task_name) gsteps = py_utils.GetGlobalStep() global_step = sess.run(gsteps) if self._ml_perf_log: steps_per_epoch = self._ml_perf.steps_per_epoch epoch = int(global_step) // steps_per_epoch if epoch > self._ml_perf_epoch: self._ml_perf_epoch = epoch mlp_log.mlperf_print('block_start', None, metadata={ 'first_epoch_num': epoch + 1, 'epoch_count': 1 }) infeed_future = self._infeed_pool.apply_async(self._InfeedLoop, args=(sess, )) ary = sess.run(self.tpu_ops) infeed_future.wait() values = ary[0] outfeeds = ary[1] self._eval_metrics.PackMetricsValues(values) eval_metrics = self._eval_metrics.metrics global_step = sess.run(gsteps) step_rate, example_rate, total_examples = ( self._step_rate_tracker.ComputeStepRate( global_step, eval_metrics['num_samples_in_batch'][0] * self._steps_per_loop)) self._SummarizeValue(global_step, 'global_step/sec', step_rate) self._SummarizeValue(global_step, 'examples/sec', example_rate) self._SummarizeValue(global_step, 'total_samples', total_examples) for key, (val, _) in sorted(six.iteritems(eval_metrics)): self._SummarizeValue(global_step, key, val) self._model.GetTask().ProcessFPropResults(sess, global_step, eval_metrics, outfeeds)
def _InfeedLoop(self, sess): tf.logging.info('_InfeedLoop start') try: for i in range(self._train_steps_per_loop): tf.logging.vlog(1, '_InfeedLoop %d', i) sess.run(self._train_task.input.tpu_infeed_op) if self._ml_perf_log: mlp_log.mlperf_print('eval_start', None, metadata={ 'first_epoch_num': self._ml_perf_epoch + 1, 'epoch_count': 1 }) for i in range(self._decode_steps_per_loop): tf.logging.vlog(1, '_InfeedLoop %d', i) sess.run(self._decode_task.input.tpu_infeed_op) tf.logging.info('_InfeedLoop done') except Exception as e: tf.logging.info('_InfeedLoop exception %r %s', e, e) raise
def __init__(self, train_cfg, ps_params_dict, model_task_name, logdir, tf_master, **kwargs): """Construct an ExecutorTpu BaseRunner. Args: train_cfg: SingleTaskModelParams or MultiTaskModelParams ps_params_dict: A dict of top-level task name -> ProgramSchedule params, if train_cfg is a SingleTaskModelParams, we expect only one entry. model_task_name: An override for multi-task models, currently unused. logdir: String path to the log directory to output to. tf_master: String path to the master job, e.g. 'local'. **kwargs: keyword args to pass through to BaseRunner. """ super().__init__(train_cfg, model_task_name, logdir, tf_master, **kwargs) self._cluster_def = self._cluster.worker_cluster_def # There is a single Executor task assert self._cluster.num_replicas == 1 data_parallelism = self._cluster.num_splits_per_client assert data_parallelism num_devices_per_split = self._cluster.num_devices_per_split tf.logging.info('data_parallelism: %d, num_devices_per_split: %d', data_parallelism, num_devices_per_split) self.task_scheduler = None self._checkpoint_dir = os.path.join(logdir, 'train') self._variable_renaming_rules = [] self._ml_perf = None # If this is a multi-task model, grab the params for the TaskScheduler. if issubclass(train_cfg.cls, base_model.SingleTaskModel): tf.logging.info('single_task_model') assert len(ps_params_dict) == 1 self._model_task_name = list(ps_params_dict.keys())[0] self._single_task_mode = True elif issubclass(train_cfg.cls, base_model.MultiTaskModel): tf.logging.info('multi_task_model') if issubclass(train_cfg.cls, multitask_model.RegExSharedVariableModel): self._variable_renaming_rules = train_cfg.variable_renaming_rules if train_cfg.task_schedule is None: task_schedule_params = task_scheduler.ConstantScheduler.Params( ) task_schedule_params.task_probs = sorted( list(train_cfg.task_probs.IterParams())) else: task_schedule_params = train_cfg.task_schedule self.task_scheduler = task_schedule_params.Instantiate() self._single_task_mode = False else: tf.logging.fatal( 'Model %s is not a sub-class of SingleTaskModel or MultiTaskModel', train_cfg.cls) tf.logging.info('train_cfg.cls: %s', train_cfg.cls) self._WriteToLog(train_cfg.ToText(), self._checkpoint_dir, 'trainer_params.txt') if self._ml_perf is not None: self._ml_perf_log = True mlp_log.mlperf_print(key='benchmark', value=self._ml_perf.benchmark_name) else: self._ml_perf_log = False # BaseRunner legacy self.enqueue_ops = None @py_utils.RetryOnTransientTfError() def _WaitTillInit(): """Wait until the model is ready.""" try: with self._graph.as_default(), self._GetSession( cluster_def=self._cluster_def, disable_meta_optimizer=FLAGS. disable_meta_optimizer_in_executor) as sess: topology = sess.run( tf.tpu.initialize_system(embedding_config=None, job=None)) device_assignment = device_assignment_lib.device_assignment( topology, computation_shape=py_utils.ComputationShape( num_devices_per_split, topology), num_replicas=data_parallelism) py_utils.SetTpuDeviceAssignment(device_assignment) tf.logging.info('device_assignment.core_assignment: %s', str(device_assignment.core_assignment)) tf.logging.info( 'device_assignment.topology.device_coordinates: %s', str(device_assignment.topology.device_coordinates)) except py_utils.transient_tf_errors as e: tf.logging.info('TPU initialization failed: %s', e) raise if self._ml_perf_log: mlp_log.mlperf_print(key='init_start', value=None) _WaitTillInit() train_cfg = self.params shared_model = self._MaybeConstructSharedModel(train_cfg) self._program_schedule_dict = {} self._programs = [] for task_string, program_schedule_params in ps_params_dict.items(): program_schedule_params.logdir = logdir program_schedule_params.num_splits_per_client = data_parallelism program_schedule_params.task_name = task_string # If the model was created above, we'll inject it here as a shared_model. ps = program_schedule_params.Instantiate(shared_model=shared_model) self._program_schedule_dict[task_string] = ps tf.logging.info('program_schedule_params: %s', program_schedule_params.ToText()) self._programs += ps.Programs() if program_schedule_params.ml_perf.benchmark_name is not None: self._ml_perf = program_schedule_params.ml_perf tf.logging.info('num_programs: %d', len(self._programs)) with self._graph.as_default(), tf.container(self._container_id): with self._cluster, tf.device( self._cluster.job_spec.name if not FLAGS. cluster_placer_in_executor else self._cluster.GetPlacer()): with py_utils.VariableRenameScope( self._variable_renaming_rules): _ = py_utils.GetOrCreateGlobalStepVar() for program in self._programs: program.BuildTpuSubgraph() py_utils.ClearTpuSummaryTensors() for program in self._programs: program.SetStatusMessageFn(self._SetStatusMessage) program.CreateCheckpointer() self._initialize_tables = tf.tables_initializer() self._initialize_local_vars = tf.local_variables_initializer() self.save_only_checkpointer = checkpointer.Checkpointer( self._checkpoint_dir, model=None, train_params=train_cfg.train, save_only=True)
def BuildTpuSubgraph(self): if self._ml_perf_log: mlp_log.mlperf_print('global_batch_size', self._ml_perf.global_batch_size) mlp_log.mlperf_print('max_sequence_length', self._ml_perf.max_sequence_length) mlp_log.mlperf_print('opt_name', self._ml_perf.optimizer_name) mlp_log.mlperf_print('opt_base_learning_rate', self._ml_perf.base_learning_rate) mlp_log.mlperf_print('opt_learning_rate_warmup_steps', self._ml_perf.warmup_steps) with py_utils.OpportunisticVariableReuseScope(True): self._eval_metrics = metrics.TpuEvalMetrics() data_parallelism = self.data_parallelism def TpuTrainStep(): """Train a shard of a batch on a single TPU core. Do not calculate loss metrics. Returns: [train_op]. """ self._train_model = self._train_task_params.Instantiate() self._model = self._train_model self._train_model.ConstructFPropBPropGraph() return [self._train_model.GetTask().train_op] def TpuTrain(): loop_result = tpu_training_loop.repeat( self._train_steps_per_loop, TpuTrainStep, inputs=[], name='train_loop') return loop_result py_utils.ResetStepSeed() def _DecodeFn(): """Decode call to be compiled for TPU.""" with py_utils.OpportunisticVariableReuseScope(True): with cluster_factory.SetEval(True): self._decode_model = self._decode_task_params.Instantiate() self._decode_model_task = self._decode_model.GetTask() if py_utils.use_tpu(): input_batch = self._decode_model_task.input_generator.CreateTpuFeeds( ) else: input_batch = self._decode_model_task.input_generator.SplitInputBatch( self.cluster.num_splits_per_client) metrics_dict = self._decode_model_task.Decode(input_batch) self.metrics_nm = py_utils.NestedMap(metrics_dict) return self.metrics_nm.Flatten() @tpu_function.on_device_training_loop def TrainAndDecode(): with tf.control_dependencies([TpuTrain()]): return _DecodeFn() self._compile_op, batch_parallel_res = tpu.split_compile_and_shard( TrainAndDecode, num_shards=data_parallelism, device_assignment=py_utils.GetTpuDeviceAssignment()) self.metrics = py_utils.NestedMap(self.metrics_nm) self.metrics = self.metrics.Pack(batch_parallel_res) return None
def __init__(self, train_cfg, ps_params_dict, model_task_name, logdir, tf_master, **kwargs): """Construct an ExecutorTpu BaseRunner. Args: train_cfg: SingleTaskModelParams or MultiTaskModelParams ps_params_dict: A dict of top-level task name -> ProgramSchedule params, if train_cfg is a SingleTaskModelParams, we expect only one entry. model_task_name: An override for multi-task models, currently unused. logdir: String path to the log directory to output to. tf_master: String path to the master job, e.g. 'local'. **kwargs: keyword args to pass through to BaseRunner. """ super().__init__(train_cfg, model_task_name, logdir, tf_master, **kwargs) data_parallelism = self._cluster.num_splits_per_client assert data_parallelism num_devices_per_split = self._cluster.num_devices_per_split tf.logging.info('data_parallelism: %d, num_devices_per_split: %d', data_parallelism, num_devices_per_split) self.task_scheduler = None self._checkpoint_dir = os.path.join(logdir, 'train') self._variable_renaming_rules = [] self._ml_perf = None # If this is a multi-task model, grab the params for the TaskScheduler. if issubclass(train_cfg.cls, base_model.SingleTaskModel): tf.logging.info('single_task_model') assert len(ps_params_dict) == 1 self._model_task_name = list(ps_params_dict.keys())[0] self._single_task_mode = True elif issubclass(train_cfg.cls, base_model.MultiTaskModel): tf.logging.info('multi_task_model') if issubclass(train_cfg.cls, multitask_model.RegExSharedVariableModel): self._variable_renaming_rules = train_cfg.variable_renaming_rules if train_cfg.task_schedule is None: task_schedule_params = task_scheduler.ConstantScheduler.Params( ) task_schedule_params.task_probs = sorted( list(train_cfg.task_probs.IterParams())) else: task_schedule_params = train_cfg.task_schedule self.task_scheduler = task_schedule_params.Instantiate() self._single_task_mode = False else: tf.logging.fatal( 'Model %s is not a sub-class of SingleTaskModel or MultiTaskModel', train_cfg.cls) tf.logging.info('train_cfg.cls: %s', train_cfg.cls) self._WriteToLog(train_cfg.ToText(), self._checkpoint_dir, 'trainer_params.txt') if self._ml_perf is not None: self._ml_perf_log = True mlp_log.mlperf_print(key='benchmark', value=self._ml_perf.benchmark_name) else: self._ml_perf_log = False # BaseRunner legacy self.enqueue_ops = None train_cfg = self.params @py_utils.RetryOnTransientTfError() def _WaitTillInit(job=None): """Wait until the model is ready.""" try: # tpu.initialize_system() is called with None as embedding_config, as # embedding_config is not available yet. Later in _Loop, it is called # with the correct embedding_config. Since it cannot be called twice in # the same graph with different embedding_config, we use a dummy_graph # here. dummy_graph = tf.Graph() with dummy_graph.as_default(): tpu_initialize_system_op = tf.tpu.initialize_system( embedding_config=None, job=job) with self._GetSession(graph=dummy_graph) as sess: topology = sess.run(tpu_initialize_system_op) if train_cfg.train.tpu_device_order_mode is None: device_assignment = device_assignment_lib.device_assignment( topology, computation_shape=py_utils.ComputationShape( num_devices_per_split, topology), num_replicas=data_parallelism) else: device_assignment = device_assignment_lib.device_assignment( topology, computation_shape=py_utils.ComputationShape( num_devices_per_split, topology), num_replicas=data_parallelism, device_order_mode=train_cfg.train.tpu_device_order_mode ) py_utils.SetTpuDeviceAssignment(device_assignment, job) tf.logging.info('device_assignment.core_assignment: %s', str(device_assignment.core_assignment)) tf.logging.info( 'device_assignment.topology.device_coordinates: %s', str(device_assignment.topology.device_coordinates)) except py_utils.transient_tf_errors as e: tf.logging.info('TPU initialization failed: %s', e) raise if self._ml_perf_log: mlp_log.mlperf_print(key='init_start', value=None) if len(self._cluster.all_worker_names) > 1: for worker in self._cluster.all_worker_names: _WaitTillInit(worker) else: _WaitTillInit(None) shared_model = self._MaybeConstructSharedModel(train_cfg) self._program_schedule_dict = {} self._programs = [] for task_string, program_schedule_params in ps_params_dict.items(): program_schedule_params.logdir = logdir program_schedule_params.num_splits_per_client = data_parallelism program_schedule_params.task_name = task_string # If the model was created above, we'll inject it here as a shared_model. ps = program_schedule_params.Instantiate(shared_model=shared_model, tf_master=self._tf_master) self._program_schedule_dict[task_string] = ps tf.logging.info('program_schedule_params: %s', program_schedule_params.ToText()) self._programs += ps.Programs() if program_schedule_params.ml_perf.benchmark_name is not None: self._ml_perf = program_schedule_params.ml_perf tf.logging.info('num_programs: %d', len(self._programs)) with self._graph.as_default(), tf.container(self._container_id): with self._cluster, tf.device(self._cluster.GetPlacer()): with py_utils.VariableRenameScope( self._variable_renaming_rules): _ = py_utils.GetOrCreateGlobalStepVar() for program in self._programs: program.BuildTpuSubgraph() py_utils.ClearTpuSummaryTensors() self._initialize_tables = tf.tables_initializer() self._initialize_local_vars = tf.local_variables_initializer() self._initialize_global_vars = tf.global_variables_initializer( ) for program in self._programs: program.SetStatusMessageFn(self._SetStatusMessage) program.CreateCheckpointer( init_op=self._initialize_global_vars) self.save_only_checkpointer = checkpointer.Checkpointer( self._checkpoint_dir, model=None, init_op=self._initialize_global_vars, train_params=train_cfg.train, save_only=True) self._load_ops = tf.get_collection(py_utils.TPU_EMBEDDING_LOAD_OPS) self._retrieve_ops = tf.get_collection( py_utils.TPU_EMBEDDING_RETRIEVE_OPS) tpu_embedding_collection = tf.get_collection( py_utils.TPU_EMBEDDING) self._tpu_embedding = (tpu_embedding_collection[0] if tpu_embedding_collection else None) tf.io.write_graph(self._graph.as_graph_def(), self._checkpoint_dir, 'train.pbtxt')
def __init__(self, train_cfg, ps_params_dict, *args, **kwargs): """Construct an ExecutorTpu BaseRunner. Args: train_cfg: SingleTaskModelParams or MultiTaskModelParams ps_params_dict: A dict of top-level task name -> ProgramSchedule params, if train_cfg is a SingleTaskModelParams, we expect only one entry. *args: List args to pass through to BaseRunner. **kwargs: keyword args to pass through to BaseRunner. """ if py_utils.IsEagerMode(): assert tf.executing_eagerly() tf.logging.info(f'FLAGS.tf_master: {FLAGS.tf_master}') # Connect to the TPU runtime. resolver = tf.distribute.cluster_resolver.TPUClusterResolver( FLAGS.tf_master, job_name=FLAGS.worker_job[len('/job:'):]) tf.config.experimental_connect_to_cluster(resolver) super().__init__(train_cfg, *args, **kwargs) data_parallelism = self._cluster.num_splits_per_client assert data_parallelism num_devices_per_split = self._cluster.num_devices_per_split tf.logging.info('data_parallelism: %d, num_devices_per_split: %d', data_parallelism, num_devices_per_split) self.task_scheduler = None self._checkpoint_dir = os.path.join(self._logdir, 'train') self._variable_renaming_rules = [] self._ml_perf = None # If this is a multi-task model, grab the params for the TaskScheduler. if issubclass(train_cfg.cls, base_model.SingleTaskModel): tf.logging.info('single_task_model') assert len(ps_params_dict) == 1 self._model_task_name = list(ps_params_dict.keys())[0] self._single_task_mode = True elif issubclass(train_cfg.cls, base_model.MultiTaskModel): tf.logging.info('multi_task_model') if issubclass(train_cfg.cls, multitask_model.RegExSharedVariableModel): self._variable_renaming_rules = train_cfg.variable_renaming_rules if train_cfg.task_schedule is None: task_schedule_params = task_scheduler.ConstantScheduler.Params( ) task_schedule_params.task_probs = sorted( list(train_cfg.task_probs.IterParams())) else: task_schedule_params = train_cfg.task_schedule self.task_scheduler = task_schedule_params.Instantiate() self._single_task_mode = False else: tf.logging.fatal( 'Model %s is not a sub-class of SingleTaskModel or MultiTaskModel', train_cfg.cls) tf.logging.info('train_cfg.cls: %s', train_cfg.cls) self._WriteToLog(train_cfg.ToText(), self._checkpoint_dir, 'trainer_params.txt') self._WriteToLog( text_format.MessageToString(train_cfg.ToProto(), as_utf8=True), self._checkpoint_dir, 'trainer_params.pbtxt') if self._ml_perf is not None: self._ml_perf_log = True mlp_log.mlperf_print(key='benchmark', value=self._ml_perf.benchmark_name) else: self._ml_perf_log = False train_cfg = self.params @py_utils.RetryOnTransientTfError() def _WaitTillInit(job=None): """Wait until the model is ready.""" try: if py_utils.IsEagerMode(): topology = tf.tpu.experimental.initialize_tpu_system( resolver) else: # tpu.initialize_system() is called with None as embedding_config, as # embedding_config is not available yet. Later in _Loop, it is called # with the correct embedding_config. Since it cannot be called twice # in the same graph with different embedding_config, we use a # dummy_graph here. dummy_graph = tf.Graph() with dummy_graph.as_default(): tpu_initialize_system_op = tf.tpu.initialize_system( embedding_config=None, job=job) with self._GetSession(graph=dummy_graph) as sess: topology = sess.run(tpu_initialize_system_op) if train_cfg.train.tpu_computation_shape is None: computation_shape = py_utils.ComputationShape( num_devices_per_split, topology) else: computation_shape = train_cfg.train.tpu_computation_shape assert num_devices_per_split == np.prod(computation_shape) if train_cfg.train.tpu_device_order_mode is None: self.device_assignment = device_assignment_lib.device_assignment( topology, computation_shape=computation_shape, num_replicas=data_parallelism) else: self.device_assignment = device_assignment_lib.device_assignment( topology, computation_shape=computation_shape, num_replicas=data_parallelism, device_order_mode=train_cfg.train.tpu_device_order_mode ) py_utils.SetTpuDeviceAssignment(self.device_assignment, job) tf.logging.info('device_assignment.core_assignment: %s', str(self.device_assignment.core_assignment)) tf.logging.info( 'device_assignment.topology.device_coordinates: %s', str(self.device_assignment.topology.device_coordinates)) except py_utils.transient_tf_errors as e: tf.logging.info('TPU initialization failed: %s', e) raise if self._ml_perf_log: mlp_log.mlperf_print(key='init_start', value=None) if len(self._cluster.all_worker_names) > 1: for worker in self._cluster.all_worker_names: _WaitTillInit(worker) else: _WaitTillInit(None) shared_model = self._MaybeConstructSharedModel(train_cfg) self._program_schedule_dict = {} self._programs = [] self._ckpt_programs = [] self._checkpoint_to_load = None with self._cluster: # Create the ExponentialMovingAverage singleton shared by all programs, if # applicable. ema = py_utils.CreateEMAForModel(train_cfg, self._global_step_var) for task_string, program_schedule_params in ps_params_dict.items(): program_schedule_params.logdir = self._logdir program_schedule_params.num_splits_per_client = data_parallelism program_schedule_params.task_name = task_string # If the model was created above, we'll inject it here as a # shared_model. ps = program_schedule_params.Instantiate( shared_model=shared_model, trial=self._trial, ema=ema, tf_master=self._tf_master) self._program_schedule_dict[task_string] = ps tf.logging.info('program_schedule_params: %s', program_schedule_params.ToText()) self._programs += ps.Programs() if ps.train_program: self._ckpt_programs.append(ps.train_program) else: self._ckpt_programs += ps.Programs() if program_schedule_params.ml_perf.benchmark_name is not None: self._ml_perf = program_schedule_params.ml_perf if ('checkpoint_to_load' in program_schedule_params and program_schedule_params.checkpoint_to_load): if (self._checkpoint_to_load and (self._checkpoint_to_load != program_schedule_params.checkpoint_to_load)): raise ValueError( f'Multiple values found for checkpoint_to_load: ' f'{self._checkpoint_to_load}, ' f'{program_schedule_params.checkpoint_to_load}.') self._checkpoint_to_load = program_schedule_params.checkpoint_to_load tf.logging.info('num_programs: %d', len(self._programs)) # When running in a vizier trainer, the executor reports infeasiable runs # in case of errors. The programs report metrics and normal completions. for program in self._programs: if program._should_report_metrics: self._should_report_metrics = True with self._cluster, tf.container( self._container_id), contextlib.ExitStack() as stack: if not py_utils.IsEagerMode(): stack.enter_context(self._graph.as_default()) if FLAGS.use_tpu_mirrored_vars: resolver = tf.distribute.cluster_resolver.TPUClusterResolver( FLAGS.tf_master, job_name=FLAGS.worker_job[len('/job:'):]) self._tpu_strategy = tf.distribute.experimental.TPUStrategy( resolver, device_assignment=self.device_assignment) stack.enter_context(self._tpu_strategy.scope()) stack.enter_context( tpu_strategy._TPUReplicaContext(self._tpu_strategy)) else: stack.enter_context(tf.device(self._cluster.GetPlacer())) if FLAGS.pdb_on_exception: stack.enter_context(pdb_wrapper.catch_post_mortem()) with py_utils.VariableStore(), py_utils.VariableRenameScope( self._variable_renaming_rules): # `BuildTpuSubgraph` has to be called before checkpoint restore, so that # the optimizer slot variables are guaranteed to be initialized before # they get loaded. Otherwise, the optimizers' slot variables will not # be properly loaded when V1 checkpoint is used. for program in self._programs: program.BuildTpuSubgraph() py_utils.ClearTpuSummaryTensors() if not py_utils.IsEagerMode(): self._initialize_tables = tf.tables_initializer() self._initialize_local_vars = tf.local_variables_initializer() self._initialize_global_vars = tf.global_variables_initializer( ) checkpointer_models = [ program.GetModel() for program in self._ckpt_programs ] if py_utils.IsEagerMode(): if FLAGS.use_v2_checkpoints_in_eager: self._checkpointer = checkpointer.EagerCheckpointerV2( self._checkpoint_dir, models=checkpointer_models, init_op=None, train_params=train_cfg.train, save_only=False) else: self._checkpointer = checkpointer.EagerCheckpointerV1( self._checkpoint_dir, models=checkpointer_models, init_op=None, train_params=train_cfg.train, save_only=False) else: self._checkpointer = checkpointer.Checkpointer( self._checkpoint_dir, models=checkpointer_models, init_op=self._initialize_global_vars, train_params=train_cfg.train, save_only=False) for program in self._programs: program.SetStatusMessageFn(self._SetStatusMessage) tpu_embedding_collection = ( tpu_embedding_layers.TpuEmbeddingCollection.Get()) self._load_ops = tpu_embedding_collection.load_ops self._retrieve_ops = tpu_embedding_collection.retrieve_ops self._tpu_embedding = tpu_embedding_collection.tpu_embedding
def BuildTpuSubgraph(self): tf.logging.info('TrainProgram BuildTpuSubGraph') if self._ml_perf_log: mlp_log.mlperf_print('global_batch_size', self._ml_perf.global_batch_size) mlp_log.mlperf_print('max_sequence_length', self._ml_perf.max_sequence_length) mlp_log.mlperf_print('opt_name', self._ml_perf.optimizer_name) mlp_log.mlperf_print('opt_base_learning_rate', self._ml_perf.base_learning_rate) mlp_log.mlperf_print('opt_learning_rate_warmup_steps', self._ml_perf.warmup_steps) with py_utils.OpportunisticVariableReuseScope(True): self._eval_metrics = metrics.TpuEvalMetrics() data_parallelism = self.data_parallelism def TpuTrainStep(*args): """Train a shard of a batch on a single TPU core. Args: *args: metrics values from previous steps. Returns: New summed metrics values and a train_op. """ self._model = self._task_params.Instantiate() self._model.ConstructFPropBPropGraph() per_step_eval_metrics = self._eval_metrics.SetMetrics( self._model.GetTask().eval_metrics, args) outfeed_op = self._OutfeedEnqueue( self._model.GetTask().per_example_tensors) summed_metrics = [] assert len(per_step_eval_metrics) == len(args) with tf.control_dependencies([outfeed_op]): for x, y in zip(per_step_eval_metrics, args): summed_metrics.append(x + y) return summed_metrics + [self._model.GetTask().train_op] @tpu_function.on_device_training_loop def TpuTrain(): loop_result = tpu_training_loop.repeat( self._steps_per_loop, TpuTrainStep, inputs=self._eval_metrics.initial_values, name='train_loop') # Final metrics are the avg across self._steps_per_loop steps. return self._eval_metrics.FinalizeMetrics(loop_result) self._compile_op, batch_parallel_res = tpu.split_compile_and_shard( TpuTrain, num_shards=data_parallelism, device_assignment=py_utils.GetTpuDeviceAssignment()) outfeed_dequeue_op = self._OutfeedDequeueLoop( self._model.GetTask().per_example_tensors, self._steps_per_loop, self.num_splits_per_client) # Get metric result from a single replica; they are all same here. self.tpu_ops = [[t[0] for t in batch_parallel_res], outfeed_dequeue_op] return self.tpu_ops