예제 #1
0
 def save_training(self):
     print_with_time("Saving model ...")
     self.checkpoint_this(MODEL_FILE_NAME, self.model.state_dict(), torch_save=True)
     self.checkpoint_this(
         OPTIMIZER_FILE_NAME, self.optimizer.state_dict(), torch_save=True
     )
     if self.scheduler is not None:
         self.checkpoint_this(SCHEDULER_FILE_NAME, self.scheduler, torch_save=True)
예제 #2
0
    def train(self):
        num_epochs = self.figure_num_epochs()
        self._mark_the_run()
        print_with_time("Training for run number: {:d}".format(self.run_number))
        epoch_range_start = self.epoch_num
        epoch_range = range(epoch_range_start, epoch_range_start + num_epochs)

        # callback
        self.on_start_training(num_epochs)

        for epoch_num in epoch_range:
            self.epoch_num = epoch_num

            # callback
            self.on_start_epoch(epoch_num)

            # resetting metrics
            for n, m in self.metrics.items():
                if self.metrics[n].report_average:
                    m.reset_values()

            # training for 1 epoch
            dataloader = self.create_train_dataloader()
            self.train_1_epoch(epoch_num, dataloader)

            # saving
            if (epoch_num + 1) % self.save_every == 0:
                self.save_training()

            # evaluation
            eval_results = []
            if (epoch_num + 1) % self.eval_every == 0:
                for evaluator in self.evaluators:
                    eval_results.append(evaluator.evaluate())
                    evaluator.reset_storage()
                self.track_end_of_epoch_metrics(eval_results, epoch_num)

            # scheduler
            if self.scheduler is not None:
                # noinspection PyArgumentList
                self.scheduler.step(**self.figure_scheduler_input(eval_results))

            # callback
            self.on_finish_epoch(epoch_num)

        # callback
        self.on_finish_training(num_epochs)
예제 #3
0
    def add_value(self, dc_value: GeneralDataClass, step: int):
        loss_like_attr_names = dc_value.filter_attributes(
            is_scalar_like, initial_attr_list=dc_value.get_attribute_names())

        for attr_name in loss_like_attr_names:
            tag_name = f"{self.base_name}/{attr_name}"
            attr = getattr(dc_value, attr_name)
            if isinstance(attr, Tensor):
                value = attr.item()
            else:
                value = attr
            if self.writer:
                self.writer.add_scalar(tag_name,
                                       scalar_value=value,
                                       global_step=step)
            self.values[attr_name].append(value)
            if self.print_each_iter:
                print_with_time(f"(step {step}) {tag_name}: {value}")
예제 #4
0
    def train_1_epoch(self, epoch_number: int, dataloader: DataLoader):
        print_with_time("Training epoch %d ..." % (epoch_number + 1))
        self.model.to(self.device)
        self.model.train()

        for batch in tqdm(dataloader):
            # callback
            self.on_start_batch(self.iter_num, batch)

            # train for 1 batch
            batch_loss, forward_out = self._train_1_batch(self.iter_num, batch)

            # update metrics for 1 batch
            self.track_training_metrics(batch, forward_out, batch_loss, self.iter_num)

            # call back
            self.on_finish_batch(self.iter_num, batch, forward_out, batch_loss)
            self.iter_num += 1
예제 #5
0
 def epoch_finished(self, epoch_num: int):
     if self.report_average:
         average_values = {}
         for attr_name in self.values.keys():
             average_value = self.average_value(attr_name)
             average_values[attr_name] = average_value
             tag_name = f"{self.average_base_tag}/{attr_name}"
             if self.writer:
                 self.writer.add_scalar(
                     tag=tag_name,
                     scalar_value=average_value,
                     global_step=epoch_num + 1,
                 )
             print_with_time(f"{tag_name}: {average_value}")
         self.save(
             dictionary=average_values,
             name=Path(str(epoch_num + 1)) / self.base_name,
         )
     self.reset_values()