def save_training(self): print_with_time("Saving model ...") self.checkpoint_this(MODEL_FILE_NAME, self.model.state_dict(), torch_save=True) self.checkpoint_this( OPTIMIZER_FILE_NAME, self.optimizer.state_dict(), torch_save=True ) if self.scheduler is not None: self.checkpoint_this(SCHEDULER_FILE_NAME, self.scheduler, torch_save=True)
def train(self): num_epochs = self.figure_num_epochs() self._mark_the_run() print_with_time("Training for run number: {:d}".format(self.run_number)) epoch_range_start = self.epoch_num epoch_range = range(epoch_range_start, epoch_range_start + num_epochs) # callback self.on_start_training(num_epochs) for epoch_num in epoch_range: self.epoch_num = epoch_num # callback self.on_start_epoch(epoch_num) # resetting metrics for n, m in self.metrics.items(): if self.metrics[n].report_average: m.reset_values() # training for 1 epoch dataloader = self.create_train_dataloader() self.train_1_epoch(epoch_num, dataloader) # saving if (epoch_num + 1) % self.save_every == 0: self.save_training() # evaluation eval_results = [] if (epoch_num + 1) % self.eval_every == 0: for evaluator in self.evaluators: eval_results.append(evaluator.evaluate()) evaluator.reset_storage() self.track_end_of_epoch_metrics(eval_results, epoch_num) # scheduler if self.scheduler is not None: # noinspection PyArgumentList self.scheduler.step(**self.figure_scheduler_input(eval_results)) # callback self.on_finish_epoch(epoch_num) # callback self.on_finish_training(num_epochs)
def add_value(self, dc_value: GeneralDataClass, step: int): loss_like_attr_names = dc_value.filter_attributes( is_scalar_like, initial_attr_list=dc_value.get_attribute_names()) for attr_name in loss_like_attr_names: tag_name = f"{self.base_name}/{attr_name}" attr = getattr(dc_value, attr_name) if isinstance(attr, Tensor): value = attr.item() else: value = attr if self.writer: self.writer.add_scalar(tag_name, scalar_value=value, global_step=step) self.values[attr_name].append(value) if self.print_each_iter: print_with_time(f"(step {step}) {tag_name}: {value}")
def train_1_epoch(self, epoch_number: int, dataloader: DataLoader): print_with_time("Training epoch %d ..." % (epoch_number + 1)) self.model.to(self.device) self.model.train() for batch in tqdm(dataloader): # callback self.on_start_batch(self.iter_num, batch) # train for 1 batch batch_loss, forward_out = self._train_1_batch(self.iter_num, batch) # update metrics for 1 batch self.track_training_metrics(batch, forward_out, batch_loss, self.iter_num) # call back self.on_finish_batch(self.iter_num, batch, forward_out, batch_loss) self.iter_num += 1
def epoch_finished(self, epoch_num: int): if self.report_average: average_values = {} for attr_name in self.values.keys(): average_value = self.average_value(attr_name) average_values[attr_name] = average_value tag_name = f"{self.average_base_tag}/{attr_name}" if self.writer: self.writer.add_scalar( tag=tag_name, scalar_value=average_value, global_step=epoch_num + 1, ) print_with_time(f"{tag_name}: {average_value}") self.save( dictionary=average_values, name=Path(str(epoch_num + 1)) / self.base_name, ) self.reset_values()