예제 #1
0
    def train_main(self, cache=False):
        print("\n\n" + "=" * 100 + "\n\t\t\t\t\t Training Network\n" +
              "=" * 100)

        self.start = time.time()
        print("\nBeginning training at:  {} \n".format(
            datetime.datetime.now()))

        self.model.to(self.device)

        for self.epoch in range(self.start_epoch,
                                self.config['max_epoch'] + 1):
            train_times = []
            for self.iters, self.batch in enumerate(
                    self.config['train_loader']):
                self.model.train()

                iter_time = time.time()
                self.batch = self.batch_to_device(self.batch)
                self.train_iter_step()
                train_times.append(time.time() - iter_time)

                # Loss only
                if (self.total_iters + self.iters +
                        1) % self.config['log_every'] == 0:
                    ## Uncomment line below for debugging
                    if self.config['debug']:
                        LOGGER.info(
                            "Logging tensorboard at step %i with %i values" %
                            (self.iters + self.total_iters + 1,
                             len(self.short_loss_list)))
                    log_tensorboard(self.config,
                                    self.config['writer'],
                                    self.model,
                                    self.epoch,
                                    self.iters,
                                    self.total_iters,
                                    self.short_loss_list,
                                    loss_only=True,
                                    val=False)
                    self.config['writer'].add_scalar(
                        'Stats/time_per_train_iter', mean(train_times),
                        (self.iters + self.total_iters + 1))
                    self.config['writer'].add_scalar(
                        'Stats/learning_rate',
                        self.scheduler.get_last_lr()[0],
                        (self.iters + self.total_iters + 1))
                    train_times = []
                    self.short_loss_list = []
            self.train_epoch_step()

            if self.terminate_training:
                break

        self.end_training()
        return self.best_val_metrics, self.test_metrics
예제 #2
0
    def train_epoch_step(self):
        self.model.train()
        lr = self.scheduler.get_last_lr()
        self.total_iters += self.iters + 1
        self.probs_list = [
            pred for batch_pred in self.probs_list for pred in batch_pred
        ]
        self.labels_list = [
            label for batch_labels in self.labels_list
            for label in batch_labels
        ]

        # Evaluate on train set
        self.train_metrics = standard_metrics(torch.tensor(self.probs_list),
                                              torch.tensor(self.labels_list),
                                              add_optimal_acc=True)
        log_tensorboard(self.config,
                        self.config['writer'],
                        self.model,
                        self.epoch,
                        self.iters,
                        self.total_iters,
                        self.loss_list,
                        self.train_metrics,
                        lr[0],
                        loss_only=False,
                        val=False)
        self.train_loss = self.loss_list[:]

        # Evaluate on dev set
        val_time = time.time()
        self.val_metrics, self.val_loss = self.eval_model()
        self.config['writer'].add_scalar("Stats/time_validation",
                                         time.time() - val_time,
                                         self.total_iters)

        # print stats
        print_stats(self.config, self.epoch, self.train_metrics,
                    self.train_loss, self.val_metrics, self.val_loss,
                    self.start, lr[0])

        # log validation stats in tensorboard
        log_tensorboard(self.config,
                        self.config['writer'],
                        self.model,
                        self.epoch,
                        self.iters,
                        self.total_iters,
                        self.val_loss,
                        self.val_metrics,
                        lr[0],
                        loss_only=False,
                        val=True)

        # Check for early stopping criteria
        self.check_early_stopping()
        self.probs_list = []
        self.preds_list = []
        self.labels_list = []
        self.loss_list = []
        self.id_list = []

        self.train_loss = sum(self.train_loss) / len(self.train_loss)
        del self.val_metrics
        del self.val_loss