Ejemplo n.º 1
0
class WeightEMA(object):
    def __init__(self, model, ema_model, run_type=0, alpha=0.999):
        self.model = model
        self.ema_model = ema_model
        self.alpha = alpha
        if run_type == 1:
            self.tmp_model = Classifier(num_classes=10).cuda()
        else:
            self.tmp_model = WideResNet(num_classes=10).cuda()
        self.wd = 0.02 * args.lr

        for param, ema_param in zip(self.model.parameters(),
                                    self.ema_model.parameters()):
            ema_param.data.copy_(param.data)

    def step(self, bn=False):
        if bn:
            # copy batchnorm stats to ema model
            for ema_param, tmp_param in zip(self.ema_model.parameters(),
                                            self.tmp_model.parameters()):
                tmp_param.data.copy_(ema_param.data.detach())

            self.ema_model.load_state_dict(self.model.state_dict())

            for ema_param, tmp_param in zip(self.ema_model.parameters(),
                                            self.tmp_model.parameters()):
                ema_param.data.copy_(tmp_param.data.detach())
        else:
            one_minus_alpha = 1.0 - self.alpha
            for param, ema_param in zip(self.model.parameters(),
                                        self.ema_model.parameters()):
                ema_param.data.mul_(self.alpha)
                ema_param.data.add_(param.data.detach() * one_minus_alpha)
                # customized weight decay
                param.data.mul_(1 - self.wd)
Ejemplo n.º 2
0
def load(config):
    if config.model == 'wideresnet':
        model = WideResNet(num_classes=config.dataset_classes)
        ema_model = WideResNet(num_classes=config.dataset_classes)
    else:
        model = CNN13(num_classes=config.dataset_classes)
        ema_model = CNN13(num_classes=config.dataset_classes)

    if config.semi_supervised == 'mix_match':
        semi_supervised = MixMatch(config)
        semi_supervised_loss = mix_match_loss
    elif config.semi_supervised == 'pseudo_label':
        semi_supervised = PseudoLabel(config)
        semi_supervised_loss = pseudo_label_loss

    model.to(config.device)
    ema_model.to(config.device)

    torch.backends.cudnn.benchmark = True

    optimizer = Adam(model.parameters(), lr=config.learning_rate)
    ema_optimizer = WeightEMA(model, ema_model, alpha=config.ema_decay)

    if config.resume:
        checkpoint = torch.load(config.checkpoint_path, map_location=config.device)

        model.load_state_dict(checkpoint['model_state'])
        ema_model.load_state_dict(checkpoint['ema_model_state'])
        optimizer.load_state_dict(checkpoint['optimizer_state'])

        # optimizer state should be moved to corresponding device
        for optimizer_state in optimizer.state.values():
            for k, v in optimizer_state.items():
                if isinstance(v, torch.Tensor):
                    optimizer_state[k] = v.to(config.device)

    return model, ema_model, optimizer, ema_optimizer, semi_supervised, semi_supervised_loss
Ejemplo n.º 3
0
class FullySupervisedTrainer:
    def __init__(self, batch_size, model_params, n_steps, optimizer, adam, sgd,
                 steps_validation, steps_checkpoint, dataset, save_path):

        self.n_steps = n_steps
        self.start_step = 0
        self.steps_validation = steps_validation
        self.steps_checkpoint = steps_checkpoint
        self.num_labeled = 50000
        self.train_loader, _, self.val_loader, self.test_loader, self.lbl_idx, _, self.val_idx = \
            get_dataloaders_with_index(path='../data',
                                       batch_size=batch_size,
                                       num_labeled=self.num_labeled,
                                       which_dataset=dataset,
                                       validation=False)
        print('Labeled samples: ' + str(len(self.train_loader.sampler)))

        self.batch_size = self.train_loader.batch_size
        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")
        print(self.device)

        depth, k, n_out = model_params
        self.model = WideResNet(depth=depth, k=k, n_out=n_out,
                                bias=True).to(self.device)
        self.ema_model = WideResNet(depth=depth, k=k, n_out=n_out,
                                    bias=True).to(self.device)
        for param in self.ema_model.parameters():
            param.detach_()

        if optimizer == 'adam':
            self.lr, self.weight_decay = adam
            self.momentum, self.lr_decay = None, None
            self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
            self.ema_optimizer = WeightEMA(self.model,
                                           self.ema_model,
                                           self.lr,
                                           alpha=0.999)

        else:
            self.lr, self.momentum, self.weight_decay, self.lr_decay = sgd
            self.optimizer = optim.SGD(self.model.parameters(),
                                       lr=self.lr,
                                       momentum=self.momentum,
                                       weight_decay=self.weight_decay,
                                       nesterov=True)
            self.ema_optimizer = None

        self.criterion = nn.CrossEntropyLoss()

        self.train_accuracies, self.train_losses, = [], []
        self.val_accuracies, self.val_losses, = [], []
        self.best_acc = 0

        self.augment = Augment(K=1)

        self.path = save_path
        self.writer = SummaryWriter()

    def train(self):
        iter_train_loader = iter(self.train_loader)

        for step in range(self.n_steps):
            self.model.train()
            # Get next batch of data
            try:
                x_input, x_labels, _ = iter_train_loader.next()
                # Check if batch size has been cropped for last batch
                if x_input.shape[0] < self.batch_size:
                    iter_train_loader = iter(self.train_loader)
                    x_input, x_labels, _ = iter_train_loader.next()
            except:
                iter_train_loader = iter(self.train_loader)
                x_input, x_labels, _ = iter_train_loader.next()

            # Send to GPU
            x_input = x_input.to(self.device)
            x_labels = x_labels.to(self.device)

            # Augment
            x_input = self.augment(x_input)
            x_input = x_input.reshape((-1, 3, 32, 32))

            # Compute X' predictions
            x_output = self.model(x_input)

            # Compute loss
            loss = self.criterion(x_output, x_labels)

            # Step
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            if self.ema_optimizer:
                self.ema_optimizer.step()

            # Decaying learning rate. Used in with SGD Nesterov optimizer
            if not self.ema_optimizer and step in self.lr_decay:
                for g in self.optimizer.param_groups:
                    g['lr'] *= 0.2

            # Evaluate model
            self.model.eval()
            if step > 0 and not step % self.steps_validation:
                val_acc, is_best = self.evaluate_loss_acc(step)
                if is_best:
                    self.save_model(step=step,
                                    path=f'{self.path}/best_checkpoint.pt')

            # Save checkpoint
            if step > 10000 and not step % self.steps_checkpoint:
                self.save_model(step=step,
                                path=f'{self.path}/checkpoint_{step}.pt')

        # --- Training finished ---
        test_val, test_acc = self.evaluate(self.test_loader)
        print("Training done!!\t Test loss: %.3f \t Test accuracy: %.3f" %
              (test_val, test_acc))

        self.writer.flush()

    # --- support functions ---

    def evaluate_loss_acc(self, step):
        val_loss, val_acc = self.evaluate(self.val_loader)
        self.val_losses.append(val_loss)
        self.val_accuracies.append(val_acc)
        train_loss, train_acc = self.evaluate(self.train_loader)
        self.train_losses.append(train_loss)
        self.train_accuracies.append(train_acc)

        is_best = False
        if val_acc > self.best_acc:
            self.best_acc = val_acc
            is_best = True

        print(
            "Step %d.\tLoss train_lbl/valid  %.2f  %.2f\t Accuracy train_lbl/valid  %.2f  %.2f \tBest acc %.2f \t%s"
            % (step, train_loss, val_loss, train_acc, val_acc, self.best_acc,
               time.ctime()))

        self.writer.add_scalar("Loss train_label", train_loss, step)
        self.writer.add_scalar("Loss validation", val_loss, step)
        self.writer.add_scalar("Accuracy train_label", train_acc, step)
        self.writer.add_scalar("Accuracy validation", val_acc, step)
        return val_acc, is_best

    def evaluate(self, dataloader):
        correct, total, loss = 0, 0, 0
        with torch.no_grad():
            for i, data in enumerate(dataloader, 0):
                inputs, labels = data[0].to(self.device), data[1].to(
                    self.device)
                if self.ema_optimizer:
                    outputs = self.ema_model(inputs)
                else:
                    outputs = self.model(inputs)
                loss += self.criterion(outputs, labels).item()

                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
            loss /= dataloader.__len__()

        acc = correct / total * 100
        return loss, acc

    def save_model(self, step=None, path='../models/model.pt'):
        if not step:
            step = self.n_steps  # Training finished

        torch.save(
            {
                'step': step,
                'model_state_dict': self.model.state_dict(),
                'ema_state_dict': self.ema_model.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict(),
                'loss_train': self.train_losses,
                'loss_val': self.val_losses,
                'acc_train': self.train_accuracies,
                'acc_val': self.val_accuracies,
                'steps': self.n_steps,
                'batch_size': self.batch_size,
                'num_labels': self.num_labeled,
                'lr': self.lr,
                'weight_decay': self.weight_decay,
                'momentum': self.momentum,
                'lr_decay': self.lr_decay,
                'lbl_idx': self.lbl_idx,
                'val_idx': self.val_idx,
            }, path)

    def load_checkpoint(self, model_name):
        saved_model = torch.load(f'../models/{model_name}')
        self.model.load_state_dict(saved_model['model_state_dict'])
        self.ema_model.load_state_dict(saved_model['ema_state_dict'])
        self.optimizer.load_state_dict(saved_model['optimizer_state_dict'])
        self.start_step = saved_model['step']

        self.train_loader, _, self.val_loader, self.test_loader, self.lbl_idx, _, self.val_idx = \
            get_dataloaders_with_index(path='../data',
                                       batch_size=self.batch_size,
                                       num_labeled=self.num_labeled,
                                       which_dataset='cifar10',
                                       lbl_idxs=saved_model['lbl_idx'],
                                       unlbl_idxs=[],
                                       valid_idxs=saved_model['val_idx'])
        print('Model ' + model_name + ' loaded.')
Ejemplo n.º 4
0
                            pin_memory=True,
                            shuffle=True)
    ulbl_loader = DataLoader(unlabeled_dataset,
                             batch_size=mu * B,
                             collate_fn=unlabeled_collate,
                             num_workers=args.num_workers3,
                             pin_memory=True,
                             shuffle=True)

    #  Model Settings

    model.to(device)
    ema = EMA(model, decay=0.999)
    ema.register()

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay,
                                nesterov=args.nesterov)
    scheduler = cosineLRreduce(optimizer, K, warmup=args.warmup_scheduler)

    train_fixmatch(model, ema, zip(lbl_loader, ulbl_loader), v_loader,
                   augmentation, optimizer, scheduler, device, K, tb_writer)
    tb_writer.close()

    # Save everything

    save_dir, prefix = os.path.expanduser(args.save_dir), str(
        datetime.datetime.now())
Ejemplo n.º 5
0
#------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
if args.model == 'wideresnet':
    cnn = WideResNet(depth=28, num_classes=num_classes, widen_factor=10).cuda()
elif args.model == 'wideresnet16_8':
    cnn = WideResNet(depth=16, num_classes=num_classes, widen_factor=8).cuda()
elif args.model == 'densenet':
    cnn = DenseNet3(depth=100,
                    num_classes=num_classes,
                    growth_rate=12,
                    reduction=0.5).cuda()
elif args.model == 'vgg13':
    cnn = VGG(vgg_name='VGG13', num_classes=num_classes).cuda()

prediction_criterion = nn.NLLLoss().cuda()

cnn_optimizer = torch.optim.SGD(cnn.parameters(),
                                lr=args.learning_rate,
                                momentum=0.9,
                                nesterov=True,
                                weight_decay=5e-4)

if args.dataset == 'svhn':
    scheduler = MultiStepLR(cnn_optimizer, milestones=[80, 120], gamma=0.1)
else:
    scheduler = MultiStepLR(cnn_optimizer,
                            milestones=[60, 120, 160],
                            gamma=0.2)

if args.model == 'densenet':
    cnn_optimizer = torch.optim.SGD(cnn.parameters(),
                                    lr=args.learning_rate,
def main():
    global args, best_prec1
    args = parser.parse_args()
    # torch.cuda.set_device(args.gpu)
    if args.tensorboard:
        print("Using tensorboard")
        configure("exp/%s" % (args.name))

    # Data loading code

    if args.augment:
        print(
            "Doing image augmentation with\n"
            "Zoom: prob: {zoom_prob}  range: {zoom_range}\n"
            "Stretch: prob: {stretch_prob} range: {stretch_range}\n"
            "Rotation: prob: {rotation_prob} range: {rotation_degree}".format(
                zoom_prob=args.zoom_prob,
                zoom_range=args.zoom_range,
                stretch_prob=args.stretch_prob,
                stretch_range=args.stretch_range,
                rotation_prob=args.rotation_prob,
                rotation_degree=args.rotation_degree))
        transform_train = transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: F.pad(
                Variable(x.unsqueeze(0), requires_grad=False, volatile=True),
                (4, 4, 4, 4),
                mode='replicate').data.squeeze()),
            transforms.ToPILImage(),
            transforms.RandomCrop(32),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(prob=args.rotation_prob,
                                      degree=args.rotation_degree),
            transforms.RandomZoom(prob=args.zoom_prob,
                                  zoom_range=args.zoom_range),
            transforms.RandomStretch(prob=args.stretch_prob,
                                     stretch_range=args.stretch_range),
            transforms.ToTensor(),
        ])

    else:
        transform_train = transforms.Compose([
            transforms.ToTensor(),
        ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
    ])

    kwargs = {'num_workers': 1, 'pin_memory': True}
    assert (args.dataset == 'cifar10' or args.dataset == 'cifar100')
    train_loader = torch.utils.data.DataLoader(
        datasets.__dict__[args.dataset.upper()]('../data',
                                                train=True,
                                                download=True,
                                                transform=transform_train),
        batch_size=args.batch_size,
        shuffle=True,
        **kwargs)
    val_loader = torch.utils.data.DataLoader(
        datasets.__dict__[args.dataset.upper()]('../data',
                                                train=False,
                                                transform=transform_test),
        batch_size=args.batch_size,
        shuffle=True,
        **kwargs)

    # create model
    model = WideResNet(args.layers,
                       args.dataset == 'cifar10' and 10 or 100,
                       args.widen_factor,
                       dropRate=args.droprate)

    # get the number of model parameters
    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    # for training on multiple GPUs.
    # Use CUDA_VISIBLE_DEVICES=0,1 to specify which GPUs to use
    # model = torch.nn.DataParallel(model).cuda()
    model = model.cuda()

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                nesterov=args.nesterov,
                                weight_decay=args.weight_decay)

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch + 1)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion, epoch)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
            }, is_best)
    print 'Best accuracy: ', best_prec1
Ejemplo n.º 7
0
        images = Variable(images, volatile=True)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum()

    print('Accuracy of the network on the %d test images: %d %%' %
          (total, 100 * correct / total))


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='PyTorch FeedForward Example')
    parser.add_argument('--epochs',
                        type=int,
                        default=3,
                        help='number of epochs to train')
    parser.add_argument('--lr', type=float, default=0.01, help='learning rate')
    parser.add_argument('--batch_size',
                        type=int,
                        default=128,
                        help='batch size')
    args = parser.parse_args()

    train_loader, test_loader = data_loader()
    model = WideResNet(int(28), 10)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
    train(args.epochs)
    test()
class MixMatchTrainer:

    def __init__(self, batch_size, num_lbls, model_params, n_steps, K, lambda_u, optimizer, adam,
                 sgd, steps_validation, steps_checkpoint, dataset, save_path, use_pseudo, tau):

        self.validation_set = False

        self.n_steps = n_steps
        self.start_step = 0
        self.K = K
        self.steps_validation = steps_validation
        self.steps_checkpoint = steps_checkpoint
        self.num_labeled = num_lbls
        self.labeled_loader, self.unlabeled_loader, self.val_loader, self.test_loader, self.lbl_idx, self.unlbl_idx, self.val_idx \
            = get_dataloaders_with_index(path='../data',
                                         batch_size=batch_size,
                                         num_labeled=num_lbls,
                                         which_dataset=dataset,
                                         validation=self.validation_set)
        print('Labeled samples: ' + str(len(self.labeled_loader.sampler)) + '\tUnlabeled samples: ' + str(len(self.unlabeled_loader.sampler)))
        self.targets_list = np.array(self.labeled_loader.dataset.targets)
        self.batch_size = self.labeled_loader.batch_size
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        print(self.device)

        # -- Model --
        depth, k, n_out = model_params
        self.model = WideResNet(depth=depth, k=k, n_out=n_out, bias=False).to(self.device)
        self.ema_model = WideResNet(depth=depth, k=k, n_out=n_out, bias=False).to(self.device)
        for param in self.ema_model.parameters():
            param.detach_()

        if optimizer == 'adam':
            self.lr, self.weight_decay = adam
            self.momentum, self.lr_decay = None, None
            self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
            self.ema_optimizer = WeightEMA(self.model, self.ema_model, self.lr, alpha=0.999)

        else:
            self.lr, self.momentum, self.weight_decay, self.lr_decay = sgd
            self.optimizer = optim.SGD(self.model.parameters(), lr=self.lr, momentum=self.momentum, weight_decay=self.weight_decay, nesterov=True)
            self.ema_optimizer = None

        self.lambda_u_max, self.step_top_up = lambda_u
        self.loss_mixmatch = Loss(self.lambda_u_max, self.step_top_up)
        self.criterion = nn.CrossEntropyLoss()

        self.train_accuracies, self.train_losses, = [], []
        self.val_accuracies, self.val_losses, = [], []
        self.best_acc = 0

        self.mixmatch = MixMatch(self.model, self.batch_size, self.device)

        self.writer = SummaryWriter()
        self.path = save_path

        # -- Pseudo label --
        self.use_pseudo = use_pseudo
        self.steps_pseudo_lbl = 5000
        self.tau = tau  # confidence threshold

        self.min_unlbl_samples = 1000
        # Make a deep copy of original unlabeled loader
        _, self.unlabeled_loader_original, _, _, _, _, _ \
            = get_dataloaders_with_index(path='../data',
                                         batch_size=batch_size,
                                         num_labeled=num_lbls,
                                         which_dataset=dataset,
                                         lbl_idxs=self.lbl_idx,
                                         unlbl_idxs=self.unlbl_idx,
                                         valid_idxs=self.val_idx,
                                         validation=self.validation_set)

    def train(self):

        iter_labeled_loader = iter(self.labeled_loader)
        iter_unlabeled_loader = iter(self.unlabeled_loader)

        for step in range(self.start_step, self.n_steps):

            # Get next batch of data
            self.model.train()
            try:
                x_imgs, x_labels, _ = iter_labeled_loader.next()
                # Check if batch size has been cropped for last batch
                if x_imgs.shape[0] < self.batch_size:
                    iter_labeled_loader = iter(self.labeled_loader)
                    x_imgs, x_labels, _ = iter_labeled_loader.next()
            except:
                iter_labeled_loader = iter(self.labeled_loader)
                x_imgs, x_labels, _ = iter_labeled_loader.next()

            try:
                u_imgs, _, _ = iter_unlabeled_loader.next()
                if u_imgs.shape[0] < self.batch_size:
                    iter_unlabeled_loader = iter(self.unlabeled_loader)
                    u_imgs, _, _ = iter_unlabeled_loader.next()
            except:
                iter_unlabeled_loader = iter(self.unlabeled_loader)
                u_imgs, _, _ = iter_unlabeled_loader.next()

            # Send to GPU
            x_imgs = x_imgs.to(self.device)
            x_labels = x_labels.to(self.device)
            u_imgs = u_imgs.to(self.device)

            # MixMatch algorithm
            x, u = self.mixmatch.run(x_imgs, x_labels, u_imgs)
            x_input, x_targets = x
            u_input, u_targets = u
            u_targets.detach_()  # stop gradients from propagation to label guessing

            # Compute X' predictions
            x_output = self.model(x_input)

            # Compute U' predictions. Separate in batches
            u_batch_outs = []
            for k in range(self.K):
                u_batch = u_input[k * self.batch_size:(k + 1) * self.batch_size]
                u_batch_outs.append(self.model(u_batch))
            u_outputs = torch.cat(u_batch_outs, dim=0)

            # Compute loss
            loss = self.loss_mixmatch(x_output, x_targets, u_outputs, u_targets, step)

            # Step
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            if self.ema_optimizer:
                self.ema_optimizer.step()

            # Decaying learning rate. Used in with SGD Nesterov optimizer
            if not self.ema_optimizer and step in self.lr_decay:
                for g in self.optimizer.param_groups:
                    g['lr'] *= 0.2

            # Evaluate model
            self.model.eval()
            if step > 0 and not step % self.steps_validation:
                val_acc, is_best = self.evaluate_loss_acc(step)
                if is_best and step > 10000:
                    self.save_model(step=step, path=f'{self.path}/best_checkpoint.pt')

            # Save checkpoint
            if step > 0 and not step % self.steps_checkpoint:
                self.save_model(step=step, path=f'{self.path}/checkpoint_{step}.pt')

            # Generate Pseudo-labels
            if self.use_pseudo and step >= 50000 and not step % self.steps_pseudo_lbl:
                # matrix columns: [index, confidence, pseudo_label, true_label, is_ground_truth]
                matrix = self.get_pseudo_labels()
                self.print_threshold_comparison(matrix)

                # Generate pseudo set based on threshold (same for all classes)
                if self.tau != -1:
                    matrix = self.generate_pseudo_set(matrix)

                # Generate pseudo set balanced (top 90% guesses of each class)
                else:
                    matrix = self.generate_pseudo_set_balanced(matrix)

                iter_labeled_loader = iter(self.labeled_loader)
                iter_unlabeled_loader = iter(self.unlabeled_loader)

                # Save
                torch.save(matrix, f'{self.path}/pseudo_matrix_balanced_{step}.pt')

        # --- Training finished ---
        test_val, test_acc = self.evaluate(self.test_loader)
        print("Training done!!\t Test loss: %.3f \t Test accuracy: %.3f" % (test_val, test_acc))

        self.writer.flush()

    # --- support functions ---
    def evaluate_loss_acc(self, step):
        val_loss, val_acc = self.evaluate(self.val_loader)
        self.val_losses.append(val_loss)
        self.val_accuracies.append(val_acc)
        train_loss, train_acc = self.evaluate(self.labeled_loader)
        self.train_losses.append(train_loss)
        self.train_accuracies.append(train_acc)

        is_best = False
        if val_acc > self.best_acc:
            self.best_acc = val_acc
            is_best = True

        print("Step %d.\tLoss train_lbl/valid  %.2f  %.2f\t Accuracy train_lbl/valid  %.2f  %.2f \tBest acc %.2f \t%s" %
              (step, train_loss, val_loss, train_acc, val_acc, self.best_acc, time.ctime()))

        self.writer.add_scalar("Loss train_label", train_loss, step)
        self.writer.add_scalar("Loss validation", val_loss, step)
        self.writer.add_scalar("Accuracy train_label", train_acc, step)
        self.writer.add_scalar("Accuracy validation", val_acc, step)
        return val_acc, is_best

    def evaluate(self, dataloader):
        correct, total, loss = 0, 0, 0
        with torch.no_grad():
            for i, data in enumerate(dataloader, 0):
                inputs, labels = data[0].to(self.device), data[1].to(self.device)
                if self.ema_optimizer:
                    outputs = self.ema_model(inputs)
                else:
                    outputs = self.model(inputs)
                loss += self.criterion(outputs, labels).item()

                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
            loss /= dataloader.__len__()

        acc = correct / total * 100
        return loss, acc


    def get_losses(self):
        return self.loss_mixmatch.loss_list, self.loss_mixmatch.lx_list, self.loss_mixmatch.lu_list, self.loss_mixmatch.lu_weighted_list

    def save_model(self, step=None, path=f'../model.pt'):
        loss_list, lx, lu, lu_weighted = self.get_losses()
        if not step:
            step = self.n_steps     # Training finished

        torch.save({
            'step': step,
            'model_state_dict': self.model.state_dict(),
            'ema_state_dict': self.ema_model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'loss_train': self.train_losses,
            'loss_val': self.val_losses,
            'acc_train': self.train_accuracies,
            'acc_val': self.val_accuracies,
            'loss_batch': loss_list,
            'lx': lx,
            'lu': lu,
            'lu_weighted': lu_weighted,
            'steps': self.n_steps,
            'batch_size': self.batch_size,
            'num_labels': self.num_labeled,
            'lambda_u_max': self.lambda_u_max,
            'step_top_up': self.step_top_up,
            'lr': self.lr,
            'weight_decay': self.weight_decay,
            'momentum': self.momentum,
            'lr_decay': self.lr_decay,
            'lbl_idx': self.lbl_idx,
            'unlbl_idx': self.unlbl_idx,
            'val_idx': self.val_idx,
        }, path)

    def load_checkpoint(self, path_checkpoint):
        saved_model = torch.load(path_checkpoint)
        self.model.load_state_dict(saved_model['model_state_dict'])
        self.ema_model.load_state_dict(saved_model['ema_state_dict'])
        self.optimizer.load_state_dict(saved_model['optimizer_state_dict'])
        self.start_step = saved_model['step']
        self.train_losses = saved_model['loss_train']
        self.val_losses = saved_model['loss_val']
        self.train_accuracies = saved_model['acc_train']
        self.val_accuracies = saved_model['acc_val']
        self.batch_size = saved_model['batch_size']
        self.num_labeled = saved_model['num_labels']
        self.labeled_loader, self.unlabeled_loader, self.val_loader, self.test_loader, self.lbl_idx, self.unlbl_idx, self.val_idx = \
            get_dataloaders_with_index(path='../data',
                                       batch_size=self.batch_size,
                                       num_labeled=self.num_labeled,
                                       which_dataset='cifar10',
                                       lbl_idxs=saved_model['lbl_idx'],
                                       unlbl_idxs=saved_model['unlbl_idx'],
                                       valid_idxs=saved_model['val_idx'],
                                       validation=self.validation_set)
        self.unlabeled_loader_original = self.unlabeled_loader
        print('Model ' + path_checkpoint + ' loaded.')

    def get_pseudo_labels(self):
        matrix = torch.tensor([], device=self.device)

        # Iterate through unlabeled loader
        for batch_idx, (data, target, idx) in enumerate(self.unlabeled_loader_original):
            with torch.no_grad():
                # Get predictions for unlabeled samples
                out = self.model(data.to(self.device))
                p_out = torch.softmax(out, dim=1)  # turn into probability distribution
                confidence, pseudo_lbl = torch.max(p_out, dim=1)
                pseudo_lbl_batch = torch.vstack((idx.to(self.device), confidence, pseudo_lbl)).T
                # Append to matrix
                matrix = torch.cat((matrix, pseudo_lbl_batch), dim=0)  # (n_unlabeled, 3)

        n_unlabeled = matrix.shape[0]
        indices = matrix[:, 0].cpu().numpy().astype(int)
        ground_truth = self.targets_list[indices]
        matrix = torch.vstack((matrix.T, torch.tensor(ground_truth, device=self.device))).T   # (n_unlabeled, 4)
        matrix = torch.vstack((matrix.T, torch.zeros(n_unlabeled, device=self.device))).T   # (n_unlabeled, 5)

        # matrix columns: [index, confidence, pseudo_label, true_label, is_ground_truth]

        # Check if pseudo label is ground truth
        for i in range(n_unlabeled):
            if matrix[i, 2] == matrix[i, 3]:
                matrix[i, 4] = 1
        return matrix

    def generate_pseudo_set(self, matrix):

        unlbl_mask1 = (matrix[:, 1] < self.tau)
        # unlbl_mask2 = (matrix[:, 1] >= 0.99)
        pseudo_mask = (matrix[:, 1] >= self.tau)
        # pseudo_mask = (matrix[:, 1] >= self.tau) & (matrix[:, 1] < 0.99)

        # unlbl_indices = torch.cat((matrix[unlbl_mask1, 0], matrix[unlbl_mask2, 0]))
        unlbl_indices = matrix[unlbl_mask1, 0]
        matrix = matrix[pseudo_mask, :]
        indices = matrix[:, 0]

        new_lbl_idx = np.int_(torch.cat((torch.tensor(self.lbl_idx, device=self.device), indices)).tolist())
        new_unlbl_idx = np.int_(unlbl_indices.tolist())
        self.labeled_loader, self.unlabeled_loader, self.val_loader, self.test_loader, _, _, new_val_idx = \
            get_dataloaders_with_index(path='../data', batch_size=self.batch_size, num_labeled=self.num_labeled,
                                       which_dataset='cifar10', lbl_idxs=new_lbl_idx, unlbl_idxs=new_unlbl_idx,
                                       valid_idxs=self.val_idx, validation=self.validation_set)

        assert np.allclose(self.val_idx, new_val_idx), 'error'
        assert (len(self.labeled_loader.sampler) + len(self.unlabeled_loader.sampler) == 50000), 'error'

        # Change real labels for pseudo labels
        for i in range(matrix.shape[0]):
            index = int(matrix[i, 0].item())
            assert int(matrix[i, 3]) == self.labeled_loader.dataset.targets[index]
            pseudo_labels = int(matrix[i, 2].item())
            self.labeled_loader.dataset.targets[index] = pseudo_labels

        correct = torch.sum(matrix[:, 4]).item()
        pseudo_acc = correct / matrix.shape[0] * 100 if matrix.shape[0] > 0 else 0
        print('Generated labels: %d\t Correct: %d\t Accuracy: %.2f' % (matrix.shape[0], correct, pseudo_acc))
        print('Training with Labeled / Unlabeled / Validation samples\t %d %d %d' % (len(new_lbl_idx),
                                                                                     len(new_unlbl_idx),
                                                                                     len(self.val_idx)))

        return matrix

    def generate_pseudo_set_balanced(self, matrix_all):

        unlbl_indices = torch.tensor([], device=self.device)

        # Get top 10% confident guesses for each class
        matrix = torch.tensor([], device=self.device)
        for i in range(10):
            matrix_label = matrix_all[matrix_all[:, 2] == i, :]
            threshold = torch.quantile(matrix_label[:, 1], 0.9)     # returns prob in the percentile 90
            unlbl_idxs = matrix_label[matrix_label[:, 1] < threshold, 0]
            matrix_label = matrix_label[matrix_label[:, 1] >= threshold, :]
            matrix = torch.cat((matrix, matrix_label), dim=0)
            unlbl_indices = torch.cat((unlbl_indices, unlbl_idxs))
        indices = matrix[:, 0]

        new_lbl_idx = np.int_(torch.cat((torch.tensor(self.lbl_idx, device=self.device), indices)).tolist())
        new_unlbl_idx = np.int_(unlbl_indices.tolist())
        self.labeled_loader, self.unlabeled_loader, self.val_loader, self.test_loader, _, _, new_val_idx = \
            get_dataloaders_with_index(path='../data', batch_size=self.batch_size, num_labeled=self.num_labeled,
                                       which_dataset='cifar10', lbl_idxs=new_lbl_idx, unlbl_idxs=new_unlbl_idx,
                                       valid_idxs=self.val_idx, validation=self.validation_set)

        assert np.allclose(self.val_idx, new_val_idx), 'error'
        assert (len(self.labeled_loader.sampler) + len(self.unlabeled_loader.sampler) == 50000), 'error'

        # Change real labels for pseudo labels
        for i in range(matrix.shape[0]):
            index = int(matrix[i, 0].item())
            assert int(matrix[i, 3]) == self.labeled_loader.dataset.targets[index]
            pseudo_labels = int(matrix[i, 2].item())
            self.labeled_loader.dataset.targets[index] = pseudo_labels

        correct = torch.sum(matrix[:, 4]).item()
        pseudo_acc = correct / matrix.shape[0] * 100 if matrix.shape[0] > 0 else 0
        print('Generated labels: %d\t Correct: %d\t Accuracy: %.2f' % (matrix.shape[0], correct, pseudo_acc))
        print('Training with Labeled / Unlabeled / Validation samples\t %d %d %d' % (len(new_lbl_idx),
                                                                                     len(new_unlbl_idx),
                                                                                     len(self.val_idx)))
        return matrix

    def print_threshold_comparison(self, matrix):
        m2 = matrix[matrix[:, 1] >= 0.9, :]
        for i, tau in enumerate([0.9, 0.95, 0.97, 0.99, 0.999]):
            pseudo_labels = m2[m2[:, 1] >= tau, :]
            total = pseudo_labels.shape[0]
            correct = torch.sum(pseudo_labels[:, 4]).item()
            print('Confidence threshold %.3f\t Generated / Correct / Precision\t %d\t%d\t%.2f '
                  % (tau, total, correct, correct / (total + np.finfo(float).eps) * 100))