def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._job_name = 'trainer' with self._graph.as_default(), tf.container(self._container_id): try: self._task_probs_summary_writers = [] for task in self._model.task_schedule.tasks: path = os.path.join(os.path.join(self._train_dir, task)) tf.io.gfile.makedirs(path) self._task_probs_summary_writers.append( self._CreateSummaryWriter(path)) except AttributeError: tf.logging.info( 'AttributeError. Expected for single task models.') self._task_probs_summary_writers = [] if self.params.cluster.task == 0: self._summary_writer = self._CreateSummaryWriter( self._train_dir) self._CreateTF2SummaryWriter(self._train_dir) else: self._summary_writer = None with self._cluster, tf.device( self._cluster.GetPlacer()), self._TF2SummaryContext(): self._model = self.params.Instantiate() self._params = self._model.params self._model.ConstructFPropBPropGraph() self._CreateTF2SummaryOps() self._initialize_tables = tf.tables_initializer() self._initialize_local_vars = tf.local_variables_initializer() self.enqueue_ops = tf.get_collection(py_utils.ENQUEUE_OPS) tf.logging.info('Trainer number of enqueue ops: %d', len(self.enqueue_ops)) self._step_rate_tracker = summary_utils.StepRateTracker() # Saves the graph def. if self.params.cluster.task == 0: self._WriteToLog(self.params.ToText(), self._train_dir, 'trainer_params.txt') tf.io.write_graph(self._graph.as_graph_def(), self._train_dir, 'train.pbtxt') worker_id = self.params.cluster.task self._start_up_delay_steps = (((worker_id + 1) * worker_id / 2) * self.params.train.start_up_delay_steps)
def __init__(self, params): super(TrainProgram, self).__init__(params) self._step_rate_tracker = summary_utils.StepRateTracker() self._program_name = 'TrainProgram'
def __init__(self, params, shared_model=None): super().__init__(params, shared_model=shared_model) self._step_rate_tracker = summary_utils.StepRateTracker() self._program_name = 'TrainProgram'
def Start(self): """Run training.""" super().Start() with self._cluster: model = self._params.Instantiate() ckptr = self._CreateCheckpointer(self._train_dir, model) task = model.GetTask(self._model_task_name) @tf.function(autograph=False) def TrainFunc(): with py_utils.GradientTape(persistent=True): model.ConstructFPropBPropGraph() return task.eval_metrics, task.per_example_tensors step_rate_tracker = summary_utils.StepRateTracker() summary_writer = tf.compat.v2.summary.create_file_writer(self._train_dir) # Attempt to restore the checkpoint # A 'dummy run' to initialze the optimizer and related slot variables # This is also needed for V2 checkpoint even though it supports delayed # loading, in case the checkpoint already exeeds max_steps. In that # scenario the slot variables will be lost without a dummy run due to # checkpoint overwrites. _, _ = TrainFunc() path = ckptr.Restore() if path: tf.logging.info(f'Loaded checkpoints from {path}.') else: tf.logging.info('Did not find any checkpoints. Starting fresh.') # Reset global_step manually if we could not load any checkpoints global_step = py_utils.GetOrCreateGlobalStepVar() global_step.assign(0) global_step = model.global_step.numpy() while True: if self._ShouldStop(global_step): break tf.logging.info('Starting train function.') metrics_dict, outfeed = TrainFunc() tf.logging.info('Train function complete.') global_step = model.global_step.numpy() if not task.per_example_tensors: assert not outfeed else: # TODO(laigd): debugging only, remove later. tf.logging.info(f'outfeed: {outfeed}') ckptr.MaybeSave(gsteps=global_step) step_rate, example_rate, total_examples = ( step_rate_tracker.ComputeStepRate( global_step, metrics_dict['num_samples_in_batch'][0].numpy())) msg = 'step:%6d, steps/sec: %0.2f, examples/sec: %0.2f' % ( global_step, step_rate, example_rate) # Write summaries. with summary_writer.as_default(): tf.compat.v2.summary.scalar( 'global_step/sec', step_rate, step=global_step) tf.compat.v2.summary.scalar( 'examples/sec', example_rate, step=global_step) tf.compat.v2.summary.scalar( 'total_samples', total_examples, step=global_step) for key, (val, _) in sorted(metrics_dict.items()): msg += ' %s:%.8g' % (key, val) tf.compat.v2.summary.scalar(key, val, step=global_step) summary_writer.flush() # Log training progress. self._SetStatusMessage(msg) # Also save at the end of training ckptr.Save(gsteps=global_step)