Ejemplo n.º 1
0
num_training = int(len(dataset) * 0.8)
num_val = int(len(dataset) * 0.1)
num_test = len(dataset) - (num_training + num_val)
training_set, validation_set, test_set = random_split(
    dataset, [num_training, num_val, num_test])

train_loader = DataLoader(training_set,
                          batch_size=args.batch_size,
                          shuffle=True)
val_loader = DataLoader(validation_set,
                        batch_size=args.batch_size,
                        shuffle=False)
test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False)

model = Model(args).to(args.device)
optimizer = torch.optim.Adam(model.parameters(),
                             lr=args.lr,
                             weight_decay=args.weight_decay)


def train():
    min_loss = 1e10
    patience_cnt = 0
    val_loss_values = []
    best_epoch = 0

    t = time.time()
    model.train()
    for epoch in range(args.epochs):
        loss_train = 0.0
        correct = 0
Ejemplo n.º 2
0
class ModelTrainer:
    """Class for training and testing of model"""
    def __init__(self, args):
        self.epoch_timer = utils.TimeIt(print_str="Epoch")
        self.args = args

        if self.args.training_mode == "gmm":
            self.dataloader = DataLoader(self.args)
        else:
            if self.args.eval:
                if self.args.eval_checkpoint == "":
                    raise ValueError(
                        "Eval mode is set, but no checkpoint path is provided!"
                    )
                self.loader = torch.load(self.args.eval_checkpoint)

            self.dataloader = DataLoader(self.args)

            # Load the model
            self.model = Model(self.args)

            if self.args.eval:
                self.model.load_state_dict(self.loader)

            if self.args.cuda:
                self.model.cuda()

            self.best_test_accuracy = 0.0
            self.best_test_epoch = 0

            if self.args.eval is False:

                if self.args.optimiser == "sgd":
                    self.opt = optim.SGD(
                        self.model.parameters(),
                        lr=self.args.learning_rate,
                        momentum=self.args.momentum,
                        weight_decay=self.args.weight_decay,
                    )
                elif self.args.optimiser == "adam":
                    self.opt = optim.Adam(
                        self.model.parameters(),
                        lr=self.args.learning_rate,
                        weight_decay=self.args.weight_decay,
                    )
                else:
                    raise Exception("Unknown optimiser {}".format(
                        self.args.optim))

                if self.args.lr_scheduler:
                    self.lr_scheduler = optim.lr_scheduler.MultiStepLR(
                        self.opt,
                        milestones=self.args.lr_schedule,
                        gamma=self.args.lr_decay_factor,
                    )
                if self.args.lr_reducer:
                    self.lr_reducer = torch.optim.lr_scheduler.ReduceLROnPlateau(
                        self.opt,
                        factor=np.sqrt(0.1),
                        cooldown=0,
                        patience=5,
                        min_lr=0.5e-6,
                    )

                # Loss function
                self.criterion = nn.CrossEntropyLoss()

                self.args.logdir = os.path.join("checkpoints",
                                                self.args.exp_name)
                utils.create_dir(self.args.logdir)

                if self.args.filelogger:
                    self.logger_path = os.path.join(
                        "checkpoints",
                        self.args.exp_name,
                        "%s_values.log" % self.args.exp_name,
                    )
                    self.logger = {
                        "train_loss_per_iter": [],
                        "train_loss_per_epoch": [],
                        "val_loss_per_iter": [],
                        "val_loss_per_epoch": [],
                        "val_accuracy_per_iter": [],
                        "val_accuracy_per_epoch": [],
                        "test_loss": [],
                        "test_accuracy": [],
                        "best_epoch": 0,
                        "best_test_accuracy": 0.0,
                        "ssl_loss": [],
                        "ssl_accuracy": [],
                        "ssl_correct": [],
                    }
                if self.args.tensorboard:
                    self.writer = SummaryWriter(log_dir=self.args.logdir,
                                                flush_secs=30)
                    self.writer.add_text("Arguments",
                                         params.print_args(self.args))

    def train_val(self, epoch):
        """Train the model for one epoch and evaluate on val split if log_intervals have passed"""

        for batch_idx, batch in enumerate(self.dataloader.train_loader):
            self.model.train()
            self.opt.zero_grad()

            self.iter += 1

            images, targets, indices = batch
            if self.args.cuda:
                images, targets = images.cuda(), targets.cuda()

            logits, unnormalised_scores = self.model(images)
            loss = self.criterion(unnormalised_scores, targets)
            loss.backward()
            self.opt.step()

            if batch_idx % self.args.log_interval == 0:
                val_loss, val_acc = self.evaluate("Val", n_batches=4)

                train_loss, val_loss, val_acc = utils.convert_for_print(
                    loss, val_loss, val_acc)

                if self.args.filelogger:
                    self.logger["train_loss_per_iter"].append(
                        [self.iter, train_loss])
                    self.logger["val_loss_per_iter"].append(
                        [self.iter, val_loss])
                    self.logger["val_accuracy_per_iter"].append(
                        [self.iter, val_acc])

                if self.args.tensorboard:
                    self.writer.add_scalar("Loss_at_Iter/Train", train_loss,
                                           self.iter)
                    self.writer.add_scalar("Loss_at_Iter/Val", val_loss,
                                           self.iter)
                    self.writer.add_scalar("Accuracy_at_Iter/Val", val_acc,
                                           self.iter)

                examples_this_epoch = batch_idx * len(images)
                epoch_progress = 100.0 * batch_idx / len(
                    self.dataloader.train_loader)
                print("Train Epoch: %3d [%5d/%5d (%5.1f%%)]\t "
                      "Train Loss: %0.6f\t Val Loss: %0.6f\t Val Acc: %0.1f" %
                      (
                          epoch,
                          examples_this_epoch,
                          len(self.dataloader.train_loader.dataset),
                          epoch_progress,
                          train_loss,
                          val_loss,
                          val_acc,
                      ))
        if self.args.lr_reducer:
            val_loss, val_acc = self.evaluate("Val", n_batches=None)
            self.lr_reducer.step(val_loss)

        val_loss, val_acc = utils.convert_for_print(val_loss, val_acc)

        if self.args.filelogger:
            self.logger["train_loss_per_epoch"].append([epoch, train_loss])
            self.logger["val_loss_per_epoch"].append([epoch, val_loss])
            self.logger["val_accuracy_per_epoch"].append([epoch, val_acc])

        if self.args.tensorboard:
            self.writer.add_scalar("Loss_at_Epoch/Train", train_loss, epoch)
            self.writer.add_scalar("Loss_at_Epoch/Val", val_loss, epoch)
            self.writer.add_scalar("Accuracy_at_Epoch/Val", val_acc, epoch)

    def evaluate(self, split, epoch=None, verbose=False, n_batches=None):
        """Evaluate model on val or test data"""

        self.model.eval()
        with torch.no_grad():
            loss = 0
            correct = 0
            n_examples = 0

            if split == "Val":
                loader = self.dataloader.val_loader
            elif split == "Test":
                loader = self.dataloader.test_loader

            for batch_idx, batch in enumerate(loader):
                images, targets, _ = batch
                if args.cuda:
                    images, targets = images.cuda(), targets.cuda()

                logits, unnormalised_scores = self.model(images)
                loss += F.cross_entropy(unnormalised_scores,
                                        targets,
                                        reduction="sum")
                pred = logits.max(1, keepdim=False)[1]
                correct += pred.eq(targets).sum()
                n_examples += pred.shape[0]
                if n_batches and (batch_idx >= n_batches):
                    break

            loss /= n_examples
            acc = 100.0 * correct / n_examples

            if split == "Test" and acc >= self.best_test_accuracy:
                self.best_test_accuracy = utils.convert_for_print(acc)
                self.best_test_epoch = epoch
                if self.args.filelogger:
                    self.logger["best_epoch"] = self.best_test_epoch
                    self.logger["best_test_accuracy"] = self.best_test_accuracy
            if verbose:
                if epoch is None:
                    epoch = 0
                    self.best_test_epoch = 0
                loss, acc = utils.convert_for_print(loss, acc)
                print(
                    "\n%s set Epoch: %2d \t Average loss: %0.4f, Accuracy: %d/%d (%0.1f%%)"
                    % (split, epoch, loss, correct, n_examples, acc))
                print(
                    "Best %s split Performance: Epoch %d - Accuracy: %0.1f%%" %
                    (split, self.best_test_epoch, self.best_test_accuracy))

                if self.args.filelogger:
                    self.logger["test_loss"].append([epoch, loss])
                    self.logger["test_accuracy"].append([epoch, acc])
                if self.args.tensorboard:
                    self.writer.add_scalar("Loss_at_Epoch/Test", loss, epoch)
                    self.writer.add_scalar("Accuracy_at_Epoch/Test", acc,
                                           epoch)
                    self.writer.add_scalar(
                        "Accuracy_at_Epoch/Best_Test_Accuracy",
                        self.best_test_accuracy,
                        self.best_test_epoch,
                    )

        return loss, acc

    def generate_labels_for_ssl(self, epoch, n_batches=None, verbose=False):

        self.model.eval()
        with torch.no_grad():
            loss = 0
            correct = 0
            n_examples = 0

            predictions_indices = []
            predictions_labels = []

            loader = self.dataloader.unsupervised_train_loader

            for batch_idx, batch in enumerate(loader):
                images, targets, indices = batch
                if args.cuda:
                    images, targets = images.cuda(), targets.cuda()

                logits, unnormalised_scores = self.model(images)
                loss += F.cross_entropy(unnormalised_scores,
                                        targets,
                                        reduction="sum")
                pred = logits.max(1, keepdim=False)[1]
                correct += pred.eq(targets).sum()
                n_examples += pred.shape[0]

                predictions_indices.extend(indices.tolist())
                predictions_labels.extend(pred.tolist())

                if n_batches and (batch_idx >= n_batches):
                    break

            loss /= n_examples
            acc = 100.0 * correct / n_examples

            if verbose:
                loss, acc, correct = utils.convert_for_print(
                    loss, acc, correct)
                print(
                    "\nLabel Generation Performance Average loss: %0.4f, Accuracy: %d/%d (%0.1f%%)"
                    % (loss, correct, n_examples, acc))

                # TODO: Add logging
                if self.args.filelogger:
                    self.logger["ssl_loss"].append([epoch, loss])
                    self.logger["ssl_accuracy"].append([epoch, acc])
                    self.logger["ssl_correct"].append([epoch, correct])
                if self.args.tensorboard:
                    self.writer.add_scalar("SSL_Loss_at_Epoch", loss, epoch)
                    self.writer.add_scalar("SSL_Accuracy_at_Epoch", acc, epoch)
                    self.writer.add_scalar("SSL_Correct_Labels_at_Epoch",
                                           correct, epoch)

        return predictions_indices, predictions_labels

    def train_val_test(self):
        """ Function to train, validate and evaluate the model"""
        self.iter = 0
        for epoch in range(1, self.args.epochs + 1):
            self.train_val(epoch)
            self.evaluate("Test", epoch, verbose=True)
            if self.args.lr_scheduler:
                self.lr_scheduler.step()
            if epoch % self.args.checkpoint_save_interval == 0:
                print("Saved %s/%s_epoch%d.pt\n" %
                      (self.args.logdir, self.args.exp_name, epoch))
                torch.save(
                    self.model.state_dict(),
                    "%s/%s_epoch%d.pt" %
                    (self.args.logdir, self.args.exp_name, epoch),
                )
            self.epoch_timer.tic(verbose=True)

        if self.args.tensorboard:
            if self.args.filelogger:
                text = "Epoch: %d Test Accuracy:%0.1f" % (
                    self.logger["best_epoch"],
                    self.logger["best_test_accuracy"],
                )
                self.writer.add_text("Best Test Performance", text)
            self.writer.close()
        if self.args.filelogger:
            utils.write_log_to_json(self.logger_path, self.logger)
        self.epoch_timer.time_since_init(print_str="Total")

    def ssl_train_val_test(self):
        """ Function to train, validate and evaluate the model"""
        self.iter = 0
        predictions_indices, predictions_labels = [], []
        self.dataloader.stop_label_generation = False

        for epoch in range(1, self.args.epochs + 1):
            if not self.dataloader.stop_label_generation:
                self.dataloader.ssl_init_epoch(predictions_indices,
                                               predictions_labels)

            self.train_val(epoch)
            self.evaluate("Test", epoch, verbose=True)

            if not self.dataloader.stop_label_generation:
                predictions_indices, predictions_labels = self.generate_labels_for_ssl(
                    epoch, n_batches=4, verbose=True)

            if self.args.lr_scheduler:
                self.lr_scheduler.step()
            if epoch % self.args.checkpoint_save_interval == 0:
                print("Saved %s/%s_epoch%d.pt\n" %
                      (self.args.logdir, self.args.exp_name, epoch))
                torch.save(
                    self.model.state_dict(),
                    "%s/%s_epoch%d.pt" %
                    (self.args.logdir, self.args.exp_name, epoch),
                )
            self.epoch_timer.tic(verbose=True)

        if self.args.tensorboard:
            if self.args.filelogger:
                text = "Epoch: %d Test Accuracy:%0.1f" % (
                    self.logger["best_epoch"],
                    self.logger["best_test_accuracy"],
                )
                self.writer.add_text("Best Test Performance", text)
            self.writer.close()
        if self.args.filelogger:
            utils.write_log_to_json(self.logger_path, self.logger)
        self.epoch_timer.time_since_init(print_str="Total")

    def gmm_train_val_test(self):
        train_data = self.dataloader.full_supervised_train_dataset.train_data
        train_labels = self.dataloader.full_supervised_train_dataset.train_labels

        train_data = train_data / 255

        mean = np.array(
            self.args.cifar10_mean_color)[np.newaxis][np.newaxis][np.newaxis]
        std = np.array(
            self.args.cifar10_std_color)[np.newaxis][np.newaxis][np.newaxis]

        train_data = (train_data - mean) / std

        train_data = train_data.reshape(train_data.shape[0], -1)

        cv_types = ["spherical", "diag", "full", "tied"]
        for cv_type in cv_types:
            gmm = mixture.GaussianMixture(n_components=10,
                                          covariance_type=cv_type)
            gmm.fit(train_data)
            clusters = gmm.predict(train_data)
            labels = np.zeros_like(clusters)
            for i in range(10):
                mask = clusters == i
                labels[mask] = mode(train_labels[mask])[0]

            correct1 = np.equal(clusters, train_labels).sum()
            correct2 = np.equal(labels, train_labels).sum()
            print("%d/49000 (%0.2f%%)" % (correct1, correct1 / 49000))
            print("%d/49000 (%0.2f%%)" % (correct2, correct2 / 49000))