Esempio n. 1
0
def train(args):

    json_options = json_file_to_pyobj(args.config)
    no_teacher_configurations = json_options.training

    wrn_depth = no_teacher_configurations.wrn_depth
    wrn_width = no_teacher_configurations.wrn_width

    M = no_teacher_configurations.M

    dataset = no_teacher_configurations.dataset
    seeds = [int(seed) for seed in no_teacher_configurations.seeds]
    log = True if no_teacher_configurations.log.lower() == 'True' else False

    if log:
        net_str = "WideResNet-{}-{}".format(wrn_depth, wrn_width)
        logfile = "No_Teacher-{}-{}-M-{}.txt".format(
            net_str, no_teacher_configurations.dataset, M)
        with open(os.path.join('./', logfile), "w") as temp:
            temp.write('No teacher {} in {} with M={}\n'.format(
                net_str, no_teacher_configurations.dataset, M))
    else:
        logfile = ''

    checkpoint = bool(no_teacher_configurations.checkpoint)

    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')

    test_set_accuracies = []

    for seed in seeds:

        set_seed(seed)

        if dataset.lower() == 'cifar10':

            # Full data
            if M == 5000:
                from utils import cifar10loaders
                loaders = cifar10loaders()
            # No data
            elif M == 0:
                from utils import cifar10loaders
                _, test_loader = cifar10loaders()
            else:
                from utils import cifar10loadersM
                loaders = cifar10loadersM(M)

        elif dataset.lower() == 'svhn':

            # Full data
            if M == 5000:
                from utils import svhnLoaders
                loaders = svhnLoaders()
            # No data
            elif M == 0:
                from utils import svhnLoaders
                _, test_loader = svhnLoaders()
            else:
                from utils import svhnloadersM
                loaders = svhnloadersM(M)

        else:
            raise ValueError('Datasets to choose from: CIFAR10 and SVHN')

        if log:
            with open(os.path.join('./', logfile), "a") as temp:
                temp.write(
                    '------------------- SEED {} -------------------\n'.format(
                        seed))

        strides = [1, 1, 2, 2]

        net = WideResNet(d=wrn_depth,
                         k=wrn_width,
                         n_classes=10,
                         input_features=3,
                         output_features=16,
                         strides=strides)
        net = net.to(device)

        checkpointFile = 'No_teacher_wrn-{}-{}-M-{}-seed-{}-{}-dict.pth'.format(
            wrn_depth, wrn_width, M, seed, dataset) if checkpoint else ''

        best_test_set_accuracy = _train_seed_no_teacher(
            net, M, loaders, device, dataset, log, checkpoint, logfile,
            checkpointFile)

        if log:
            with open(os.path.join('./', logfile), "a") as temp:
                temp.write('Best test set accuracy of seed {} is {}\n'.format(
                    seed, best_test_set_accuracy))

        test_set_accuracies.append(best_test_set_accuracy)

        if log:
            with open(os.path.join('./', logfile), "a") as temp:
                temp.write('Best test set accuracy of seed {} is {}\n'.format(
                    seed, best_test_set_accuracy))

    mean_test_set_accuracy, std_test_set_accuracy = np.mean(
        test_set_accuracies), np.std(test_set_accuracies)

    if log:
        with open(os.path.join('./', logfile), "a") as temp:
            temp.write(
                'Mean test set accuracy is {} with standard deviation equal to {}\n'
                .format(mean_test_set_accuracy, std_test_set_accuracy))
def train(args):

    json_options = json_file_to_pyobj(args.config)
    training_configurations = json_options.training
    wandb.init(
        name=
        f"{training_configurations.checkpoint}_subset_{args.subset_index}_ensemble"
    )
    device = torch.device(f'cuda:{args.device}')

    flag = False
    if training_configurations.train_pickle != 'None' and training_configurations.test_pickle != 'None':
        pickle_files = [
            training_configurations.train_pickle,
            training_configurations.test_pickle
        ]
        flag = True

    if args.subset_index is None:
        model = build_model(args)
        model = model.to(device)
        epochs = 40
        optimizer = optim.SGD(model.parameters(),
                              lr=1.25e-2,
                              momentum=0.9,
                              nesterov=True,
                              weight_decay=1e-4)
        scheduler = MultiStepLR(optimizer, milestones=[10, 20, 30], gamma=0.1)

    dataset = args.dataset.lower()
    b = 0.2
    m = 0.4

    if not flag:
        trainloader, val_loader, testloader = fine_grained_image_loaders_subset(
            dataset,
            subset_index=args.subset_index,
            validation_test_split=800,
            save_to_pickle=True)
    else:
        pickle_files[0] = "pickle_files/" + pickle_files[0].split(
            ".pickle")[0] + f"_subset_{args.subset_index}.pickle"
        pickle_files[1] = "pickle_files/" + pickle_files[1].split(
            ".pickle")[0] + f"_subset_{args.subset_index}.pickle"
        trainloader, val_loader, testloader, num_classes = fine_grained_image_loaders_subset(
            dataset,
            subset_index=args.subset_index,
            validation_test_split=800,
            pickle_files=pickle_files,
            ret_num_classes=True)
        train_ood_loader = fine_grained_image_loaders_subset(
            dataset,
            single=True,
            subset_index=args.subset_index,
            validation_test_split=800,
            pickle_files=pickle_files)

        if 'genOdin' in training_configurations.checkpoint:
            weight_decay = 1e-4
            optimizer = optim.SGD([
                {
                    'params': model._conv_stem.parameters(),
                    'weight_decay': weight_decay
                },
                {
                    'params': model._bn0.parameters(),
                    'weight_decay': weight_decay
                },
                {
                    'params': model._blocks.parameters(),
                    'weight_decay': weight_decay
                },
                {
                    'params': model._conv_head.parameters(),
                    'weight_decay': weight_decay
                },
                {
                    'params': model._bn1.parameters(),
                    'weight_decay': weight_decay
                },
                {
                    'params': model._fc_denominator.parameters(),
                    'weight_decay': weight_decay
                },
                {
                    'params': model._denominator_batch_norm.parameters(),
                    'weight_decay': weight_decay
                },
                {
                    'params': model._fc_nominator.parameters(),
                    'weight_decay': 0
                },
            ],
                                  lr=1.25e-2,
                                  momentum=0.9,
                                  nesterov=True)
            scheduler = MultiStepLR(optimizer,
                                    milestones=[10, 20, 30],
                                    gamma=0.1)

    if args.subset_index is not None:
        model = build_model(args)
        model._fc = nn.Linear(model._fc.in_features, num_classes)
        model = model.to(device)
        epochs = 40
        optimizer = optim.SGD(model.parameters(),
                              lr=1.25e-2,
                              momentum=0.9,
                              nesterov=True,
                              weight_decay=1e-4)
        scheduler = MultiStepLR(optimizer, milestones=[10, 20, 30], gamma=0.1)

    criterion = nn.CrossEntropyLoss()
    checkpoint_val_accuracy, best_val_acc, test_set_accuracy = 0, 0, 0

    ood_loader_iter = iter(train_ood_loader)

    for epoch in tqdm(range(epochs)):

        model.train()
        correct, total = 0, 0
        train_loss = 0
        for data in tqdm(trainloader):

            model.train()

            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            ce_loss = criterion(outputs, labels)

            try:
                ood_inputs, _ = next(ood_loader_iter)
            except:
                ood_loader_iter = iter(train_ood_loader)
                ood_inputs, _ = next(ood_loader_iter)

            ood_inputs = ood_inputs.to(device)
            ood_outputs = model(ood_inputs)
            entropy_input = -torch.mean(
                torch.sum(
                    F.log_softmax(outputs, dim=1) * F.softmax(outputs, dim=1),
                    dim=1))
            entropy_output = -torch.mean(
                torch.sum(F.log_softmax(ood_outputs, dim=1) *
                          F.softmax(ood_outputs, dim=1),
                          dim=1))
            margin_loss = b * torch.clamp(m + entropy_input - entropy_output,
                                          min=0)

            loss = ce_loss + margin_loss
            train_loss += loss.item()
            loss.backward()
            optimizer.step()

        train_accuracy = correct / total
        wandb.log({'epoch': epoch}, commit=False)
        wandb.log({
            'Train Set Loss': train_loss / trainloader.__len__(),
            'epoch': epoch
        })
        wandb.log({'Train Set Accuracy': train_accuracy, 'epoch': epoch})

        model.eval()
        correct, total = 0, 0

        with torch.no_grad():

            for data in val_loader:
                images, labels = data
                images = images.to(device)
                labels = labels.to(device)

                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

            epoch_val_accuracy = correct / total

            wandb.log({
                'Validation Set Accuracy': epoch_val_accuracy,
                'epoch': epoch
            })

        if epoch_val_accuracy > best_val_acc:
            best_val_acc = epoch_val_accuracy

            if os.path.exists('/raid/ferles/'):
                torch.save(
                    model.state_dict(),
                    f'/raid/ferles/checkpoints/eb0/{dataset}/{training_configurations.checkpoint}_subset_ens_{args.subset_index}.pth'
                )
            else:
                torch.save(
                    model.state_dict(),
                    f'/home/ferles/checkpoints/eb0/{dataset}/{training_configurations.checkpoint}_subset_ens_{args.subset_index}.pth'
                )

            correct, total = 0, 0

            for data in testloader:
                images, labels = data
                images = images.to(device)
                labels = labels.to(device)

                if 'genodin' in training_configurations.checkpoint.lower():
                    outputs, h, g = model(images)
                else:
                    outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

            test_set_accuracy = correct / total

        wandb.log({'Test Set Accuracy': test_set_accuracy, 'epoch': epoch})

        scheduler.step(epoch=epoch)
def train(args):

    device = torch.device(f'cuda:{args.device}')

    json_options = json_file_to_pyobj(args.config)
    training_configurations = json_options.training
    traincsv = training_configurations.traincsv
    testcsv = training_configurations.testcsv
    gtFileName = training_configurations.gtFile
    out_classes = training_configurations.out_classes
    exclude_class = training_configurations.exclude_class
    exclude_class = None if exclude_class == "None" else exclude_class

    if exclude_class is None:
        wandb.init(name='oe_isic')
    else:
        wandb.init(name=f'oe_{exclude_class}')

    batch_size = 32

    if exclude_class is None:
        train_loader, val_loader, test_loader, columns = oversampling_loaders_custom(csvfiles=[traincsv, testcsv], train_batch_size=32, val_batch_size=16, gtFile=gtFileName)
    else:
        train_loader, val_loader, test_loader, columns = oversampling_loaders_exclude_class_custom_no_gts(csvfiles=[traincsv, testcsv], train_batch_size=32, val_batch_size=16, gtFile=gtFileName, exclude_class=exclude_class)
    ood_loader = imageNetLoader(dataset='isic', batch_size=batch_size)
    ood_loader_iter = iter(ood_loader)

    model = build_model(args).to(device)
    epochs = 40
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=1.25e-2, momentum=0.9, nesterov=True, weight_decay=1e-4)
    scheduler = MultiStepLR(optimizer, milestones=[10, 20, 30], gamma=0.1)

    uniform = torch.ones(size=(batch_size, out_classes)) / float(out_classes)
    uniform = uniform.to(device)
    lamda = 0.5

    checkpoint_val_accuracy, best_val_acc, test_set_accuracy = 0, 0, 0
    for epoch in tqdm(range(epochs)):

        model.train()
        loss_acc = []

        for data in tqdm(train_loader):

            model.train()

            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)

            try:
                ood_inputs, _ = next(ood_loader_iter)
            except:
                ood_loader_iter = iter(ood_loader)
                ood_inputs, _ = next(ood_loader_iter)

            ood_inputs = ood_inputs.to(device)
            ood_outputs = model(ood_inputs)

            _labels = torch.argmax(labels, dim=1)
            ce_loss = criterion(outputs, _labels)
            if ood_outputs.size(0) < batch_size:
                uniform = torch.ones(size=(ood_outputs.size(0), out_classes)) / float(out_classes)
                uniform = uniform.to(device)

            outlier_loss = lamda * -(uniform.mean(1) - torch.logsumexp(ood_outputs, dim=1)).mean()
            loss = ce_loss + outlier_loss

            loss_acc.append(loss.item())
            loss.backward()
            optimizer.step()

            if ood_outputs.size(0) < batch_size:
                uniform = torch.ones(size=(batch_size, out_classes)) / float(out_classes)
                uniform = uniform.to(device)

        wandb.log({'epoch': epoch}, commit=False)
        wandb.log({'Train Set Loss': sum(loss_acc) / float(train_loader.__len__()), 'epoch': epoch})

        model.eval()
        correct, total = 0, 0

        with torch.no_grad():

            for data in val_loader:
                images, labels = data
                images = images.to(device)
                labels = labels.to(device)
                _labels = torch.argmax(labels, dim=1)

                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == _labels).sum().item()

            val_detection_accuracy = round(100*correct/total, 2)
            wandb.log({'Validation Detection Accuracy': val_detection_accuracy, 'epoch': epoch})

            if val_detection_accuracy > best_val_acc:
                best_val_acc = val_detection_accuracy

                if os.path.exists('/raid/ferles/'):
                    if exclude_class is None:
                        torch.save(model.state_dict(), f'/raid/ferles/checkpoints/isic_classifiers/outlier_exposure_isic.pth')
                    else:
                        torch.save(model.state_dict(), f'/raid/ferles/checkpoints/isic_classifiers/outlier_exposure_{exclude_class}.pth')
                else:
                    if exclude_class is None:
                        torch.save(model.state_dict(), f'/home/ferles/checkpoints/isic_classifiers/outlier_exposure_isic.pth')
                    else:
                        torch.save(model.state_dict(), f'/home/ferles/checkpoints/isic_classifiers/outlier_exposure_{exclude_class}.pth')
                correct, total = 0, 0

                for data in test_loader:
                    _, images, labels = data
                    images = images.to(device)
                    labels = labels.to(device)
                    _labels = torch.argmax(labels, dim=1)

                    outputs = model(images)
                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += (predicted == _labels).sum().item()

                test_detection_accuracy = correct / total

            wandb.log({'Detection Accuracy': test_detection_accuracy, 'epoch': epoch})

            scheduler.step(epoch=epoch)
Esempio n. 4
0
def train(args):

    use_wandb = True

    device = torch.device(f'cuda:{args.device}')

    json_options = json_file_to_pyobj(args.config)
    training_configurations = json_options.training
    traincsv = training_configurations.traincsv
    testcsv = training_configurations.testcsv
    gtFileName = training_configurations.gtFile
    checkpointFileName = training_configurations.checkpointFile
    out_classes = training_configurations.out_classes
    exclude_class = training_configurations.exclude_class
    exclude_class = None if exclude_class == "None" else exclude_class

    if use_wandb:
        wandb.init(name=checkpointFileName)

    if exclude_class is None:
        train_loader, val_loader, test_loader, columns = oversampling_loaders_custom(csvfiles=[traincsv, testcsv], train_batch_size=32, val_batch_size=16, gtFile=gtFileName)
    else:
        train_loader, val_loader, test_loader, columns = oversampling_loaders_exclude_class_custom_no_gts(csvfiles=[traincsv, testcsv], train_batch_size=32, val_batch_size=16, gtFile=gtFileName, exclude_class=exclude_class)

    model = build_model(args).to(device)
    epochs = 40
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=1.25e-2, momentum=0.9, nesterov=True, weight_decay=1e-4)
    scheduler = MultiStepLR(optimizer, milestones=[10, 20, 30], gamma=0.1)

    best_val_detection_accuracy, test_detection_accuracy = 0, 0
    train_loss, val_loss, balanced_accuracies = [], [], []

    early_stopping = False
    early_stopping_cnt = 0
        
    for epoch in tqdm(range(epochs)):

        model.train()
        loss_acc = []

        for data in tqdm(train_loader):
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()

            if 'genodin' in training_configurations.checkpointFile.lower():
                outputs, _, _ = model(inputs)
            else:
                outputs = model(inputs)

            _labels = torch.argmax(labels, dim=1)
            loss = criterion(outputs, _labels)
            loss_acc.append(loss.item())
            loss.backward()
            optimizer.step()

        wandb.log({'Train Set Loss': sum(loss_acc) / float(train_loader.__len__()), 'epoch': epoch})
        wandb.log({'epoch': epoch}, commit=False)
        train_loss.append(sum(loss_acc) / float(train_loader.__len__()))
        loss_acc.clear()

        with torch.no_grad():
            correct, total = 0, 0
            for data in tqdm(val_loader):
                images, labels = data
                images = images.to(device)
                labels = labels.to(device)

                if 'genodin' in training_configurations.checkpointFile.lower():
                    outputs, _, _ = model(images)
                else:
                    outputs = model(images)

                softmax_outputs = torch.softmax(outputs, 1)
                max_idx = torch.argmax(softmax_outputs, axis=1)
                _labels = torch.argmax(labels, dim=1)
                correct += (max_idx == _labels).sum().item()
                total += max_idx.size()[0]
                loss = criterion(outputs, _labels)
                loss_acc.append(loss.item())

        val_detection_accuracy = round(100*correct/total, 2)
        wandb.log({'Validation Detection Accuracy': val_detection_accuracy, 'epoch': epoch})

        if val_detection_accuracy > best_val_detection_accuracy:
            best_val_detection_accuracy = val_detection_accuracy
            if 'genodin' in training_configurations.checkpointFile.lower():
                # test_loss, auc, balanced_accuracy, test_detection_accuracy = _test_set_eval(model, epoch, device, test_loader, out_classes, columns, gtFileName, gen=True)
                test_loss, test_detection_accuracy = _test_set_eval(model, epoch, device, test_loader, out_classes, columns, gtFileName, gen=True)
            else:
                # test_loss, auc, balanced_accuracy, test_detection_accuracy = _test_set_eval(model, epoch, device, test_loader, out_classes, columns, gtFileName)
                test_loss, test_detection_accuracy = _test_set_eval(model, epoch, device, test_loader, out_classes, columns, gtFileName)
            checkpointFile = os.path.join(f'/raid/ferles/checkpoints/isic_classifiers/{checkpointFileName}-best-model.pth')
            if os.path.exists(checkpointFile):
                torch.save(model.state_dict(), checkpointFile)
            else:
                torch.save(model.state_dict(), checkpointFile.replace('raid', 'home'))
        else:
            if early_stopping:
                early_stopping_cnt += 1
                if early_stopping_cnt == 3:
                    break

        wandb.log({'Val Set Loss': val_loss, 'epoch': epoch})
        wandb.log({'Detection Accuracy': test_detection_accuracy, 'epoch': epoch})
        # wandb.log({'Balanced Accuracy': balanced_accuracy, 'epoch': epoch})
        # wandb.log({'AUC': auc, 'epoch': epoch})

        scheduler.step()
def train(args):
    json_options = json_file_to_pyobj(args.config)
    kd_att_configurations = json_options.training

    wrn_depth_teacher = kd_att_configurations.wrn_depth_teacher
    wrn_width_teacher = kd_att_configurations.wrn_width_teacher
    wrn_depth_student = kd_att_configurations.wrn_depth_student
    wrn_width_student = kd_att_configurations.wrn_width_student

    M = kd_att_configurations.M

    dataset = kd_att_configurations.dataset
    seeds = [int(seed) for seed in kd_att_configurations.seeds]
    log = True if kd_att_configurations.log.lower() == 'True' else False

    if log:
        teacher_str = "WideResNet-{}-{}".format(wrn_depth_teacher,
                                                wrn_width_teacher)
        student_str = "WideResNet-{}-{}".format(wrn_depth_student,
                                                wrn_width_student)
        logfile = "Teacher-{}-Student-{}-{}-M-{}-seeds-1-2.txt".format(
            teacher_str, student_str, kd_att_configurations.dataset, M)
        print(logfile)
        with open(os.path.join('./', logfile), "w") as temp:
            temp.write(
                'KD_ATT with teacher {} and student {} in {} with M={}\n'.
                format(teacher_str, student_str, kd_att_configurations.dataset,
                       M))
    else:
        logfile = ''

    checkpoint = bool(kd_att_configurations.checkpoint)

    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')

    test_set_accuracies = []

    for seed in seeds:

        set_seed(seed)

        if dataset.lower() == 'cifar10':

            # Full data
            if M == 5000:
                from utils import cifar10loaders
                loaders = cifar10loaders()
            # No data
            elif M == 0:
                from utils import cifar10loaders
                _, test_loader = cifar10loaders()
            else:
                from utils import cifar10loadersM
                loaders = cifar10loadersM(M)

        elif dataset.lower() == 'svhn':

            # Full data
            if M == 5000:
                from utils import svhnLoaders
                loaders = svhnLoaders()
            # No data
            elif M == 0:
                from utils import svhnLoaders
                _, test_loader = svhnLoaders()
            else:
                from utils import svhnloadersM
                loaders = svhnloadersM(M)

        else:
            raise ValueError('Datasets to choose from: CIFAR10 and SVHN')

        if log:
            with open(os.path.join('./', logfile), "a") as temp:
                temp.write(
                    '------------------- SEED {} -------------------\n'.format(
                        seed))

        strides = [1, 1, 2, 2]

        teacher_net = WideResNet(d=wrn_depth_teacher,
                                 k=wrn_width_teacher,
                                 n_classes=10,
                                 input_features=3,
                                 output_features=16,
                                 strides=strides)
        teacher_net = teacher_net.to(device)
        if dataset.lower() == 'cifar10':
            torch_checkpoint = torch.load(
                './PreTrainedModels/PreTrainedScratches/CIFAR10/wrn-{}-{}-seed-{}-dict.pth'
                .format(wrn_depth_teacher, wrn_width_teacher, seed),
                map_location=device)
        else:
            torch_checkpoint = torch.load(
                './PreTrainedModels/PreTrainedScratches/SVHN/wrn-{}-{}-seed-svhn-{}-dict.pth'
                .format(wrn_depth_teacher, wrn_width_teacher, seed),
                map_location=device)

        teacher_net.load_state_dict(torch_checkpoint)

        student_net = WideResNet(d=wrn_depth_student,
                                 k=wrn_width_student,
                                 n_classes=10,
                                 input_features=3,
                                 output_features=16,
                                 strides=strides)
        student_net = student_net.to(device)

        checkpointFile = 'kd_att_teacher_wrn-{}-{}_student_wrn-{}-{}-M-{}-seed-{}-{}-dict.pth'.format(
            wrn_depth_teacher, wrn_width_teacher, wrn_depth_student,
            wrn_width_student, M, seed, dataset) if checkpoint else ''
        if M != 0:

            best_test_set_accuracy = _train_seed_kd_att(
                teacher_net, student_net, M, loaders, device, dataset, log,
                checkpoint, logfile, checkpointFile)

            if log:
                with open(os.path.join('./', logfile), "a") as temp:
                    temp.write(
                        'Best test set accuracy of seed {} is {}\n'.format(
                            seed, best_test_set_accuracy))

            test_set_accuracies.append(best_test_set_accuracy)

            if log:
                with open(os.path.join('./', logfile), "a") as temp:
                    temp.write(
                        'Best test set accuracy of seed {} is {}\n'.format(
                            seed, best_test_set_accuracy))

        else:

            best_test_set_accuracy = _test_set_eval(student_net, device,
                                                    test_loader)
            test_set_accuracies.append(best_test_set_accuracy)

    mean_test_set_accuracy, std_test_set_accuracy = np.mean(
        test_set_accuracies), np.std(test_set_accuracies)

    if log:
        with open(os.path.join('./', logfile), "a") as temp:
            temp.write(
                'Mean test set accuracy is {} with standard deviation equal to {}\n'
                .format(mean_test_set_accuracy, std_test_set_accuracy))
Esempio n. 6
0
def train(args):

    json_options = json_file_to_pyobj(args.config)
    training_configurations = json_options.training
    wandb.init(name=f'rot_{training_configurations.checkpoint}')
    device = torch.device(f'cuda')

    flag = False
    if training_configurations.train_pickle != 'None' and training_configurations.test_pickle != 'None':
        pickle_files = [training_configurations.train_pickle, training_configurations.test_pickle]
        flag = True

    if args.checkpoint is None:
        model = build_model(args, rot=True)
        model = nn.DataParallel(model).to(device)
    else:
        model = build_model_with_checkpoint(modelName='rot' + training_configurations.model.lower(), model_checkpoint=args.checkpoint, device=device, out_classes=training_configurations.out_classes, rot=True)
        model = nn.DataParallel(model).to(device)

    dataset = args.dataset.lower()
    if 'wide' in training_configurations.model.lower():
        resize = False
        epochs = 200
        optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, nesterov=True, weight_decay=5e-4)
        scheduler = MultiStepLR(optimizer, milestones=[60, 120, 160], gamma=0.2)
    else:
        resize = True
        epochs = 40
        optimizer = optim.SGD(model.parameters(), lr=1.25e-2, momentum=0.9, nesterov=True, weight_decay=1e-4)
        scheduler = MultiStepLR(optimizer, milestones=[10, 20, 30], gamma=0.1)

    if not flag:
        trainloader, val_loader, testloader = natural_image_loaders(dataset, train_batch_size=32, test_batch_size=16, validation_test_split=1000, save_to_pickle=True, resize=resize)
    else:
        trainloader, val_loader, testloader = natural_image_loaders(dataset, train_batch_size=32, test_batch_size=16, validation_test_split=1000, pickle_files=pickle_files, resize=resize)

    criterion = nn.CrossEntropyLoss()
    checkpoint_val_accuracy, best_val_acc, test_set_accuracy = 0, 0, 0

    train_loss, test_loss = 0, 0
    for epoch in tqdm(range(epochs)):

        model.train()
        correct, total = 0, 0
        for data in tqdm(trainloader):
            inputs, labels = data
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            ce_loss = criterion(outputs, labels)
            rot_gt = torch.cat((torch.zeros(inputs.size(0)), torch.ones(inputs.size(0)),
                                2*torch.ones(inputs.size(0)), 3*torch.ones(inputs.size(0))), 0).long().to(device)

            rot_inputs = inputs.detach().cpu().numpy()

            rot_inputs = np.concatenate((rot_inputs, np.rot90(rot_inputs, 1, axes=(2, 3)),
                                         np.rot90(rot_inputs, 2, axes=(2, 3)), np.rot90(rot_inputs, 3, axes=(2, 3))), 0)

            rot_inputs = torch.FloatTensor(rot_inputs)

            rot_preds = model(rot_inputs, rot=True)
            rot_loss = criterion(rot_preds, rot_gt)

            loss = ce_loss + rot_loss
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        train_accuracy = correct / total
        wandb.log({'epoch': epoch}, commit=False)
        wandb.log({'Train Set Loss': train_loss / trainloader.__len__(), 'epoch': epoch})
        wandb.log({'Train Set Accuracy': train_accuracy, 'epoch': epoch})

        model.eval()
        correct, total = 0, 0

        with torch.no_grad():

            for data in val_loader:
                images, labels = data
                images = images.to(device)
                labels = labels.to(device)
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

            epoch_val_accuracy = correct / total
            wandb.log({'Validation Set Accuracy': epoch_val_accuracy, 'epoch': epoch})

        if epoch_val_accuracy > best_val_acc:
            best_val_acc = epoch_val_accuracy
            if os.path.exists('/raid/ferles/'):
                torch.save(model.state_dict(), f'/raid/ferles/checkpoints/eb0/{dataset}/rot_{training_configurations.checkpoint}.pth')
            else:
                torch.save(model.state_dict(), f'/home/ferles/checkpoints/eb0/{dataset}/rot_{training_configurations.checkpoint}.pth')

            correct, total = 0, 0

            for data in testloader:
                images, labels = data
                images = images.to(device)
                labels = labels.to(device)
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

            test_set_accuracy = correct / total

        wandb.log({'Test Set Accuracy': test_set_accuracy, 'epoch': epoch})

        scheduler.step(epoch=epoch)
def train(args):

    json_options = json_file_to_pyobj(args.config)
    extra_M_configuration = json_options.training

    wrn_depth_teacher = extra_M_configuration.wrn_depth_teacher
    wrn_width_teacher = extra_M_configuration.wrn_width_teacher
    wrn_depth_student = extra_M_configuration.wrn_depth_student
    wrn_width_student = extra_M_configuration.wrn_width_student

    M = extra_M_configuration.M

    dataset = extra_M_configuration.dataset
    seeds = [int(seed) for seed in extra_M_configuration.seeds]
    log = True if extra_M_configuration.log.lower() == 'True' else False

    if dataset.lower() == 'cifar10':
        epochs = 200
    elif dataset.lower() == 'svhn':
        epochs = 100
    else:
        raise ValueError('Unknown dataset')

    if log:
        teacher_str = 'WideResNet-{}-{}'.format(wrn_depth_teacher,
                                                wrn_width_teacher)
        student_str = 'WideResNet-{}-{}'.format(wrn_depth_student,
                                                wrn_width_student)
        logfile = 'Extra_M_samples_Reproducibility_Zero_Shot_Teacher-{}-Student-{}-{}-M-{}-Zero-Shot.txt'.format(
            teacher_str, student_str, extra_M_configuration.dataset, M)
        with open(logfile, 'w') as temp:
            temp.write(
                'Zero-Shot with teacher {} and student {} in {} with M-{}\n'.
                format(teacher_str, student_str, extra_M_configuration.dataset,
                       M))
    else:
        logfile = ''

    checkpoint = bool(extra_M_configuration.checkpoint)

    if torch.cuda.is_available():
        device = torch.device('cuda:2')
    else:
        device = torch.device('cpu')

    test_set_accuracies = []

    for seed in seeds:

        set_seed(seed)

        if dataset.lower() == 'cifar10':

            from utils import cifar10loadersM
            loaders = cifar10loadersM(M)

        elif dataset.lower() == 'svhn':

            from utils import svhnloadersM
            loaders = svhnloadersM(M)

        else:
            raise ValueError('Datasets to choose from: CIFAR10 and SVHN')

        if log:
            with open(logfile, 'a') as temp:
                temp.write(
                    '------------------- SEED {} -------------------\n'.format(
                        seed))

        strides = [1, 1, 2, 2]

        teacher_net = WideResNet(d=wrn_depth_teacher,
                                 k=wrn_width_teacher,
                                 n_classes=10,
                                 input_features=3,
                                 output_features=16,
                                 strides=strides)
        teacher_net = teacher_net.to(device)

        if dataset.lower() == 'cifar10':
            torch_checkpoint = torch.load(
                './PreTrainedModels/PreTrainedScratches/CIFAR10/wrn-{}-{}-seed-{}-dict.pth'
                .format(wrn_depth_teacher, wrn_width_teacher, seed),
                map_location=device)
        elif dataset.lower() == 'svhn':
            torch_checkpoint = torch.load(
                './PreTrainedModels/PreTrainedScratches/SVHN/wrn-{}-{}-seed-svhn-{}-dict.pth'
                .format(wrn_depth_teacher, wrn_width_teacher, seed),
                map_location=device)
        else:
            raise ValueError('Dataset not found')

        teacher_net.load_state_dict(torch_checkpoint)

        student_net = WideResNet(d=wrn_depth_student,
                                 k=wrn_width_student,
                                 n_classes=10,
                                 input_features=3,
                                 output_features=16,
                                 strides=strides)
        student_net = student_net.to(device)

        if dataset.lower() == 'cifar10':
            torch_checkpoint = torch.load(
                './PreTrainedModels/Zero-Shot/CIFAR10/reproducibility_zero_shot_teacher_wrn-{}-{}_student_wrn-{}-{}-M-0-seed-{}-CIFAR10-dict.pth'
                .format(wrn_depth_teacher, wrn_width_teacher,
                        wrn_depth_student, wrn_width_student, seed),
                map_location=device)
        elif dataset.lower() == 'svhn':
            torch_checkpoint = torch.load(
                './PreTrainedModels/Zero-Shot/SVHN/reproducibility_zero_shot_teacher_wrn-{}-{}_student_wrn-{}-{}-M-0-seed-{}-SVHN-dict.pth'
                .format(wrn_depth_teacher, wrn_width_teacher,
                        wrn_depth_student, wrn_width_student, seed),
                map_location=device)
        else:
            raise ValueError('Dataset not found')

        student_net.load_state_dict(torch_checkpoint)

        if checkpoint:
            teacher_str = 'WideResNet-{}-{}'.format(wrn_depth_teacher,
                                                    wrn_width_teacher)
            student_str = 'WideResNet-{}-{}'.format(wrn_depth_student,
                                                    wrn_width_student)
            checkpointFile = 'Checkpoint_Extra_M_samples_Reproducibility_Zero_Shot_Teacher-{}-Student-{}-{}-M-{}-Zero-Shot-seed-{}.pth'.format(
                teacher_str, student_str, extra_M_configuration.dataset, M,
                seed)
        else:
            checkpointFile = ''

        best_test_set_accuracy = _train_extra_M(epochs, teacher_net,
                                                student_net, M, loaders,
                                                device, log, checkpoint,
                                                logfile, checkpointFile)

        test_set_accuracies.append(best_test_set_accuracy)

        if log:
            with open(logfile, 'a') as temp:
                temp.write('Best test set accuracy of seed {} is {}\n'.format(
                    seed, best_test_set_accuracy))

    mean_test_set_accuracy, std_test_set_accuracy = np.mean(
        test_set_accuracies), np.std(test_set_accuracies)

    if log:
        with open(logfile, 'a') as temp:
            temp.write(
                'Mean test set accuracy is {} with standard deviation equal to {}\n'
                .format(mean_test_set_accuracy, std_test_set_accuracy))
def train(args):
    json_options = json_file_to_pyobj(args.config)
    training_configurations = json_options.training

    wrn_depth = training_configurations.wrn_depth
    wrn_width = training_configurations.wrn_width
    dataset = training_configurations.dataset.lower()
    seeds = [int(seed) for seed in training_configurations.seeds]
    log = True if training_configurations.log.lower() == 'true' else False

    if log:
        logfile = 'WideResNet-{}-{}-{}.txt'.format(
            wrn_depth, wrn_width, training_configurations.dataset)
        with open(logfile, 'w') as temp:
            temp.write('WideResNet-{}-{} on {}\n'.format(
                wrn_depth, wrn_width, training_configurations.dataset))
    else:
        logfile = ''

    checkpoint = True if training_configurations.checkpoint.lower(
    ) == 'true' else False
    loaders = get_loaders(dataset)

    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')

    test_set_accuracies = []

    for seed in seeds:
        set_seed(seed)

        if log:
            with open(logfile, 'a') as temp:
                temp.write(
                    '------------------- SEED {} -------------------\n'.format(
                        seed))

        strides = [1, 1, 2, 2]
        net = WideResNet(d=wrn_depth,
                         k=wrn_width,
                         n_classes=10,
                         input_features=3,
                         output_features=16,
                         strides=strides)
        net = net.to(device)

        checkpointFile = 'wrn-{}-{}-seed-{}-{}-dict.pth'.format(
            wrn_depth, wrn_width, dataset, seed) if checkpoint else ''
        best_test_set_accuracy = _train_seed(net, loaders, device, dataset,
                                             log, checkpoint, logfile,
                                             checkpointFile)

        if log:
            with open(logfile, 'a') as temp:
                temp.write('Best test set accuracy of seed {} is {}\n'.format(
                    seed, best_test_set_accuracy))

        test_set_accuracies.append(best_test_set_accuracy)

    mean_test_set_accuracy, std_test_set_accuracy = np.mean(
        test_set_accuracies), np.std(test_set_accuracies)

    if log:
        with open(logfile, 'a') as temp:
            temp.write(
                'Mean test set accuracy is {} with standard deviation equal to {}\n'
                .format(mean_test_set_accuracy, std_test_set_accuracy))
def train(args):

    json_options = json_file_to_pyobj(args.config)
    training_configurations = json_options.training
    wandb.init(name=training_configurations.checkpoint)
    device = torch.device(f'cuda:{args.device}')

    flag = False
    if training_configurations.train_pickle != 'None' and training_configurations.test_pickle != 'None':
        pickle_files = [
            training_configurations.train_pickle,
            training_configurations.test_pickle
        ]
        flag = True

    dataset = args.dataset.lower()

    model = build_model(args, dropout=0.5)
    # model = build_model(args)
    model = model.to(device)
    optimizer = optim.SGD(model.parameters(),
                          lr=1.25e-2,
                          momentum=0.9,
                          nesterov=True,
                          weight_decay=1e-4)

    resize = True
    epochs = 40
    scheduler = MultiStepLR(optimizer, milestones=[10, 20, 30], gamma=0.1)
    # epochs = 90
    # scheduler = MultiStepLR(optimizer, milestones=[30, 60, 80], gamma=0.1)

    if 'genOdin' in training_configurations.checkpoint:
        weight_decay = 1e-4
        optimizer = optim.SGD([
            {
                'params': model._conv_stem.parameters(),
                'weight_decay': weight_decay
            },
            {
                'params': model._bn0.parameters(),
                'weight_decay': weight_decay
            },
            {
                'params': model._blocks.parameters(),
                'weight_decay': weight_decay
            },
            {
                'params': model._conv_head.parameters(),
                'weight_decay': weight_decay
            },
            {
                'params': model._bn1.parameters(),
                'weight_decay': weight_decay
            },
            {
                'params': model._fc_denominator.parameters(),
                'weight_decay': weight_decay
            },
            {
                'params': model._denominator_batch_norm.parameters(),
                'weight_decay': weight_decay
            },
            {
                'params': model._fc_nominator.parameters(),
                'weight_decay': 0
            },
        ],
                              lr=1.25e-2,
                              momentum=0.9,
                              nesterov=True)
        scheduler = MultiStepLR(optimizer, milestones=[10, 20, 30], gamma=0.1)

    if not flag:
        trainloader, val_loader, testloader = natural_image_loaders(
            dataset,
            train_batch_size=32,
            test_batch_size=32,
            validation_test_split=1000,
            save_to_pickle=True,
            resize=resize)
    else:
        trainloader, val_loader, testloader = natural_image_loaders(
            dataset,
            train_batch_size=32,
            test_batch_size=32,
            validation_test_split=1000,
            pickle_files=pickle_files,
            resize=resize)

    criterion = nn.CrossEntropyLoss()
    checkpoint_val_accuracy, best_val_acc, test_set_accuracy = 0, 0, 0

    for epoch in tqdm(range(epochs)):

        model.train()
        correct, total = 0, 0
        train_loss = 0
        for data in tqdm(trainloader):
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            if 'genOdin' in training_configurations.checkpoint:
                outputs, _, _ = model(inputs)
            else:
                outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            loss = criterion(outputs, labels)
            train_loss += loss.item()
            loss.backward()
            optimizer.step()

        scheduler.step()
        train_accuracy = correct / total
        wandb.log({'epoch': epoch}, commit=False)
        wandb.log({
            'Train Set Loss': train_loss / trainloader.__len__(),
            'epoch': epoch
        })
        wandb.log({'Train Set Accuracy': train_accuracy, 'epoch': epoch})

        model.eval()
        correct, total = 0, 0

        with torch.no_grad():

            for data in val_loader:
                images, labels = data
                images = images.to(device)
                labels = labels.to(device)

                if 'genOdin' in training_configurations.checkpoint:
                    outputs, _, _ = model(images)
                else:
                    outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

            epoch_val_accuracy = correct / total
            wandb.log({
                'Validation Set Accuracy': epoch_val_accuracy,
                'epoch': epoch
            })

        if epoch_val_accuracy > best_val_acc:
            best_val_acc = epoch_val_accuracy
            if os.path.exists('/raid/ferles/'):
                torch.save(
                    model.state_dict(),
                    f'/raid/ferles/checkpoints/eb0/{dataset}/{training_configurations.checkpoint}.pth'
                )
                # torch.save(model.state_dict(), f'/raid/ferles/checkpoints/eb0/{dataset}/extended_{training_configurations.checkpoint}.pth')
            else:
                torch.save(
                    model.state_dict(),
                    f'/home/ferles/checkpoints/eb0/{dataset}/{training_configurations.checkpoint}.pth'
                )
                # torch.save(model.state_dict(), f'/home/ferles/checkpoints/eb0/{dataset}/low_dropout_extended_{training_configurations.checkpoint}.pth')

            correct, total = 0, 0

            for data in testloader:
                images, labels = data
                images = images.to(device)
                labels = labels.to(device)

                if 'genOdin' in training_configurations.checkpoint:
                    outputs, _, _ = model(images)
                else:
                    outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

            test_set_accuracy = correct / total

        wandb.log({'Test Set Accuracy': test_set_accuracy, 'epoch': epoch})
def train(args):

    json_options = json_file_to_pyobj(args.config)
    modified_zero_shot_configurations = json_options.training

    wrn_depth_teacher = modified_zero_shot_configurations.wrn_depth_teacher
    wrn_width_teacher = modified_zero_shot_configurations.wrn_width_teacher
    wrn_depth_student = modified_zero_shot_configurations.wrn_depth_student
    wrn_width_student = modified_zero_shot_configurations.wrn_width_student

    M = modified_zero_shot_configurations.M

    dataset = modified_zero_shot_configurations.dataset
    seeds = [int(seed) for seed in modified_zero_shot_configurations.seeds]
    log = True if modified_zero_shot_configurations.log.lower() == 'True' else False

    if log:
        teacher_str = 'WideResNet-{}-{}'.format(wrn_depth_teacher, wrn_width_teacher)
        student_str = 'WideResNet-{}-{}'.format(wrn_depth_student, wrn_width_student)
        logfile = 'Teacher-{}-Student-{}-{}-M-{}-Zero-Shot.txt'.format(teacher_str, student_str, modified_zero_shot_configurations.dataset, M)
        with open(logfile, 'w') as temp:
            temp.write('Zero-Shot with teacher {} and student {} in {} with M-{}\n'.format(teacher_str, student_str, modified_zero_shot_configurations.dataset, M))
    else:
        logfile = ''

    checkpoint = bool(modified_zero_shot_configurations.checkpoint)

    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')

    test_set_accuracies = []

    for seed in seeds:

        set_seed(seed)

        if dataset.lower() == 'cifar10':

            from utils import cifar10loaders
            _, test_loader = cifar10loaders()

        elif dataset.lower() == 'svhn':

            from utils import svhnLoaders
            _, test_loader = svhnLoaders()

        else:
            raise ValueError('Datasets to choose from: CIFAR10 and SVHN')

        if log:
            with open(logfile, 'a') as temp:
                temp.write('------------------- SEED {} -------------------\n'.format(seed))

        strides = [1, 1, 2, 2]

        teacher_net = WideResNet(d=wrn_depth_teacher, k=wrn_width_teacher, n_classes=10, input_features=3, output_features=16, strides=strides)
        teacher_net = teacher_net.to(device)
        if dataset.lower() == 'cifar10':
            torch_checkpoint = torch.load('./PreTrainedModels/PreTrainedScratches/CIFAR10/wrn-{}-{}-seed-{}-dict.pth'.format(wrn_depth_teacher, wrn_width_teacher, seed), map_location=device)
        elif dataset.lower() == 'svhn':
            torch_checkpoint = torch.load('./PreTrainedModels/PreTrainedScratches/SVHN/wrn-{}-{}-seed-svhn-{}-dict.pth'.format(wrn_depth_teacher, wrn_width_teacher, seed), map_location=device)
        else:
            raise ValueError('Dataset not found')
        teacher_net.load_state_dict(torch_checkpoint)

        student_net = WideResNet(d=wrn_depth_student, k=wrn_width_student, n_classes=10, input_features=3, output_features=16, strides=strides)
        student_net = student_net.to(device)

        generator_net = Generator()
        generator_net = generator_net.to(device)

        checkpointFile = 'zero_shot_teacher_wrn-{}-{}_student_wrn-{}-{}-M-{}-seed-{}-{}-dict.pth'.format(wrn_depth_teacher, wrn_width_teacher, wrn_depth_student, wrn_width_student, M, seed, dataset) if checkpoint else ''
        finalCheckpointFile = 'zero_shot_teacher_wrn-{}-{}_student_wrn-{}-{}-M-{}-seed-{}-{}-final-dict.pth'.format(wrn_depth_teacher, wrn_width_teacher, wrn_depth_student, wrn_width_student, M, seed, dataset) if checkpoint else ''
        genCheckpointFile = 'zero_shot_teacher_wrn-{}-{}_student_wrn-{}-{}-M-{}-seed-{}-{}-generator-dict.pth'.format(wrn_depth_teacher, wrn_width_teacher, wrn_depth_student, wrn_width_student, M, seed, dataset) if checkpoint else ''

        best_test_set_accuracy = _train_seed_zero_shot(teacher_net, student_net, generator_net, test_loader, device, log, checkpoint, logfile, checkpointFile, finalCheckpointFile, genCheckpointFile)

        if log:
            with open(logfile, 'a') as temp:
                temp.write('Best test set accuracy of seed {} is {}\n'.format(seed, best_test_set_accuracy))

        test_set_accuracies.append(best_test_set_accuracy)

        if log:
            with open(logfile, 'a') as temp:
                temp.write('Best test set accuracy of seed {} is {}\n'.format(seed, best_test_set_accuracy))

    mean_test_set_accuracy, std_test_set_accuracy = np.mean(test_set_accuracies), np.std(test_set_accuracies)

    if log:
        with open(logfile, 'a') as temp:
            temp.write('Mean test set accuracy is {} with standard deviation equal to {}\n'.format(mean_test_set_accuracy, std_test_set_accuracy))
Esempio n. 11
0
def train(args):

    json_options = json_file_to_pyobj(args.config)
    training_configurations = json_options.training
    wandb.init(name=training_configurations.checkpoint + 'Ensemble')
    device = torch.device(f'cuda:{args.device}')

    dataset = args.dataset.lower()
    pickle_files = [
        training_configurations.train_pickle,
        training_configurations.test_pickle
    ]
    train_ind_loaders, train_ood_loaders, val_ind_loaders, test_ind_loaders, num_classes, dicts = create_ensemble_loaders(
        dataset,
        num_classes=training_configurations.out_classes,
        pickle_files=pickle_files)

    criterion = nn.CrossEntropyLoss()
    b = 0.2
    m = 0.4

    for index in range(len(train_ind_loaders)):

        epochs = 40
        model = build_model(args)
        model._fc = nn.Linear(model._fc.in_features, num_classes[index])
        model = model.to(device)
        optimizer = optim.SGD(model.parameters(),
                              lr=1.25e-02,
                              momentum=0.9,
                              nesterov=True,
                              weight_decay=1e-4)
        scheduler = MultiStepLR(optimizer, milestones=[10, 20, 30], gamma=0.1)

        train_ind_loader, train_ood_loader = train_ind_loaders[
            index], train_ood_loaders[index]
        val_ind_loader = val_ind_loaders[index]
        test_ind_loader = test_ind_loaders[index]
        dic = dicts[index]

        ood_loader_iter = iter(train_ood_loader)

        best_val_acc = 0
        test_epoch_accuracy = 0

        for epoch in tqdm(range(epochs)):

            model.train()
            correct, total = 0, 0
            train_loss = 0
            for data in tqdm(train_ind_loader):
                inputs, labels = data
                inputs = inputs.to(device)
                _labels = torch.LongTensor([dic[int(l)] for l in labels])
                labels = _labels.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                ce_loss = criterion(outputs, labels)

                try:
                    ood_inputs, _ = next(ood_loader_iter)
                except:
                    ood_loader_iter = iter(train_ood_loader)
                    ood_inputs, _ = next(ood_loader_iter)

                ood_inputs = ood_inputs.to(device)
                ood_outputs = model(ood_inputs)
                entropy_input = -torch.mean(
                    torch.sum(F.log_softmax(outputs, dim=1) *
                              F.softmax(outputs, dim=1),
                              dim=1))
                entropy_output = -torch.mean(
                    torch.sum(F.log_softmax(ood_outputs, dim=1) *
                              F.softmax(ood_outputs, dim=1),
                              dim=1))
                margin_loss = b * torch.clamp(
                    m + entropy_input - entropy_output, min=0)

                loss = ce_loss + margin_loss
                train_loss += loss.item()
                loss.backward()
                optimizer.step()

            train_accuracy = correct / total
            wandb.log({'epoch': epoch}, commit=False)
            epoch_train_set_loss = train_loss / train_ind_loader.__len__()

            wandb.log({
                f'Train Set Loss {index}': epoch_train_set_loss,
                'epoch': epoch
            })
            wandb.log({
                f'Train Set Accuracy {index}': train_accuracy,
                'epoch': epoch
            })

            with torch.no_grad():

                model.eval()
                v_correct, v_total = 0, 0

                for data in val_ind_loader:
                    images, labels = data
                    images = images.to(device)
                    _labels = torch.LongTensor([dic[int(l)] for l in labels])
                    labels = _labels.to(device)
                    labels = labels.to(device)

                    outputs = model(images)
                    _, predicted = torch.max(outputs.data, 1)
                    v_total += labels.size(0)
                    v_correct += (predicted == labels).sum().item()

                val_epoch_accuracy = v_correct / v_total
                wandb.log({
                    f'Validation Set Accuracy {index}': val_epoch_accuracy,
                    'epoch': epoch
                })

                if val_epoch_accuracy > best_val_acc:
                    best_val_acc = val_epoch_accuracy
                    torch.save(
                        model.state_dict(),
                        f'/raid/ferles/checkpoints/eb0/{dataset}/{training_configurations.checkpoint}_best_accuracy_ensemble_{index}.pth'
                    )

                    correct, total = 0, 0
                    for data in test_ind_loader:
                        images, labels = data
                        images = images.to(device)
                        _labels = torch.LongTensor(
                            [dic[int(l)] for l in labels])
                        labels = _labels.to(device)
                        labels = labels.to(device)

                        outputs = model(images)
                        _, predicted = torch.max(outputs.data, 1)
                        total += labels.size(0)
                        correct += (predicted == labels).sum().item()

                    test_epoch_accuracy = correct / total

            wandb.log({
                f'Test Set Accuracy {index}': test_epoch_accuracy,
                'epoch': epoch
            })

    scheduler.step()
def train(args):

    use_wandb = True
    device = torch.device(f'cuda')

    json_options = json_file_to_pyobj(args.config)
    training_configurations = json_options.training
    traincsv = training_configurations.traincsv
    testcsv = training_configurations.testcsv
    gtFileName = training_configurations.gtFile
    checkpointFileName = training_configurations.checkpointFile
    out_classes = training_configurations.out_classes
    exclude_class = training_configurations.exclude_class
    exclude_class = None if exclude_class == "None" else exclude_class

    if use_wandb:
        wandb.init(name=checkpointFileName, entity='ferles')

    if exclude_class is None:
        train_loader, val_loader, test_loader, columns = oversampling_loaders_custom(
            csvfiles=[traincsv, testcsv],
            train_batch_size=32,
            val_batch_size=16,
            gtFile=gtFileName)
    else:
        train_loader, val_loader, test_loader, columns = oversampling_loaders_exclude_class_custom_no_gts(
            csvfiles=[traincsv, testcsv],
            train_batch_size=32,
            val_batch_size=16,
            gtFile=gtFileName,
            exclude_class=exclude_class)
    model = build_model(args, rot=True)
    model = nn.DataParallel(model).to(device)

    epochs = 40
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(),
                          lr=1.25e-2,
                          momentum=0.9,
                          nesterov=True,
                          weight_decay=1e-4)
    scheduler = MultiStepLR(optimizer, milestones=[10, 20, 30], gamma=0.1)

    test_loss, best_val_detection_accuracy, test_detection_accuracy = 10, 0, 0

    train_loss, val_loss, balanced_accuracies = [], [], []

    # early_stopping = True
    early_stopping = False

    early_stopping_cnt = 0
    for epoch in tqdm(range(epochs)):

        model.train()
        loss_acc = []

        for data in tqdm(train_loader):
            inputs, labels = data
            labels = labels.to(device)
            optimizer.zero_grad()

            outputs = model(inputs)

            _labels = torch.argmax(labels, dim=1)
            ce_loss = criterion(outputs, _labels)

            # Rotation Loss
            rot_gt = torch.cat(
                (torch.zeros(inputs.size(0)), torch.ones(inputs.size(0)), 2 *
                 torch.ones(inputs.size(0)), 3 * torch.ones(inputs.size(0))),
                0).long().to(device)

            rot_inputs = inputs.detach().cpu().numpy()

            rot_inputs = np.concatenate(
                (rot_inputs, np.rot90(rot_inputs, 1, axes=(2, 3)),
                 np.rot90(rot_inputs, 2, axes=(2, 3)),
                 np.rot90(rot_inputs, 3, axes=(2, 3))), 0)

            rot_inputs = torch.FloatTensor(rot_inputs)
            rot_preds = model(rot_inputs, rot=True)

            rot_loss = 0.5 * criterion(rot_preds, rot_gt.to(device))

            loss = ce_loss + rot_loss

            loss_acc.append(loss.item())
            loss.backward()
            optimizer.step()

        wandb.log({
            'Train Set Loss':
            sum(loss_acc) / float(train_loader.__len__()),
            'epoch':
            epoch
        })
        wandb.log({'epoch': epoch}, commit=False)
        train_loss.append(sum(loss_acc) / float(train_loader.__len__()))
        loss_acc.clear()

        with torch.no_grad():
            model.eval()
            correct, total = 0, 0
            for data in tqdm(val_loader):
                images, labels = data
                images = images.to(device)
                labels = labels.to(device)

                outputs = model(images)

                softmax_outputs = torch.softmax(outputs, 1)
                max_idx = torch.argmax(softmax_outputs, axis=1)
                _labels = torch.argmax(labels, dim=1)
                correct += (max_idx == _labels).sum().item()
                total += max_idx.size()[0]
                loss = criterion(outputs, _labels)
                loss_acc.append(loss.item())

            val_detection_accuracy = round(100 * correct / total, 2)
            wandb.log({
                'Validation Detection Accuracy': val_detection_accuracy,
                'epoch': epoch
            })

            if val_detection_accuracy > best_val_detection_accuracy:
                best_val_detection_accuracy = val_detection_accuracy
                test_loss, test_detection_accuracy = _test_set_eval(
                    model, epoch, device, test_loader, out_classes, columns,
                    gtFileName)

                if exclude_class is None:
                    checkpointFile = os.path.join(
                        f'/raid/ferles/checkpoints/isic_classifiers/rot_isic-best-model.pth'
                    )
                else:
                    checkpointFile = os.path.join(
                        f'/raid/ferles/checkpoints/isic_classifiers/rot_isic-_{exclude_class}-best-model.pth'
                    )
                if os.path.exists('/raid/ferles/'):
                    torch.save(model.state_dict(), checkpointFile)
                else:
                    torch.save(model.state_dict(),
                               checkpointFile.replace('raid', 'home'))
            else:
                if early_stopping:
                    early_stopping_cnt += 1
                    if early_stopping_cnt == 3:
                        wandb.log({'Test Set Loss': test_loss, 'epoch': epoch})
                        wandb.log({
                            'Detection Accuracy': test_detection_accuracy,
                            'epoch': epoch
                        })
                        break

            if exclude_class is None and epoch == 20:
                break
            elif exclude_class == 'AK' and epoch == 19:
                break
            elif exclude_class == 'BCC' and epoch == 12:
                break
            elif exclude_class == 'BKL' and epoch == 15:
                break
            elif exclude_class == 'DF' and epoch == 10:
                break
            elif exclude_class == 'MEL' and epoch == 12:
                break
            elif exclude_class == 'NV' and epoch == 27:
                break
            elif exclude_class == 'SCC' and epoch == 10:
                break
            elif exclude_class == 'VASC' and epoch == 10:
                break

            wandb.log({'Test Set Loss': test_loss, 'epoch': epoch})
            wandb.log({
                'Detection Accuracy': test_detection_accuracy,
                'epoch': epoch
            })
        # val_loss, auc, balanced_accuracy = _test_set_eval(model, epoch, device, val_loader, out_classes, columns, gtFileName)

        # if auc > best_auc:
        #     best_auc = auc
        #     checkpointFile = os.path.join(f'{abs_path}/checkpoints/rotation/{checkpointFileName}-best-auc-model.pth')
        #     torch.save(model.state_dict(), checkpointFile)

        # if balanced_accuracy > best_balanced_accuracy:
        #     best_balanced_accuracy = balanced_accuracy
        #     checkpointFile = os.path.join(f'{abs_path}/checkpoints/rotation/{checkpointFileName}-best-balanced-accuracy-model.pth')
        #     torch.save(model.state_dict(), checkpointFile)

        # if val_loss < best_val_loss:
        #     best_val_loss = val_loss
        #     checkpointFile = os.path.join(f'{abs_path}/checkpoints/rotation/{checkpointFileName}-best-val-loss-model.pth')
        #     torch.save(model.state_dict(), checkpointFile)
        #     early_stopping_cnt = 0

        scheduler.step()
def train(args):

    json_options = json_file_to_pyobj(args.config)
    training_configurations = json_options.training
    wandb.init(
        name=f"{training_configurations.checkpoint}_subset_{args.subset_index}"
    )
    device = torch.device(f'cuda:{args.device}')

    flag = False
    if training_configurations.train_pickle != 'None' and training_configurations.test_pickle != 'None':
        pickle_files = [
            training_configurations.train_pickle,
            training_configurations.test_pickle
        ]
        flag = True

    model = build_model(args)
    model = model.to(device)
    dataset = args.dataset.lower()

    if 'wide' in training_configurations.model.lower():
        epochs = 100
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=0.1,
                                    momentum=0.9,
                                    nesterov=True,
                                    weight_decay=5e-4)
        scheduler = MultiStepLR(optimizer, milestones=[20, 50, 80], gamma=0.2)
    else:
        epochs = 40
        optimizer = optim.SGD(model.parameters(),
                              lr=1.25e-2,
                              momentum=0.9,
                              nesterov=True,
                              weight_decay=1e-4)
        scheduler = MultiStepLR(optimizer, milestones=[10, 20, 30], gamma=0.1)

    if not flag:
        if args.subset_index is None:
            trainloader, val_loader, testloader = fine_grained_image_loaders(
                dataset,
                train_batch_size=32,
                test_batch_size=32,
                validation_test_split=1000,
                save_to_pickle=True)
        else:
            trainloader, val_loader, testloader = fine_grained_image_loaders_subset(
                dataset,
                subset_index=args.subset_index,
                validation_test_split=800,
                save_to_pickle=True)
    else:
        if args.subset_index is None:
            trainloader, val_loader, testloader = fine_grained_image_loaders(
                dataset,
                train_batch_size=32,
                test_batch_size=32,
                validation_test_split=1000,
                pickle_files=pickle_files)
        else:
            pickle_files[0] = pickle_files[0].split(
                ".pickle")[0] + f"_subset_{args.subset_index}.pickle"
            pickle_files[1] = pickle_files[1].split(
                ".pickle")[0] + f"_subset_{args.subset_index}.pickle"
            trainloader, val_loader, testloader = fine_grained_image_loaders_subset(
                dataset,
                subset_index=args.subset_index,
                validation_test_split=800,
                pickle_files=pickle_files)

        if 'genOdin' in training_configurations.checkpoint:
            weight_decay = 1e-4
            optimizer = optim.SGD([
                {
                    'params': model._conv_stem.parameters(),
                    'weight_decay': weight_decay
                },
                {
                    'params': model._bn0.parameters(),
                    'weight_decay': weight_decay
                },
                {
                    'params': model._blocks.parameters(),
                    'weight_decay': weight_decay
                },
                {
                    'params': model._conv_head.parameters(),
                    'weight_decay': weight_decay
                },
                {
                    'params': model._bn1.parameters(),
                    'weight_decay': weight_decay
                },
                {
                    'params': model._fc_denominator.parameters(),
                    'weight_decay': weight_decay
                },
                {
                    'params': model._denominator_batch_norm.parameters(),
                    'weight_decay': weight_decay
                },
                {
                    'params': model._fc_nominator.parameters(),
                    'weight_decay': 0
                },
            ],
                                  lr=1.25e-2,
                                  momentum=0.9,
                                  nesterov=True)
            scheduler = MultiStepLR(optimizer,
                                    milestones=[10, 20, 30],
                                    gamma=0.1)

    criterion = nn.CrossEntropyLoss()
    checkpoint_val_accuracy, best_val_acc, test_set_accuracy = 0, 0, 0

    for epoch in tqdm(range(epochs)):

        model.train()
        correct, total = 0, 0
        train_loss = 0
        for data in tqdm(trainloader):

            model.train()

            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            if 'genodin' in training_configurations.checkpoint.lower():
                outputs, h, g = model(inputs)
            else:
                outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            loss = criterion(outputs, labels)
            train_loss += loss.item()
            loss.backward()
            optimizer.step()

            # if epoch < 2:
            #     model.eval()
            #     v_correct, v_total = 0, 0
            #
            #     with torch.no_grad():
            #
            #         for v_data in testloader:
            #             v_images, v_labels = v_data
            #             v_images = v_images.to(device)
            #             v_labels = v_labels.to(device)
            #
            #             v_outputs = model(v_images)
            #             _, v_predicted = torch.max(v_outputs.data, 1)
            #             v_total += v_labels.size(0)
            #             v_correct += (v_predicted == v_labels).sum().item()
            #
            #         acc = v_correct / v_total
            #         if os.path.exists('/raid/ferles/'):
            #             torch.save(model.state_dict(), f'/raid/ferles/checkpoints/eb0/{dataset}/{training_configurations.checkpoint}_acc_{acc}.pth')
            #         else:
            #             torch.save(model.state_dict(), f'/home/ferles/checkpoints/eb0/{dataset}/{training_configurations.checkpoint}_acc_{acc}.pth')

        train_accuracy = correct / total
        wandb.log({'epoch': epoch}, commit=False)
        wandb.log({
            'Train Set Loss': train_loss / trainloader.__len__(),
            'epoch': epoch
        })
        wandb.log({'Train Set Accuracy': train_accuracy, 'epoch': epoch})

        model.eval()
        correct, total = 0, 0

        with torch.no_grad():

            for data in val_loader:
                images, labels = data
                images = images.to(device)
                labels = labels.to(device)

                if 'genodin' in training_configurations.checkpoint.lower():
                    outputs, h, g = model(images)
                else:
                    outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

            epoch_val_accuracy = correct / total

            wandb.log({
                'Validation Set Accuracy': epoch_val_accuracy,
                'epoch': epoch
            })

        if epoch_val_accuracy > best_val_acc:
            best_val_acc = epoch_val_accuracy

            if os.path.exists('/raid/ferles/'):
                if args.subset_index is None:
                    torch.save(
                        model.state_dict(),
                        f'/raid/ferles/checkpoints/eb0/{dataset}/{training_configurations.checkpoint}.pth'
                    )
                else:
                    torch.save(
                        model.state_dict(),
                        f'/raid/ferles/checkpoints/eb0/{dataset}/{training_configurations.checkpoint}_subset_{args.subset_index}.pth'
                    )
            else:
                if args.subset_index is None:
                    torch.save(
                        model.state_dict(),
                        f'/home/ferles/checkpoints/eb0/{dataset}/{training_configurations.checkpoint}.pth'
                    )
                else:
                    torch.save(
                        model.state_dict(),
                        f'/home/ferles/checkpoints/eb0/{dataset}/{training_configurations.checkpoint}_subset_{args.subset_index}.pth'
                    )

            # if best_val_acc - checkpoint_val_accuracy > 0.05:
            #     checkpoint_val_accuracy = best_val_acc
            #     torch.save(model.state_dict(), f'/raid/ferles/checkpoints/eb0/{dataset}/{training_configurations.checkpoint}_epoch_{epoch}_accuracy_{best_val_acc}.pth')

            correct, total = 0, 0

            for data in testloader:
                images, labels = data
                images = images.to(device)
                labels = labels.to(device)

                if 'genodin' in training_configurations.checkpoint.lower():
                    outputs, h, g = model(images)
                else:
                    outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

            test_set_accuracy = correct / total

        wandb.log({'Test Set Accuracy': test_set_accuracy, 'epoch': epoch})

        scheduler.step(epoch=epoch)
def train(args):

    json_options = json_file_to_pyobj(args.config)
    training_configurations = json_options.training
    wandb.init(
        name=f"{training_configurations.checkpoint}_subset_{args.subset_index}"
    )
    device = torch.device(f'cuda:{args.device}')

    flag = False
    if training_configurations.train_pickle != 'None' and training_configurations.test_pickle != 'None':
        pickle_files = [
            training_configurations.train_pickle,
            training_configurations.test_pickle
        ]
        flag = True

    if args.subset_index is None:
        model = build_model(args)
        model = model.to(device)
        if training_configurations.model == 'EfficientNet':
            epochs = 40
            optimizer = optim.SGD(model.parameters(),
                                  lr=1.25e-2,
                                  momentum=0.9,
                                  nesterov=True,
                                  weight_decay=1e-4)
            scheduler = MultiStepLR(optimizer,
                                    milestones=[10, 20, 30],
                                    gamma=0.1)
            batch_size = 32
        elif training_configurations.model == 'DenseNet':
            model = torch.hub.load('pytorch/vision:v0.6.0',
                                   'densenet121',
                                   pretrained=True)
            model = model.to(device)
            epochs = 200
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=0.1,
                                        momentum=0.9,
                                        weight_decay=0.0001)
            scheduler = MultiStepLR(
                optimizer,
                milestones=[int(0.5 * epochs),
                            int(0.75 * epochs)],
                gamma=0.1)
            batch_size = 16

    dataset = args.dataset.lower()

    if not flag:
        if args.subset_index is None:
            trainloader, val_loader, testloader = fine_grained_image_loaders(
                dataset,
                train_batch_size=batch_size,
                test_batch_size=batch_size,
                validation_test_split=1000,
                save_to_pickle=True)
        else:
            trainloader, val_loader, testloader = fine_grained_image_loaders_subset(
                dataset,
                subset_index=args.subset_index,
                validation_test_split=800,
                save_to_pickle=True)
    else:
        if args.subset_index is None:
            trainloader, val_loader, testloader = fine_grained_image_loaders(
                dataset,
                train_batch_size=batch_size,
                test_batch_size=batch_size,
                validation_test_split=1000,
                pickle_files=pickle_files)
        else:
            pickle_files[0] = "pickle_files/" + pickle_files[0].split(
                ".pickle")[0] + f"_subset_{args.subset_index}.pickle"
            pickle_files[1] = "pickle_files/" + pickle_files[1].split(
                ".pickle")[0] + f"_subset_{args.subset_index}.pickle"
            trainloader, val_loader, testloader, num_classes = fine_grained_image_loaders_subset(
                dataset,
                subset_index=args.subset_index,
                validation_test_split=800,
                pickle_files=pickle_files,
                ret_num_classes=True)

    if args.subset_index is not None:
        model = build_model(args)
        if 'genodin' in training_configurations.checkpoint.lower():
            from efficientnet_pytorch.gen_odin_model import CosineSimilarity
            model._fc_nominator = CosineSimilarity(feat_dim=1280,
                                                   num_centers=num_classes)
        else:
            model._fc = nn.Linear(model._fc.in_features, num_classes)
        model = model.to(device)
        epochs = 40
        optimizer = optim.SGD(model.parameters(),
                              lr=1.25e-2,
                              momentum=0.9,
                              nesterov=True,
                              weight_decay=1e-4)
        scheduler = MultiStepLR(optimizer, milestones=[10, 20, 30], gamma=0.1)

    criterion = nn.CrossEntropyLoss()
    checkpoint_val_accuracy, best_val_acc, test_set_accuracy = 0, 0, 0

    if 'genodin' in training_configurations.checkpoint:
        weight_decay = 1e-4
        optimizer = optim.SGD([
            {
                'params': model._conv_stem.parameters(),
                'weight_decay': weight_decay
            },
            {
                'params': model._bn0.parameters(),
                'weight_decay': weight_decay
            },
            {
                'params': model._blocks.parameters(),
                'weight_decay': weight_decay
            },
            {
                'params': model._conv_head.parameters(),
                'weight_decay': weight_decay
            },
            {
                'params': model._bn1.parameters(),
                'weight_decay': weight_decay
            },
            {
                'params': model._fc_denominator.parameters(),
                'weight_decay': weight_decay
            },
            {
                'params': model._denominator_batch_norm.parameters(),
                'weight_decay': weight_decay
            },
            {
                'params': model._fc_nominator.parameters(),
                'weight_decay': 0
            },
        ],
                              lr=1.25e-2,
                              momentum=0.9,
                              nesterov=True)
        scheduler = MultiStepLR(optimizer, milestones=[10, 20, 30], gamma=0.1)

    for epoch in tqdm(range(epochs)):

        model.train()
        correct, total = 0, 0
        train_loss = 0
        for data in tqdm(trainloader):

            model.train()
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            if 'genodin' in training_configurations.checkpoint.lower():
                outputs, h, g = model(inputs)
            else:
                outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            loss = criterion(outputs, labels)
            train_loss += loss.item()
            loss.backward()
            optimizer.step()

        train_accuracy = correct / total
        wandb.log({'epoch': epoch}, commit=False)
        wandb.log({
            'Train Set Loss': train_loss / trainloader.__len__(),
            'epoch': epoch
        })
        wandb.log({'Train Set Accuracy': train_accuracy, 'epoch': epoch})

        model.eval()
        correct, total = 0, 0

        with torch.no_grad():

            for data in val_loader:
                images, labels = data
                images = images.to(device)
                labels = labels.to(device)

                if 'genodin' in training_configurations.checkpoint.lower():
                    outputs, h, g = model(images)
                else:
                    outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

            epoch_val_accuracy = correct / total

            wandb.log({
                'Validation Set Accuracy': epoch_val_accuracy,
                'epoch': epoch
            })

        if epoch_val_accuracy > best_val_acc:
            best_val_acc = epoch_val_accuracy

            if os.path.exists('/raid/ferles/'):
                if args.subset_index is None:
                    torch.save(
                        model.state_dict(),
                        f'/raid/ferles/checkpoints/eb0/{dataset}/{training_configurations.checkpoint}.pth'
                    )
                else:
                    torch.save(
                        model.state_dict(),
                        f'/raid/ferles/checkpoints/eb0/{dataset}/{training_configurations.checkpoint}_subset_{args.subset_index}.pth'
                    )
            else:
                if args.subset_index is None:
                    torch.save(
                        model.state_dict(),
                        f'/home/ferles/checkpoints/eb0/{dataset}/{training_configurations.checkpoint}.pth'
                    )
                else:
                    torch.save(
                        model.state_dict(),
                        f'/home/ferles/checkpoints/eb0/{dataset}/{training_configurations.checkpoint}_subset_{args.subset_index}.pth'
                    )

            # if best_val_acc - checkpoint_val_accuracy > 0.05:
            #     checkpoint_val_accuracy = best_val_acc
            #     torch.save(model.state_dict(), f'/raid/ferles/checkpoints/eb0/{dataset}/{training_configurations.checkpoint}_epoch_{epoch}_accuracy_{best_val_acc}.pth')

            correct, total = 0, 0

            for data in testloader:
                images, labels = data
                images = images.to(device)
                labels = labels.to(device)

                if 'genodin' in training_configurations.checkpoint.lower():
                    outputs, h, g = model(images)
                else:
                    outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

            test_set_accuracy = correct / total

        wandb.log({'Test Set Accuracy': test_set_accuracy, 'epoch': epoch})

        scheduler.step(epoch=epoch)
def adversarial_belief_matching(args):
    json_options = json_file_to_pyobj(args.config)
    abm_configurations = json_options.abm_setting

    wrn_depth_teacher = abm_configurations.wrn_depth_teacher
    wrn_width_teacher = abm_configurations.wrn_width_teacher
    wrn_depth_student = abm_configurations.wrn_depth_student
    wrn_width_student = abm_configurations.wrn_width_student

    dataset = abm_configurations.dataset.lower()
    seeds = abm_configurations.seeds
    mode = abm_configurations.mode

    eval_teacher = True if abm_configurations.eval_teacher.lower(
    ) == 'true' else False

    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')

    for seed in seeds:

        teacher_net, student_net = _load_teacher_and_student(
            abm_configurations, seed, device)
        test_loader = get_matching_indices(dataset,
                                           teacher_net,
                                           student_net,
                                           device,
                                           n=1000)
        cnt = test_loader.__len__()

        teacher_net.eval()

        criterion = nn.CrossEntropyLoss()

        eta = 1
        K = 100

        student_image_average_transition_curves_acc, teacher_image_average_transition_curves_acc = [], []
        mean_transition_error = 0

        # count on how many test set samples teacher and student initially agree (and they are correct too!)
        for data in tqdm(test_loader):

            images, _ = data
            images = images.to(device)

            student_net.eval()
            student_outputs = student_net(images)[0]
            _, student_predicted = torch.max(student_outputs.data, 1)

            teacher_outputs = teacher_net(images)[0]
            _, teacher_predicted = torch.max(teacher_outputs.data, 1)

            x0 = deepcopy(images.detach())
            student_transition_curves, teacher_transition_curves = [], []

            for fake_label in range(0, 10):

                if fake_label != student_predicted:

                    fake_label = torch.Tensor([fake_label]).long().to(device)
                    student_probs_acc, teacher_probs_acc = [], []

                    x_adv = deepcopy(x0)
                    x_adv.requires_grad = True

                    for _ in range(K):
                        if eval_teacher:
                            teacher_fake_outputs = teacher_net(x_adv)[0]
                            with torch.no_grad():
                                student_fake_outputs = student_net(x_adv)[0]
                            loss = criterion(teacher_fake_outputs, fake_label)

                            teacher_net.zero_grad()
                            loss.backward()
                            x_adv.data -= eta * x_adv.grad.data
                            x_adv.grad.data.zero_()
                        else:
                            student_fake_outputs = student_net(x_adv)[0]
                            with torch.no_grad():
                                teacher_fake_outputs = teacher_net(x_adv)[0]
                            loss = criterion(student_fake_outputs, fake_label)

                            student_net.zero_grad()
                            loss.backward()
                            x_adv.data -= eta * x_adv.grad.data
                            x_adv.grad.data.zero_()

                        teacher_probs = F.softmax(teacher_fake_outputs, dim=1)
                        student_probs = F.softmax(student_fake_outputs, dim=1)

                        pj_b = teacher_probs[0][fake_label].item()
                        pj_a = student_probs[0][fake_label].item()

                        with torch.no_grad():
                            student_probs_acc.append(pj_a)
                            teacher_probs_acc.append(pj_b)

                        mean_transition_error += abs(pj_b - pj_a)

                    student_transition_curves.append(student_probs_acc)
                    teacher_transition_curves.append(teacher_probs_acc)

                else:
                    continue

                student_image_average_transition_curves_acc.append(
                    np.average(np.array(student_transition_curves), axis=0))
                teacher_image_average_transition_curves_acc.append(
                    np.average(np.array(teacher_transition_curves), axis=0))

        student_image_average_transition_curves_acc_np = np.average(
            np.array(student_image_average_transition_curves_acc), axis=0)
        teacher_image_average_transition_curves_acc_np = np.average(
            np.array(teacher_image_average_transition_curves_acc), axis=0)

        np.savez(
            'Teacher_WRN-{}-{}_transition_curve-{}-seed-{}.nzp'.format(
                wrn_depth_teacher, wrn_width_teacher, mode, seed),
            teacher_image_average_transition_curves_acc_np)
        np.savez(
            'Student_WRN-{}-{}_transition_curve-{}-seed-{}.npz'.format(
                wrn_depth_student, wrn_width_student, mode, seed),
            student_image_average_transition_curves_acc_np)

        # Average MTE over C-1 classes, K adversarial steps and correct initial samples
        mean_transition_error /= float(9 * K * cnt)
        write_mode = 'w' if seed == seeds[0] else 'a'
        with open(
                'Teacher_WRN-{}-{}-Student_WRN-{}-{}-{}-MTE.txt'.format(
                    wrn_depth_teacher, wrn_width_teacher, wrn_depth_student,
                    wrn_width_student, mode), write_mode) as logfile:
            logfile.write(
                'Teacher WideResNet-{}-{} and Student WideResNet-{}-{} trained with {} Mean Transition Error on seed {}: {}\n'
                .format(wrn_depth_teacher, wrn_width_teacher,
                        wrn_depth_student, wrn_width_student, mode, seed,
                        mean_transition_error))