def _Loop(self): with self._cluster, tf.container(self._container_id), self._GetSession( disable_meta_optimizer=FLAGS.disable_meta_optimizer_in_executor ) as sess: config_proto = (self._tpu_embedding.config_proto if self._tpu_embedding is not None else None) sess.reset(self._tf_master) for worker in self._cluster.all_worker_names: sess.run( tf.tpu.initialize_system(embedding_config=config_proto, job=worker)) # Initialize the variables first, if needed. compile_fns = [] for program in self._programs: program.RestoreIfNeeded(sess) compile_fns += [program.Compile] # Run the compiles in parallel. threadpool = multiprocessing.dummy.Pool(len(compile_fns)) futures = [] tf.logging.info( f'Compiling {len(compile_fns)} programs in parallel.') for fn in compile_fns: futures += [threadpool.apply_async(fn, args=(sess, ))] for future in futures: future.wait() sess.run(self._initialize_tables) sess.run(self._initialize_local_vars) sess.run(self._load_ops) program_schedule = None while True: global_step = sess.run(py_utils.GetGlobalStep()) if self._ShouldStop(sess, global_step): tf.logging.info('Training finished.') if not self._ml_perf_log: self.save_only_checkpointer.Save(sess, global_step) for program in self._programs: program.SaveProgramState(sess, global_step) if program_schedule: tf.logging.info('Shutting down programs.') program_schedule.Shutdown() return if not self._ml_perf_log and self.save_only_checkpointer.ShouldSave( ): def RunSave(sess, global_step): # Run TPU embedding retrieve ops. # NOTE: this is expensive, so only run it when we're checkpointing. tf.logging.info('Retrieve params.') sess.run(self._retrieve_ops) tf.logging.info('Retrieve params done.') # Save program state first, so it's recoverable after we restore # from checkpoint. for program in self._programs: program.SaveProgramState(sess, global_step) # Save the checkpoints. self.save_only_checkpointer.Save(sess, global_step) if self.save_only_checkpointer.async_checkpointing: tf.logging.info( 'Save checkpoint asynchronously AT YOUR OWN RISK.') threadpool = multiprocessing.dummy.Pool(1) saver_future = threadpool.apply_async( RunSave, args=(sess, global_step)) else: RunSave(sess, global_step) # If a task is explicitly selected, only run the programs associated # with that task. if self._single_task_mode or self._model_task_name: tf.logging.info('Single task mode: %s', self._model_task_name) program_schedule = self._program_schedule_dict[ self._model_task_name] else: # Otherwise, sample a task. model_task = self.task_scheduler.Sample(global_step) tf.logging.info('Sampled %s', model_task) program_schedule = self._program_schedule_dict[model_task] done = program_schedule.Run(sess) if (not self._ml_perf_log and self.save_only_checkpointer.async_checkpointing): saver_future.wait() if done: tf.logging.info('Program schedule told us to stop.\n' 'Shutting down programs.') program_schedule.Shutdown() return
def __init__(self, train_cfg, ps_params_dict, model_task_name, logdir, tf_master, **kwargs): """Construct an ExecutorTpu BaseRunner. Args: train_cfg: SingleTaskModelParams or MultiTaskModelParams ps_params_dict: A dict of top-level task name -> ProgramSchedule params, if train_cfg is a SingleTaskModelParams, we expect only one entry. model_task_name: An override for multi-task models, currently unused. logdir: String path to the log directory to output to. tf_master: String path to the master job, e.g. 'local'. **kwargs: keyword args to pass through to BaseRunner. """ super().__init__(train_cfg, model_task_name, logdir, tf_master, **kwargs) self._cluster_def = self._cluster.worker_cluster_def # There is a single Executor task assert self._cluster.num_replicas == 1 data_parallelism = self._cluster.num_splits_per_client assert data_parallelism num_devices_per_split = self._cluster.num_devices_per_split tf.logging.info('data_parallelism: %d, num_devices_per_split: %d', data_parallelism, num_devices_per_split) self.task_scheduler = None self._checkpoint_dir = os.path.join(logdir, 'train') self._variable_renaming_rules = [] self._ml_perf = None # If this is a multi-task model, grab the params for the TaskScheduler. if issubclass(train_cfg.cls, base_model.SingleTaskModel): tf.logging.info('single_task_model') assert len(ps_params_dict) == 1 self._model_task_name = list(ps_params_dict.keys())[0] self._single_task_mode = True elif issubclass(train_cfg.cls, base_model.MultiTaskModel): tf.logging.info('multi_task_model') if issubclass(train_cfg.cls, multitask_model.RegExSharedVariableModel): self._variable_renaming_rules = train_cfg.variable_renaming_rules if train_cfg.task_schedule is None: task_schedule_params = task_scheduler.ConstantScheduler.Params() task_schedule_params.task_probs = sorted( list(train_cfg.task_probs.IterParams())) else: task_schedule_params = train_cfg.task_schedule self.task_scheduler = task_schedule_params.Instantiate() self._single_task_mode = False else: tf.logging.fatal( 'Model %s is not a sub-class of SingleTaskModel or MultiTaskModel', train_cfg.cls) tf.logging.info('train_cfg.cls: %s', train_cfg.cls) self._WriteToLog(train_cfg.ToText(), self._checkpoint_dir, 'trainer_params.txt') self._program_schedule_dict = {} self._programs = [] for task_string, program_schedule_params in ps_params_dict.items(): program_schedule_params.logdir = logdir program_schedule_params.num_splits_per_client = data_parallelism program_schedule_params.task_name = task_string ps = program_schedule_params.Instantiate() self._program_schedule_dict[task_string] = ps tf.logging.info('program_schedule_params: %s', program_schedule_params.ToText()) self._programs += ps.Programs() if program_schedule_params.ml_perf.benchmark_name is not None: self._ml_perf = program_schedule_params.ml_perf tf.logging.info('num_programs: %d', len(self._programs)) if self._ml_perf is not None: self._ml_perf_log = True mlp_log.mlperf_print(key='benchmark', value=self._ml_perf.benchmark_name) else: self._ml_perf_log = False # BaseRunner legacy self.enqueue_ops = None @py_utils.RetryOnTransientTfError() def _WaitTillInit(): """Wait until the model is ready.""" try: with self._graph.as_default(), self._GetSession( cluster_def=self._cluster_def, disable_meta_optimizer=FLAGS.disable_meta_optimizer_in_executor ) as sess: topology = sess.run( tf.tpu.initialize_system(embedding_config=None, job=None)) device_assignment = device_assignment_lib.device_assignment( topology, computation_shape=py_utils.ComputationShape( num_devices_per_split), num_replicas=data_parallelism) py_utils.SetTpuDeviceAssignment(device_assignment) tf.logging.info('device_assignment.core_assignment: %s', str(device_assignment.core_assignment)) tf.logging.info( 'device_assignment.topology.device_coordinates: %s', str(device_assignment.topology.device_coordinates)) except py_utils.transient_tf_errors as e: tf.logging.info('TPU initialization failed: %s', e) raise if self._ml_perf_log: mlp_log.mlperf_print(key='init_start', value=None) _WaitTillInit() with self._graph.as_default(), tf.container(self._container_id): with self._cluster, tf.device( self._cluster.job_spec.name if not FLAGS.cluster_placer_in_executor else self._cluster.GetPlacer()): with py_utils.VariableRenameScope(self._variable_renaming_rules): _ = py_utils.GetOrCreateGlobalStepVar() for program in self._programs: program.BuildTpuSubgraph() for program in self._programs: program.SetStatusMessageFn(self._SetStatusMessage) program.CreateCheckpointer() self._initialize_tables = tf.tables_initializer() self._initialize_local_vars = tf.local_variables_initializer() self.save_only_checkpointer = checkpointer.Checkpointer( self._checkpoint_dir, model=None, train_params=train_cfg.train, save_only=True)
def _Loop(self): with self._cluster, tf.container(self._container_id), self._GetSession( disable_meta_optimizer=FLAGS.disable_meta_optimizer_in_executor ) as sess: config_proto = (self._tpu_embedding.config_proto if self._tpu_embedding is not None else None) sess.reset(self._tf_master) for worker in self._cluster.all_worker_names: sess.run( tf.tpu.initialize_system(embedding_config=config_proto, job=worker)) # Initialize the variables first, if needed. compile_fns = [] for program in self._programs: program.RestoreIfNeeded(sess) compile_fns += [program.Compile] # Run the compiles in parallel. threadpool = multiprocessing.dummy.Pool(len(compile_fns)) futures = [] tf.logging.info( f'Compiling {len(compile_fns)} programs in parallel.') for fn in compile_fns: futures += [threadpool.apply_async(fn, args=(sess, ))] for future in futures: future.wait() sess.run(self._initialize_tables) sess.run(self._initialize_local_vars) sess.run(self._load_ops) program_schedule = None while True: global_step = sess.run(py_utils.GetGlobalStep()) if self._ShouldStop(sess, global_step): tf.logging.info('Training finished.') if not self._ml_perf_log: self.save_only_checkpointer.Save(sess, global_step) if program_schedule: tf.logging.info('Shutting down programs.') program_schedule.Shutdown() return # If a task is explicitly selected, only run the programs associated # with that task. if self._single_task_mode or self._model_task_name: tf.logging.info('Single task mode: %s', self._model_task_name) program_schedule = self._program_schedule_dict[ self._model_task_name] else: # Otherwise, sample a task. model_task = self.task_scheduler.Sample(global_step) tf.logging.info('Sampled %s', model_task) program_schedule = self._program_schedule_dict[model_task] done = program_schedule.Run(sess) if done: tf.logging.info('Program schedule told us to stop.\n' 'Shutting down programs.') program_schedule.Shutdown() return # global_step local variable above is a result of sess.run, not a # tf variable, so when we do save_only_checkpointer.Save(...) here # py_utils.GetGlobalStep() is ahead of it by # (train_executions_per_eval * train_steps_per_loop) # steps ahead already, due to program_schedule.Run(sess). # if not self._ml_perf_log: tf.logging.info('Retrieve params.') sess.run(self._retrieve_ops) tf.logging.info('Retrieve params done.') self.save_only_checkpointer.MaybeSave( sess, py_utils.GetGlobalStep())
def __init__(self, train_cfg, ps_params_dict, model_task_name, logdir, tf_master, **kwargs): """Construct an ExecutorTpu BaseRunner. Args: train_cfg: SingleTaskModelParams or MultiTaskModelParams ps_params_dict: A dict of top-level task name -> ProgramSchedule params, if train_cfg is a SingleTaskModelParams, we expect only one entry. model_task_name: An override for multi-task models, currently unused. logdir: String path to the log directory to output to. tf_master: String path to the master job, e.g. 'local'. **kwargs: keyword args to pass through to BaseRunner. """ super().__init__(train_cfg, model_task_name, logdir, tf_master, **kwargs) data_parallelism = self._cluster.num_splits_per_client assert data_parallelism num_devices_per_split = self._cluster.num_devices_per_split tf.logging.info('data_parallelism: %d, num_devices_per_split: %d', data_parallelism, num_devices_per_split) self.task_scheduler = None self._checkpoint_dir = os.path.join(logdir, 'train') self._variable_renaming_rules = [] self._ml_perf = None # If this is a multi-task model, grab the params for the TaskScheduler. if issubclass(train_cfg.cls, base_model.SingleTaskModel): tf.logging.info('single_task_model') assert len(ps_params_dict) == 1 self._model_task_name = list(ps_params_dict.keys())[0] self._single_task_mode = True elif issubclass(train_cfg.cls, base_model.MultiTaskModel): tf.logging.info('multi_task_model') if issubclass(train_cfg.cls, multitask_model.RegExSharedVariableModel): self._variable_renaming_rules = train_cfg.variable_renaming_rules if train_cfg.task_schedule is None: task_schedule_params = task_scheduler.ConstantScheduler.Params( ) task_schedule_params.task_probs = sorted( list(train_cfg.task_probs.IterParams())) else: task_schedule_params = train_cfg.task_schedule self.task_scheduler = task_schedule_params.Instantiate() self._single_task_mode = False else: tf.logging.fatal( 'Model %s is not a sub-class of SingleTaskModel or MultiTaskModel', train_cfg.cls) tf.logging.info('train_cfg.cls: %s', train_cfg.cls) self._WriteToLog(train_cfg.ToText(), self._checkpoint_dir, 'trainer_params.txt') if self._ml_perf is not None: self._ml_perf_log = True mlp_log.mlperf_print(key='benchmark', value=self._ml_perf.benchmark_name) else: self._ml_perf_log = False # BaseRunner legacy self.enqueue_ops = None train_cfg = self.params @py_utils.RetryOnTransientTfError() def _WaitTillInit(job=None): """Wait until the model is ready.""" try: # tpu.initialize_system() is called with None as embedding_config, as # embedding_config is not available yet. Later in _Loop, it is called # with the correct embedding_config. Since it cannot be called twice in # the same graph with different embedding_config, we use a dummy_graph # here. dummy_graph = tf.Graph() with dummy_graph.as_default(): tpu_initialize_system_op = tf.tpu.initialize_system( embedding_config=None, job=job) with self._GetSession(graph=dummy_graph) as sess: topology = sess.run(tpu_initialize_system_op) if train_cfg.train.tpu_device_order_mode is None: device_assignment = device_assignment_lib.device_assignment( topology, computation_shape=py_utils.ComputationShape( num_devices_per_split, topology), num_replicas=data_parallelism) else: device_assignment = device_assignment_lib.device_assignment( topology, computation_shape=py_utils.ComputationShape( num_devices_per_split, topology), num_replicas=data_parallelism, device_order_mode=train_cfg.train.tpu_device_order_mode ) py_utils.SetTpuDeviceAssignment(device_assignment, job) tf.logging.info('device_assignment.core_assignment: %s', str(device_assignment.core_assignment)) tf.logging.info( 'device_assignment.topology.device_coordinates: %s', str(device_assignment.topology.device_coordinates)) except py_utils.transient_tf_errors as e: tf.logging.info('TPU initialization failed: %s', e) raise if self._ml_perf_log: mlp_log.mlperf_print(key='init_start', value=None) if len(self._cluster.all_worker_names) > 1: for worker in self._cluster.all_worker_names: _WaitTillInit(worker) else: _WaitTillInit(None) shared_model = self._MaybeConstructSharedModel(train_cfg) self._program_schedule_dict = {} self._programs = [] for task_string, program_schedule_params in ps_params_dict.items(): program_schedule_params.logdir = logdir program_schedule_params.num_splits_per_client = data_parallelism program_schedule_params.task_name = task_string # If the model was created above, we'll inject it here as a shared_model. ps = program_schedule_params.Instantiate(shared_model=shared_model, tf_master=self._tf_master) self._program_schedule_dict[task_string] = ps tf.logging.info('program_schedule_params: %s', program_schedule_params.ToText()) self._programs += ps.Programs() if program_schedule_params.ml_perf.benchmark_name is not None: self._ml_perf = program_schedule_params.ml_perf tf.logging.info('num_programs: %d', len(self._programs)) with self._graph.as_default(), tf.container(self._container_id): with self._cluster, tf.device(self._cluster.GetPlacer()): with py_utils.VariableRenameScope( self._variable_renaming_rules): _ = py_utils.GetOrCreateGlobalStepVar() for program in self._programs: program.BuildTpuSubgraph() py_utils.ClearTpuSummaryTensors() self._initialize_tables = tf.tables_initializer() self._initialize_local_vars = tf.local_variables_initializer() self._initialize_global_vars = tf.global_variables_initializer( ) for program in self._programs: program.SetStatusMessageFn(self._SetStatusMessage) program.CreateCheckpointer( init_op=self._initialize_global_vars) self.save_only_checkpointer = checkpointer.Checkpointer( self._checkpoint_dir, model=None, init_op=self._initialize_global_vars, train_params=train_cfg.train, save_only=True) tpu_embedding_collection = ( tpu_embedding_layers.TpuEmbeddingCollection.Get()) self._load_ops = tpu_embedding_collection.load_ops self._retrieve_ops = tpu_embedding_collection.retrieve_ops self._tpu_embedding = tpu_embedding_collection.tpu_embedding tf.io.write_graph(self._graph.as_graph_def(), self._checkpoint_dir, 'train.pbtxt')
def _LoopEnqueue(self, op, session_override=None): """Runs the enqueue op in a loop.""" p = self.params sess = session_override or self._GetSession() with tf.container(self._container_id), sess: if self._initialize_tables is not None: sess.run(self._initialize_tables) for task in self._model.tasks: task.input.Initialize(sess) gsteps = py_utils.GetGlobalStep() local_enqueue_steps = 0 # Global enqueue steps measures how many global steps have data enqueued # for already. We use this to terminate; note that the enqueue op may # hang in session.run if we do not terminate with this check. global_enqueue_steps = None tf.logging.info( 'params.train.max_steps: %d, enqueue_max_steps: %d', p.train.max_steps, p.train.enqueue_max_steps) while True: if self._dequeue_thread_complete: tf.logging.info( 'LoopEnqueue done since consuming thread is done.') return global_step = sess.run(gsteps) if global_enqueue_steps is None: global_enqueue_steps = global_step if local_enqueue_steps % 1000 == 0: tf.logging.info( 'Current global_enqueue_steps: %d, ' 'local_enqueue_steps: %d, global_step: %d', global_enqueue_steps, local_enqueue_steps, global_step) if py_utils.use_tpu(): global_steps_with_available_data = int( global_enqueue_steps // p.train.tpu_steps_per_loop * p.train.tpu_steps_per_loop) else: global_steps_with_available_data = global_enqueue_steps if (self._ShouldStop(sess, global_steps_with_available_data) or self._ShouldStop(sess, global_step)): tf.logging.info('Done. ShouldStop is True.') tf.logging.info('Enqueue loop sleeping') time.sleep(15) continue if (p.train.enqueue_max_steps > 0 and local_enqueue_steps >= p.train.enqueue_max_steps): tf.logging.info('Done. train.enqueue_max_steps reached.') tf.logging.info('Enqueue loop sleeping') time.sleep(15) continue local_enqueue_steps += 1 # There are tpu_infeed_parallelism parallel threads enqueuing. # We account for all of them when updating global_enqueue_steps. global_enqueue_steps += p.input.tpu_infeed_parallelism # Input data stats generated during training are collected and logged in # in input generators. The merged summary op for input data stats merges # all the scalar summaries for the stats logged from the input # generators. If merged scalar summaries for input data stats are # available write them to the training directory along with processing # the TPU infeed op. if self._merged_input_data_summary_op is not None: summary_str, _ = sess.run( [self._merged_input_data_summary_op, op]) self._WriteInputDataStatSummaries(summary_str, global_enqueue_steps) else: sess.run([op])
def _Loop(self): # Evaler/Controller jobs may find that the trial is infeasible and report # done earlier. This is an important check since the trainer may retry # indefinitely without it. if self._trial.ShouldStop(): tf.logging.info('Training skipped (trial requested to stop).') return with tf.container( self._container_id), self._cluster, self._GetSession() as sess: # This initializes local tables sess.run(self._initialize_tables) # This initializes local variables. sess.run(self._initialize_local_vars) global_step = self._WaitUntilInit(sess, self._start_up_delay_steps) status_interval_steps = 100 next_status_step = 1 eval_metrics = None while True: if (self._trial.ShouldStopAndMaybeReport(global_step, eval_metrics) or self._ShouldStop(sess, global_step)): tf.logging.info('Training finished.') if self._early_stop: time.sleep(300) # controller hangs if it doesn't finish first self._DequeueThreadComplete() return # If a task is explicitly specified, only train that task. if self._model_task_name: task = self._model.GetTask(self._model_task_name) else: # Note: This is a slightly stale global_step value from the previous # sess.run() call. # For multi-task models, `self._model.task_schedule.cur_probs` will # be updated. task = self._model.SampleTask(global_step) if self._task_probs_summary_writers: for index, prob in enumerate(self._model.task_schedule.cur_probs): self._SummarizeValue(global_step, 'task_probability', prob, self._task_probs_summary_writers[index]) try: for index, task in enumerate(self._model.tasks): self._SummarizeValue(global_step, 'task_weight', sess.run(task.vars.task_weight), self._task_probs_summary_writers[index]) except AttributeError: pass (_, eval_metrics, per_example_tensors) = sess.run([ task.train_op, task.eval_metrics, task.per_example_tensors, ]) # Explicitly fetch global_step after running train_op. # TODO(b/151181934): Investigate this behavior further. task_global_step = sess.run(task.global_step) task.ProcessFPropResults(sess, task_global_step, eval_metrics, per_example_tensors) global_step = sess.run(self._model.global_step) step_rate, example_rate, total_examples = ( self._step_rate_tracker.ComputeStepRate( global_step, eval_metrics['num_samples_in_batch'][0])) self._SummarizeValue(global_step, 'global_step/sec', step_rate) self._SummarizeValue(global_step, 'examples/sec', example_rate) self._SummarizeValue(global_step, 'total_samples', total_examples) msg = 'step:%6d, steps/sec: %0.2f, examples/sec: %0.2f' % ( global_step, step_rate, example_rate) for key, (val, _) in sorted(eval_metrics.items()): msg += ' %s:%.8g' % (key, val) self._SummarizeValue(global_step, key, val) if global_step >= next_status_step: self._SetStatusMessage(msg) self._ExportMetrics( # Metrics expects python int, but global_step is numpy.int64. global_step=int(global_step), step_rate=step_rate, example_rate=example_rate) next_status_step = global_step + status_interval_steps else: tf.logging.info(msg) self._model.ProcessFPropResults(sess, global_step, eval_metrics, per_example_tensors)
def __init__(self, train_cfg, ps_params_dict, model_task_name, logdir, tf_master, **kwargs): """Construct an ExecutorTpu BaseRunner. Args: train_cfg: SingleTaskModelParams or MultiTaskModelParams ps_params_dict: A dict of top-level task name -> ProgramSchedule params, if train_cfg is a SingleTaskModelParams, we expect only one entry. model_task_name: An override for multi-task models, currently unused. logdir: String path to the log directory to output to. tf_master: String path to the master job, e.g. 'local'. **kwargs: keyword args to pass through to BaseRunner. """ super(ExecutorTpu, self).__init__(train_cfg, model_task_name, logdir, tf_master, **kwargs) self._cluster_def = self._cluster.worker_cluster_def # There is a single Executor task assert self._cluster.num_replicas == 1 data_parallelism = self._cluster.num_splits_per_client assert data_parallelism num_devices_per_split = self._cluster.num_devices_per_split tf.logging.info('data_parallelism: %d, num_devices_per_split: %d', data_parallelism, num_devices_per_split) self.task_scheduler = None # If this is a multi-task model, grab the params for the TaskScheduler. if issubclass(train_cfg.cls, base_model.SingleTaskModel): tf.logging.info('single_task_model') assert len(ps_params_dict) == 1 self._model_task_name = ps_params_dict.keys()[0] self._single_task_mode = True elif issubclass(train_cfg.cls, base_model.MultiTaskModel): tf.logging.info('multi_task_model') if train_cfg.task_schedule is None: task_schedule_params = task_scheduler.ConstantScheduler.Params( ) task_schedule_params.task_probs = sorted( list(train_cfg.task_probs.IterParams())) else: task_schedule_params = train_cfg.task_schedule self.task_scheduler = task_schedule_params.Instantiate() self._single_task_mode = False else: tf.logging.fatal( 'Model %s is not a sub-class of SingleTaskModel or MultiTaskModel', train_cfg.cls) tf.logging.info('train_cfg.cls: %s', train_cfg.cls) self._program_schedule_dict = {} self._programs = [] for task_string, program_schedule_params in ps_params_dict.items(): program_schedule_params.logdir = logdir program_schedule_params.num_splits_per_client = data_parallelism program_schedule_params.task_name = task_string ps = program_schedule_params.Instantiate() self._program_schedule_dict[task_string] = ps tf.logging.info('program_schedule_params: %s', program_schedule_params.ToText()) self._programs += ps.Programs() tf.logging.info('num_programs: %d', len(self._programs)) # BaseRunner legacy self.enqueue_ops = None def ComputationShape(split_size): """Decides the computation shape based on the split_size.""" computation_shape = None if split_size == 1: computation_shape = [1, 1, 1] elif split_size == 2: computation_shape = [1, 1, 2] elif split_size == 4: computation_shape = [1, 2, 2] elif split_size == 8: computation_shape = [2, 2, 2] elif split_size == 16: computation_shape = [4, 2, 2] else: assert False, ( 'Model parallelism with %d devices is currently not' ' supported.' % split_size) assert computation_shape is not None return computation_shape @py_utils.RetryOnTransientTfError() def _WaitTillInit(): """Wait until the model is ready.""" try: with self._GetSession(cluster_def=self._cluster_def) as sess: topology = sess.run( tf.tpu.initialize_system(embedding_config=None, job=None)) device_assignment = device_assignment_lib.device_assignment( topology, computation_shape=ComputationShape( num_devices_per_split), num_replicas=data_parallelism) py_utils.SetTpuDeviceAssignment(device_assignment) tf.logging.info('device_assignment.core_assignment: %s', str(device_assignment.core_assignment)) tf.logging.info( 'device_assignment.topology.device_coordinates: %s', str(device_assignment.topology.device_coordinates)) except py_utils.transient_tf_errors as e: tf.logging.info('TPU initialization failed: %s', e) raise _WaitTillInit() with self._graph.as_default(), tf.container(self._container_id): with self._cluster, tf.device(self._cluster.job_spec.name): for program in self._programs: program.BuildTpuSubgraph() self.initialize_tables = tf.tables_initializer() self._initialize_local_vars = tf.local_variables_initializer() self._checkpoint_dir = os.path.join(logdir, 'train') self.save_only_checkpointer = checkpointer.Checkpointer( self._checkpoint_dir, model=None, train_params=train_cfg.train, save_only=True)
def _LoopEnqueue(self, op, session_override=None): """Runs the enqueue op in a loop.""" p = self.params sess = session_override or self._GetSession() with tf.container(self._container_id), sess: if self._initialize_tables is not None: sess.run(self._initialize_tables) for task in self._model.tasks: task.input.Initialize(sess) gsteps = py_utils.GetGlobalStep() local_enqueue_steps = 0 # Global enqueue steps measures how many global steps have data enqueued # for already. We use this to terminate; note that the enqueue op may # hang in session.run if we do not terminate with this check. global_enqueue_steps = None tf.logging.info( 'params.train.max_steps: %d, enqueue_max_steps: %d', p.train.max_steps, p.train.enqueue_max_steps) while True: if self._dequeue_thread_complete: tf.logging.info( 'LoopEnqueue done since consuming thread is done.') return global_step = sess.run(gsteps) if global_enqueue_steps is None: global_enqueue_steps = global_step if local_enqueue_steps % 1000 == 0: tf.logging.info( 'Current global_enqueue_steps: %d, ' 'local_enqueue_steps: %d, global_step: %d', global_enqueue_steps, local_enqueue_steps, global_step) if py_utils.use_tpu(): global_steps_with_available_data = int( global_enqueue_steps // p.train.tpu_steps_per_loop * p.train.tpu_steps_per_loop) else: global_steps_with_available_data = global_enqueue_steps if (self._ShouldStop(sess, global_steps_with_available_data) or self._ShouldStop(sess, global_step)): tf.logging.info('Done. ShouldStop is True.') tf.logging.info('Enqueue loop sleeping') time.sleep(15) continue if (p.train.enqueue_max_steps > 0 and local_enqueue_steps >= p.train.enqueue_max_steps): tf.logging.info('Done. train.enqueue_max_steps reached.') tf.logging.info('Enqueue loop sleeping') time.sleep(15) continue local_enqueue_steps += 1 # There are tpu_infeed_parallelism parallel threads enqueuing. # We account for all of them when updating global_enqueue_steps. global_enqueue_steps += p.input.tpu_infeed_parallelism sess.run([op])
def __init__(self, task_dict, program_schedule_params, model_task_name, logdir, tf_master, **kwargs): """Construct an ExecutorTpu BaseRunner. Args: task_dict: A dict of dataset_name -> task params. program_schedule_params: A ProgramSchedule params. model_task_name: An override for multi-task models, currently unused. logdir: String path to the log directory to output to. tf_master: String path to the master job, e.g. 'local'. **kwargs: keyword args to pass through to BaseRunner. """ # TODO(blee): fix this. train_params = task_dict['Train'] super(ExecutorTpu, self).__init__(train_params, model_task_name, logdir, tf_master, **kwargs) self._cluster_def = self._cluster.worker_cluster_def # There is a single Executor task assert self._cluster.num_replicas == 1 data_parallelism = self._cluster.num_splits_per_client assert data_parallelism num_devices_per_split = self._cluster.num_devices_per_split tf.logging.info('data_parallelism: %d, num_devices_per_split: %d', data_parallelism, num_devices_per_split) # Update run-time params program_schedule_params.task_dict = task_dict program_schedule_params.logdir = logdir program_schedule_params.num_splits_per_client = data_parallelism self._programs = [] self._program_schedule = program_schedule_params.Instantiate() tf.logging.info('program_schedule_params: %s', program_schedule_params.ToText()) self._programs += self._program_schedule.Programs() # BaseRunner legacy self.enqueue_ops = None def ComputationShape(split_size): """Decides the computation shape based on the split_size.""" computation_shape = None if split_size == 1: computation_shape = [1, 1, 1] elif split_size == 2: computation_shape = [1, 1, 2] elif split_size == 4: computation_shape = [1, 2, 2] elif split_size == 8: computation_shape = [2, 2, 2] elif split_size == 16: computation_shape = [4, 2, 2] else: assert False, ( 'Model parallelism with %d devices is currently not' ' supported.' % split_size) assert computation_shape is not None return computation_shape @py_utils.RetryOnTransientTfError() def _WaitTillInit(): """Wait until the model is ready.""" try: with self._GetSession(cluster_def=self._cluster_def) as sess: topology = sess.run( tf.tpu.initialize_system(embedding_config=None, job=None)) device_assignment = device_assignment_lib.device_assignment( topology, computation_shape=ComputationShape( num_devices_per_split), num_replicas=data_parallelism) py_utils.SetTpuDeviceAssignment(device_assignment) tf.logging.info('device_assignment.core_assignment: %s', str(device_assignment.core_assignment)) tf.logging.info( 'device_assignment.topology.device_coordinates: %s', str(device_assignment.topology.device_coordinates)) except py_utils.transient_tf_errors as e: tf.logging.info('TPU initialization failed: %s', e) raise _WaitTillInit() with self._graph.as_default(), tf.container(self._container_id): with self._cluster, tf.device(self._cluster.job_spec.name): for program in self._programs: program.BuildTpuSubgraph() self.initialize_tables = tf.tables_initializer() self._initialize_local_vars = tf.local_variables_initializer()