def __call__(self, start_epoch, experiment, verbose=0): ''' Execute training ''' with ExitStack() as stack: stack.enter_context(chunked_scattering()) stack.enter_context(experiment.train()) if start_epoch > 0 or experiment.curr_step > 0: # TODO: Hacky approach to decide if the metric store should be loaded. Revisit later self.metric_store = self.metric_store.load() epoch = start_epoch experiment.log_current_epoch(epoch) while not self.is_done(experiment, epoch): experiment.log_current_epoch(epoch) self.train_epoch(epoch, experiment, verbose) experiment.log_epoch_end(epoch) epoch += 1 if self.stopped_early: print('Stopping early!') else: new_best = False if self.config.early_stopping: new_best = self.evaluate(experiment, epoch, verbose) self.checkpoint(epoch, experiment.curr_step, new_best)
def __call__(self, start_epoch, experiment, verbose=0): ''' Execute training ''' with ExitStack() as stack: stack.enter_context(chunked_scattering()) stack.enter_context(experiment.train()) if start_epoch > 0 or experiment.curr_step > 0: self.metric_store = self.metric_store.load() epoch = start_epoch experiment.log_current_epoch(epoch) stats_filename = self.config.stats_filename or f'train_stats.pickle' stats_path = os.path.join(self.config.stats_directory, stats_filename) stats_file = stack.enter_context(open(stats_path, 'wb')) while not self.is_done(experiment, epoch): experiment.log_current_epoch(epoch) self.train_epoch(epoch, experiment, verbose) experiment.log_epoch_end(epoch) epoch += 1 self.save_stats(stats_file) if self.stopped_early: print('Stopping early!') else: new_best = False if self.config.early_stopping: new_best = self.evaluate(experiment, epoch, verbose) self.checkpoint(epoch, experiment.curr_step, new_best)
def __call__(self) -> float: """ Run the evaluation! """ dataloader = get_dataloader(self.args.data, self.dataset, num_devices=len(self.model.device_ids)) def get_description(): return f"Eval {self.metric_store}" batch_iterator = tqdm( dataloader, unit="batch", initial=1, dynamic_ncols=True, desc=get_description(), file=sys.stdout, # needed to make tqdm_wrap_stdout work ) with ExitStack() as stack: # pylint:disable=no-member stack.enter_context(tqdm_wrap_stdout()) stack.enter_context(chunked_scattering()) # pylint:enable=no-member for batch in batch_iterator: try: self.eval_step(batch) except RuntimeError as rte: if "out of memory" in str(rte): self.metric_store["oom"].update(1) logging.warning(str(rte)) else: batch_iterator.close() raise rte batch_iterator.set_description_str(get_description()) batch_iterator.close() return self.metric_store["nll"].average
def __call__(self): """ Run the training! """ # Must be called first self.try_init_amp() model = self.modules["model"] optimizer = self.modules["optimizer"] scheduler = self.modules["scheduler"] if self.args.optim.use_gradient_checkpointing: model.enable_gradient_checkpointing() model = nn.DataParallel(model) dataloader = get_dataloader( self.args.data, self.dataset, num_devices=len(model.device_ids), shuffle=True, ) def get_description(): return f"Train {self.metric_store}" max_steps = self.args.optim.max_steps accumulation_steps = self.args.optim.gradient_accumulation_steps progress = tqdm( unit="step", initial=self.step, dynamic_ncols=True, desc=get_description(), total=max_steps, file=sys.stdout, # needed to make tqdm_wrap_stdout work ) with ExitStack() as stack: # pylint:disable=no-member stack.enter_context(tqdm_wrap_stdout()) stack.enter_context(chunked_scattering()) stack.enter_context(self.experiment.train()) # pylint:enable=no-member if self.args.optim.early_stopping: # If using early stopping, must evaluate regularly to determine # if training should stop early, so setup an Evaluator eval_args = copy.deepcopy(self.args) eval_args.data.batch_size = self.args.optim.eval_batch_size evaluator = Evaluator(eval_args) evaluator.model = model evaluator.load_dataset("validation") evaluator.initialize_experiment(experiment=self.experiment) # Make sure we are tracking validation nll self.metric_store.add( metrics.Metric("vnll", "format_float", "g(m)")) # And store a local variable for easy access vnll_metric = self.metric_store["vnll"] loss = 0 num_tokens = 0 for step, batch in enumerate(cycle(dataloader), 1): try: step_loss = self.compute_gradients_and_loss( batch, model, optimizer) run_optimizer = (step % accumulation_steps) == 0 if run_optimizer: # Run an optimization step optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() # Update loss and num tokens after running an optimization # step, in case it results in an out of memory error loss += step_loss num_tokens += batch["num_tokens"] if run_optimizer: # Since we ran the optimizer, increment current step self.step += 1 self.experiment.set_step(self.step) progress.update() # update our metrics as well self.update_metrics( loss / accumulation_steps, num_tokens, scheduler.get_lr()[0], ) num_tokens = 0 loss = 0 # and finally check if we should save if (self.args.save_steps > 0 and self.step % self.args.save_steps == 0): # First save the current checkpoint self.save() # Then if we are implementing early stopping, see # if we achieved a new best if self.args.optim.early_stopping: evaluator.reset_metrics() with ExitStack() as eval_stack: # pylint:disable=no-member eval_stack.enter_context( tqdm_unwrap_stdout()) eval_stack.enter_context( release_cuda_memory( collect_tensors(optimizer.state))) # pylint:enable=no-member vnll = evaluator() vnll_metric.update(vnll) # Save the updated metrics self.save_metrics() if vnll == vnll_metric.min: self.on_new_best() # Try to combat OOM errors caused by doing evaluation # in the same loop with training. This manifests in out # of memory errors after the first or second evaluation # run. refresh_cuda_memory() if not self.prune_checkpoints(): logging.info("Stopping early") break if self.step >= max_steps: logging.info("Finished training") break except RuntimeError as rte: if "out of memory" in str(rte): self.metric_store["oom"].update(1) logging.warning(str(rte)) else: progress.close() raise rte progress.set_description_str(get_description()) progress.close()