def train_one_batch(self, batch): """A normal training core without fetching data from iterator. """ model = self.elements["model"] model_forward = self.elements["model_forward"] optimizer = self.elements["optimizer"] if not model.training: model.train() if self.params["nan_debug"]: device = utils.get_device(self.elements["model"]) inputs = torch.load("{0}/nan.batch".format(self.params["model_dir"])).to(device) targets = torch.load("{0}/nan.targets".format(self.params["model_dir"])).to(device) self.elements["model"].load_state_dict(torch.load("{0}/nan.params".format(self.params["model_dir"]), map_location="cpu")) self.elements["model"].to(device) else: inputs, targets = batch optimizer.zero_grad() loss = model.get_loss(model_forward(inputs), targets) loss.backward() loss.detach() # For safe. if self.params["max_change"] > 0: # Reference:https://github.com/horovod/horovod/blob/master/horovod/torch/__init__.py:420~423. # Synchronize the grad for grad_norm when using horovod. if utils.use_horovod(): optimizer.synchronize() grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), self.params["max_change"]) if math.isnan(grad_norm): if self.params["nan_debug"]: raise RuntimeError("[NOT OK] Nan is still found in this debug.") torch.save(inputs.cpu(), "{0}/nan.batch".format(self.params["model_dir"])) torch.save(targets.cpu(), "{0}/nan.targets".format(self.params["model_dir"])) torch.save(self.elements["model"].state_dict(), "{0}/nan.params".format(self.params["model_dir"])) raise RuntimeError('There is Nan problem in iter/epoch: {0}/{1} (nan batch and params are saved in {2})'.format(self.training_point[1]+1, self.training_point[0]+1, "{0}/nan.*".format(self.params["model_dir"]))) else: if self.params["nan_debug"]: raise RuntimeError("[OK] There is no nan found for this debug.") if utils.use_horovod(): with optimizer.skip_synchronize(): optimizer.step() else: optimizer.step() else: optimizer.step() accuracy = model.get_accuracy(targets) if self.params["compute_accuracy"] else None return loss.item(), accuracy
def __init__(self, trainer): default_params = { "report_times_every_epoch": None, "report_interval_iters": 100, "record_file": "train.csv", "use_tensorboard": False } self.trainer = trainer default_params = utils.assign_params_dict(default_params, self.trainer.params) if default_params["report_times_every_epoch"] is not None: self.report_interval_iters = max( 1, self.trainer.training_point[2] // default_params["report_times_every_epoch"]) else: self.report_interval_iters = default_params[ "report_interval_iters"] if not self.trainer.params["debug"] and default_params[ "use_tensorboard"]: # from tensorboardX import SummaryWriter from torch.utils.tensorboard import SummaryWriter model_name = os.path.basename(self.trainer.params["model_dir"]) # time_string = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())) # time_string = self.trainer.params["time_string"] # self.board_writer = SummaryWriter("{}/log/{}-{}-tensorboard".format(self.trainer.params["model_dir"], model_name, time_string)) # self.board_writer = SummaryWriter("{}/log/{}-{}-tensorboard".format( # self.trainer.params["model_dir"], time_string, model_name)) self.board_writer = SummaryWriter("{}/log/tensorboard".format( self.trainer.params["model_dir"])) else: self.board_writer = None self.epochs = self.trainer.params["epochs"] self.optimizer = self.trainer.elements["optimizer"] # For optimizer wrapper such as lookahead. # "None" is the default value if getattr(self.optimizer, "optimizer", None) is not None: self.optimizer = self.optimizer.optimizer self.device = "[{0}]".format( utils.get_device(self.trainer.elements["model"])) self.record_value = [] self.start_write_log = False if not self.trainer.params["debug"] and default_params[ "record_file"] != "" and default_params[ "record_file"] is not None: self.record_file = "{0}/log/{1}".format( self.trainer.params["model_dir"], default_params["record_file"]) # The case to recover training if self.trainer.params["start_epoch"] > 0: # train.csv using append mode self.start_write_log = True elif os.path.exists(self.record_file): # Do backup to avoid clearing the loss log when re-running a same launcher. bk_file = "{0}.backup.{1}".format( self.record_file, time.strftime('%Y-%m-%d_%H:%M:%S', time.localtime(time.time()))) shutil.move(self.record_file, bk_file) else: self.record_file = None # A format to show progress # Do not use progressbar.Bar(marker="\x1b[32m█\x1b[39m") and progressbar.SimpleProgress(format='%(value_s)s/%(max_value_s)s') to avoid too long string. widgets = [ progressbar.Percentage(format='%(percentage)3.2f%%'), " | ", "Epoch:", progressbar.Variable('current_epoch', format='{formatted_value}', width=0, precision=0), "/{0}, ".format(self.epochs), "Iter:", progressbar.Variable('current_iter', format='{formatted_value}', width=0, precision=0), "/{0}".format(self.trainer.training_point[2]), " (", progressbar.Timer(format='ELA: %(elapsed)s'), ", ", progressbar.AdaptiveETA(), ")" ] # total num of iter max_value = self.trainer.params[ "epochs"] * self.trainer.training_point[2] self.bar = progressbar.ProgressBar(max_value=max_value, widgets=widgets, redirect_stdout=True) # Use multi-process for update. self.queue = Queue() self.process = Process(target=self._update, daemon=True) self.process.start()