def run_(
    is_inter=False,
    debug=False,
    keyword="",
    net="",
    fewshot=False,
    augment=False,
    shot=50,
    unlabeled_shot=50,
    autoaugment=False,
    data="cifar",
    pretrained=False,
    dim=512,
    seed=0,
    cutout=False,
    random_erase=False,
):
    global netG, netZ, Zs_real, neigh, aug_param, criterion, aug_param_test
    manual_seed(seed)
    FULL_PATH = keyword
    data_dir = '../../data'
    cifar_dir_cs = '/cs/dataset/CIFAR/'
    if data == "cifar":
        classes = 100
        aug_param = aug_param_test = get_cifar100_param()
        if not fewshot:
            train_labeled_dataset = torchvision.datasets.CIFAR100(
                root=data_dir, train=True, download=True)
            train_unlabeled_dataset = torchvision.datasets.CIFAR100(
                root=data_dir, train=False, download=True)
            test_data = train_unlabeled_dataset
        else:
            print("=> Fewshot")
            # train_labeled_dataset, train_unlabeled_dataset = get_cifar100_small(cifar_dir_small, shot)
            train_labeled_dataset, train_unlabeled_dataset, _, test_data = get_cifar100(
                cifar_dir_cs, n_labeled=shot, n_unlabled=unlabeled_shot)

    if data == "cub":
        classes = 200
        aug_param = aug_param_test = get_cub_param()
        split_file = None
        if fewshot:
            samples_per_class = int(shot)
            split_file = 'train_test_split_{}.txt'.format(samples_per_class)
        train_labeled_dataset = Cub2011(root=f"../../data/{data}",
                                        train=True,
                                        split_file=split_file)
        train_unlabeled_dataset = Cub2011(root=f"../../data/{data}",
                                          train=False,
                                          split_file=split_file)
        test_data = Cub2011(root=f"../../data/{data}", train=False)

    noise_projection = keyword.__contains__("proj")
    print(f" noise_projection={noise_projection}")
    # netG = _netG(dim, aug_param['rand_crop'], 3, noise_projection)
    if data == 'cub':
        netG = DCGAN_G(dim, aug_param['image_size'], 3, noise_projection,
                       noise_projection)
    elif data == 'cifar':
        netG = _netG(dim, aug_param['image_size'], 3, noise_projection,
                     noise_projection)
    data_to_save = test_data
    if keyword != "":
        set_seed(seed)
        if "tr_" in FULL_PATH:
            netZ = _netZ(
                dim,
                len(train_labeled_dataset) + len(train_unlabeled_dataset) +
                len(test_data), classes, None)
            # netZ = _netZ(dim, len(train_labeled_dataset) +len(train_unlabeled_dataset), classes, None)
        else:
            netZ = _netZ(dim, len(train_labeled_dataset), classes, None)
            # netZ = _netZ(dim, train_data_size +len(test_data)+ len(train_unlabeled_dataset), classes, None)
        print(f"=> Loading model from {FULL_PATH}")
        _, netG = load_saved_model(f'{FULL_PATH}/netG_nag', netG)
        epoch, netZ = load_saved_model(f'{FULL_PATH}/netZ_nag', netZ)
        # _, netG = load_saved_model(f'{PATH}/runs2/nets_{rn}/netG_nag', netG)
        # epoch, netZ = load_saved_model(f'{PATH}/runs2/nets_{rn}/netZ_nag', netZ)
        netZ = netZ.cuda()
        netG = netG.cuda()
        print(f"=> Embedding size = {len(netZ.emb.weight)}")

        if epoch > 0:
            print(f"=> Loaded successfully! epoch:{epoch}")
        else:
            raise Exception("=> No checkpoint to resume")

        Zs_real = netZ.emb.weight.data.detach().cpu().numpy()

        train_data_loader = get_loader_with_idx(train_labeled_dataset,
                                                batch_size=100,
                                                augment=augment,
                                                shuffle=True,
                                                offset_idx=0,
                                                offset_label=0,
                                                autoaugment=autoaugment,
                                                **aug_param,
                                                cutout=cutout,
                                                random_erase=random_erase)

        netZ.eval()
    save_reals = True
    if save_reals:
        # save_inter_imgs_in_npz(netZ,netG,"fid")
        # save_dataset_in_npz(test_data,"fid",name=name)
        save_dataset_in_npz(test_data, "fid")
    else:
        # for i, (idx, inputs, _) in enumerate(train_data_loader):
        # 	inputs = inputs.cuda()
        # 	targets = validate_loader_consistency(netZ, idx)
        # normalize = transforms.Normalize(mean=aug_param['mean'], std=aug_param['std'])
        # transform_train_np = transforms.Compose([
        # 		# RandomPadandCrop(32),
        # 		# RandomFlip(),
        # 		ToTensor(),
        # 		normalize,
        # 		])
        # Zs_real = netZ.emb.weight.data
        # if is_inter:
        # # z = Zs_real_numpy[idx]
        # z = Zs_real[idx]
        # z_knn_idx = np.array([random.sample(netZ.label2idx[x], 1) for x in targets])
        # ratios = list(np.linspace(0.1, 0.4, 4))
        # rnd_ratio = random.sample(ratios, 1)[0]
        # z = z.float().cuda()
        # z_knn = Zs_real[z_knn_idx[:, 0]]  # * 0.9 + Zs_real[z_knn_idx[:, 1]] * 0.1
        # inter_z = slerp_torch(rnd_ratio, z.unsqueeze(0), z_knn.cuda().float().unsqueeze(0))
        # # inter_z_slerp = interpolate_points(z, z_knn.float(), 10, print_mode=False, slerp=True)
        # # inter_z = interpolate_points(z,z_knn.float(),10,print_mode=False,slerp=False)
        # # targets = torch.tensor(targets).long().cuda()
        # # targets = targets.repeat(10)
        #
        # code = get_code(idx)
        # generated_img = netG(inter_z.squeeze().cuda(), code)
        # try:
        # 	imgs = torch.stack([normalize((img)) for img in generated_img.detach().cpu()])
        # except:
        # 	imgs = torch.stack([normalize((img)) for img in generated_img])

        # save_inter_imgs(netZ, netG, "fid")
        save_inter_imgs_in_npz(netZ, netG, keyword.split("/")[-1])
Ejemplo n.º 2
0
def run_eval_(is_inter=False,
              debug=False,
              keyword="",
              epochs=200,
              d="",
              fewshot=False,
              augment=False,
              shot=50,
              unlabeled_shot=50,
              autoaugment=False,
              loss_method="ce",
              data="cifar",
              pretrained=False,
              dim=512,
              seed=0,
              cutout=False,
              random_erase=False,
              lerp=False):
    global netG, netZ, Zs_real, neigh, aug_param, criterion, aug_param_test
    manual_seed(seed)
    title = f"{d}_{keyword}"
    if fewshot:
        title = title + "_fewshot"

    if is_inter:
        title = title + "_inter"
    dir = "eval"
    if not os.path.isdir(dir):
        os.mkdir(dir)
    logger = Logger(os.path.join(f'{dir}/', f'{title}_log.txt'), title=title)
    logger.set_names(
        ["valid_acc@1", "valid_acc@5", "test_acc@1", "test_acc@5"])
    PATH = os.getcwd()
    data_dir = '../../data'
    cifar_dir_cs = '/cs/dataset/CIFAR/'
    if data == "cifar":
        classes = 100
        lr = 0.1
        batch_size = 128
        WD = 5e-4
        aug_param = aug_param_test = get_cifar100_param()
        if not fewshot:
            train_labeled_dataset = torchvision.datasets.CIFAR100(
                root=data_dir, train=True, download=True)
            train_unlabeled_dataset = torchvision.datasets.CIFAR100(
                root=data_dir, train=False, download=True)
            test_data = train_unlabeled_dataset
        else:
            print("=> Fewshot")
            # train_labeled_dataset, train_unlabeled_dataset = get_cifar100_small(cifar_dir_small, shot)
            train_labeled_dataset, train_unlabeled_dataset, _, test_data = get_cifar100(
                cifar_dir_cs, n_labeled=shot, n_unlabled=unlabeled_shot)
    if data == "cifar-10":
        classes = 10
        lr = 0.1
        batch_size = 128
        WD = 5e-4
        aug_param = aug_param_test = get_cifar10_param()
        if not fewshot:
            train_labeled_dataset = torchvision.datasets.CIFAR10(root=data_dir,
                                                                 train=True,
                                                                 download=True)
            train_unlabeled_dataset = torchvision.datasets.CIFAR10(
                root=data_dir, train=False, download=True)
            test_data = train_unlabeled_dataset
        else:
            print("=> Fewshot")
            train_labeled_dataset, train_unlabeled_dataset, _, test_data = get_cifar10(
                cifar_dir_cs, n_labeled=shot, n_unlabled=unlabeled_shot)
    if data == "cub":
        classes = 200
        batch_size = 16
        lr = 0.001
        WD = 1e-5
        aug_param = aug_param_test = get_cub_param()
        split_file = None
        if fewshot:
            samples_per_class = int(shot)
            split_file = 'train_test_split_{}.txt'.format(samples_per_class)
        train_labeled_dataset = Cub2011(root=f"../../data/{data}",
                                        train=True,
                                        split_file=split_file)
        train_unlabeled_dataset = Cub2011(root=f"../../data/{data}",
                                          train=False,
                                          split_file=split_file)
        test_data = Cub2011(root=f"../../data/{data}", train=False)
    if data == "stl":
        print("STL-10")
        classes = 10
        WD = 4e-4
        batch_size = 32
        lr = 2e-3
        aug_param = get_train_n_unlabled_stl_param()
        aug_param_test = get_test_stl_param()
        train_labeled_dataset = torchvision.datasets.STL10(
            root=f"../../data/{data}", split='train', download=True)
        train_unlabeled_dataset = test_data = torchvision.datasets.STL10(
            root=f"../../data/{data}", split='unlabeled', download=True)
    train_data_size = len(train_labeled_dataset)
    print(f"train_labeled_dataset size:{train_data_size}")
    print(f"test_data size:{len(test_data)}")
    print(f"transductive data size:{len(train_unlabeled_dataset)}")

    # netZ.set_label2idx()
    noise_projection = keyword.__contains__("proj")
    print(f" noise_projection={noise_projection}")
    # netG = _netG(dim, aug_param['rand_crop'], 3, noise_projection)
    if data == 'cub':
        netG = DCGAN_G(dim, aug_param['image_size'], 3, noise_projection,
                       noise_projection)
        print(
            f"G: {dim}, {aug_param['image_size']}, 3, {noise_projection},{noise_projection}"
        )
    elif data == 'stl':
        netG = DCGAN_G_small(dim, aug_param['image_size'], 3, noise_projection,
                             noise_projection)
    elif 'cifar' in data:
        netG = _netG(dim, aug_param['image_size'], 3, noise_projection,
                     noise_projection)
    paths = list()
    print(f"{PATH}")
    dirs = [d for d in glob.glob(PATH)]
    print(dirs)
    for dir in dirs:
        for f in glob.iglob(f"{dir}/runs/{keyword}*log.txt"):
            fname = f.split("/")[-1]
            tmp = fname.split("_")
            name = '_'.join(tmp[:-1])
            # if is_model_classifier:
            # 	if "classifier" in name or "cnn" in name:
            paths.append(name)
    # else:
    # 	paths.append(name)
    scores_test_acc1_fewshot = dict()
    scores_test_acc5_fewshot = dict()
    set_seed(seed)
    print(f"=> Total runs: {len(paths)}\n{paths}")
    for rn in paths:
        classifier = get_classifier(classes, d, pretrained)
        if "tr_" in rn:
            print("=> Transductive mode")
            netZ = _netZ(
                dim, train_data_size + len(train_unlabeled_dataset) +
                len(test_data), classes, None)
        # netZ = _netZ(dim, train_data_size +len(test_data), classes, None)
        else:
            print("=> No Transductive")
            netZ = _netZ(dim, train_data_size, classes, None)
        # 	netZ = _netZ(dim, train_data_size +len(test_data)+ len(train_unlabeled_dataset), classes, None)
        try:
            print(f"=> Loading model from {rn}")
            _, netG = load_saved_model(f'runs/nets_{rn}/netG_nag', netG)
            epoch, netZ = load_saved_model(f'runs/nets_{rn}/netZ_nag', netZ)
            netZ = netZ.cuda()
            netG = netG.cuda()
            print(f"=> Embedding size = {len(netZ.emb.weight)}")
            print(' => Total params: %.2fM' %
                  (sum(p.numel() for p in netZ.parameters()) / 1000000.0))
            print(' => Total params: %.2fM' %
                  (sum(p.numel() for p in netG.parameters()) / 1000000.0))
            if epoch > 0:
                print(f"=> Loaded successfully! epoch:{epoch}")
            else:
                print("=> No checkpoint to resume")
        except Exception as e:
            print(f"=> Failed resume job!\n {e}")
        Zs_real = netZ.emb.weight.data.detach().cpu().numpy()
        optimizer = optim.SGD(classifier.parameters(),
                              lr,
                              momentum=0.9,
                              weight_decay=WD,
                              nesterov=True)
        print("=> Train new classifier")
        if loss_method == "cosine":
            criterion = nn.CosineEmbeddingLoss().cuda()
        else:
            criterion = nn.CrossEntropyLoss().cuda()
        num_gpus = torch.cuda.device_count()
        if num_gpus > 1:
            print(f"=> Using {num_gpus} GPUs")
            classifier = nn.DataParallel(classifier).cuda()
            cudnn.benchmark = True
        else:
            classifier = maybe_cuda(classifier)

        print(' => Total params: %.2fM' %
              (sum(p.numel() for p in classifier.parameters()) / 1000000.0))
        print(f"=> {d}  Training model")
        print(f"=> Training Epochs = {str(epochs)}")
        print(f"=> Initial Learning Rate = {str(lr)}")
        generic_train_classifier(classifier,
                                 optimizer,
                                 train_labeled_dataset,
                                 criterion=criterion,
                                 batch_size=batch_size,
                                 is_inter=is_inter,
                                 num_epochs=epochs,
                                 augment=augment,
                                 fewshot=fewshot,
                                 autoaugment=autoaugment,
                                 test_data=test_data,
                                 loss_method=loss_method,
                                 n_classes=classes,
                                 aug_param=aug_param,
                                 aug_param_test=aug_param_test,
                                 shot=shot,
                                 cutout=cutout,
                                 random_erase=random_erase,
                                 is_lerp=lerp)

        print("=> Done training classifier")

        valid_acc_1, valid_acc_5 = accuracy(classifier,
                                            train_labeled_dataset,
                                            batch_size=batch_size,
                                            aug_param=aug_param_test)
        test_acc_1, test_acc_5 = accuracy(classifier,
                                          test_data,
                                          batch_size=batch_size,
                                          aug_param=aug_param_test)

        print('train_acc accuracy@1', valid_acc_1)
        print('train_acc accuracy@5', valid_acc_5)
        print('test accuracy@1', test_acc_1)
        print('test accuracy@5', test_acc_5)
        scores_test_acc1_fewshot[seed] = test_acc_1
        scores_test_acc5_fewshot[seed] = test_acc_5
        logger.append([valid_acc_1, valid_acc_5, test_acc_1, test_acc_5])
        w = csv.writer(open(f"eval_{rn}_shot_{shot}_acc1.csv", "w+"))
        for key, val in scores_test_acc1_fewshot.items():
            w.writerow([rn, key, val])
        w = csv.writer(open(f"eval_{rn}_shot_{shot}_acc5.csv", "w+"))
        for key, val in scores_test_acc5_fewshot.items():
            w.writerow([rn, key, val])

    logger.close()
Ejemplo n.º 3
0
    if args.dataset == 'cifar10':
        trainset = torchvision.datasets.CIFAR10(
            root='./data',
            train=True,
            download=True,
            transform=torchvision.transforms.ToTensor())

    elif args.dataset == 'cub':

        transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize([224, 224],
                                          interpolation=PIL.Image.BICUBIC),
            torchvision.transforms.ToTensor()
        ])

        trainset = Cub2011(root='./data', train=True, transform=transform)
        trainset.class_to_idx = range(200)

    elif args.dataset == 'cars':

        transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize([224, 224],
                                          interpolation=PIL.Image.BICUBIC),
            torchvision.transforms.ToTensor()
        ])

        trainset = torchvision.datasets.ImageFolder(
            root='data/stanford-cars/car_data/car_data/train/',
            transform=transform)

    elif args.dataset == 'dogs':
Ejemplo n.º 4
0
def run_eval_generic(seed=17, epochs=200, d="", fewshot=False, augment=False, autoaugment=False, data="", shot=0,
                     unlabeled_shot=None,
                     resume=False, pretrained=False, batch_size=128, cutout=False, random_erase=False):
    global classifier
    dir = "baselines_aug"
    if not os.path.isdir(dir):
        os.mkdir(dir)
    name = f"{data}_{d}_baseline_"
    if not shot is None:
        name = name + f"shot{shot}"
    if not autoaugment is None:
        name = name + "_aug"

    logger = Logger(os.path.join(f"{dir}/", f'{d}_log.txt'), title=name)
    logger.set_names(["valid_acc@1", "valid_acc@5", "test_acc@1", "test_acc@5"])
    data_dir = '../../data'
    cifar_dir_cs = '/cs/dataset/CIFAR/'
    aug_param = aug_param_test = None

    if data == "cifar":
        classes = 100
        # batch_size = 128
        lr = 0.1
        WD = 5e-4
        if not fewshot:
            train_labeled_dataset = torchvision.datasets.CIFAR100(root=data_dir, train=True, download=True)
            train_unlabeled_dataset = torchvision.datasets.CIFAR100(root=data_dir, train=False, download=True)
            test_data = train_unlabeled_dataset
        else:
            print("=> Fewshot")
            # train_labeled_dataset, train_unlabeled_dataset = get_cifar100_small(cifar_dir_small, shot)
            train_labeled_dataset, train_unlabeled_dataset, _, test_data = get_cifar100(cifar_dir_cs, n_labeled=shot,
                                                                                        n_unlabled=unlabeled_shot)
        # test_data = torchvision.datasets.CIFAR100(root=data_dir, train=False, download=True)
        print(f"train_labeled_dataset size:{len(train_labeled_dataset)}")
        print(f"test_data size:{len(test_data)}")

    if data == "cifar-10":
        classes = 10
        lr = 0.1
        WD = 5e-4
        aug_param = get_cifar10_param()
        if not fewshot:
            train_labeled_dataset = torchvision.datasets.CIFAR10(root=data_dir, train=True, download=True)
            train_unlabeled_dataset = torchvision.datasets.CIFAR10(root=data_dir, train=False, download=True)
            test_data = train_unlabeled_dataset
        else:
            print("=> Fewshot")
            train_labeled_dataset, train_unlabeled_dataset, _, test_data = get_cifar10(cifar_dir_cs, n_labeled=shot,
                                                                                       n_unlabled=unlabeled_shot)
    print(f"train_labeled_dataset size:{len(train_labeled_dataset)}")
    print(f"test_data size:{len(test_data)}")

    if data == "cub":
        print("CUB-200")
        classes = 200
        # batch_size = 16
        lr = 0.001
        WD = 1e-5
        aug_param = aug_param_test = get_cub_param()
        split_file = None
        if fewshot:
            samples_per_class = int(shot)
            split_file = 'train_test_split_{}.txt'.format(samples_per_class)
            train_labeled_dataset = Cub2011(root=f"../../data/{data}", train=True, split_file=split_file)
        test_data = Cub2011(root=f"../../data/{data}", train=False)
        print(f"train_labeled_dataset size:{len(train_labeled_dataset)},{train_labeled_dataset.data.shape}")

        print(f"test_data size:{len(test_data)},{test_data.data.shape}")
    if data == "stl":
        print("STL-10")
        classes = 10
        WD = 4e-4
        # batch_size = 128
        lr = 2e-3
        aug_param = get_train_n_unlabled_stl_param()
        aug_param_test = get_test_stl_param()
        train_labeled_dataset = torchvision.datasets.STL10(root=f"../../data/{data}", split='train', download=True)
        # train_unlabeled_dataset = torchvision.datasets.STL10(root=f"../../data/{data}", split='unlabeled', download=True)
        test_data = torchvision.datasets.STL10(root=f"../../data/{data}", split='test', download=True)
        print(f"train_labeled_dataset size:{len(train_labeled_dataset)},{train_labeled_dataset.data.shape}")
        print(f"test_data size:{len(test_data)},{test_data.data.shape}")
    scores_test_acc1_fewshot = dict()
    scores_test_acc5_fewshot = dict()

    # offset_idx = len(train_labeled_dataset)
    print(f"{data} num classes: {classes}")
    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    start_epoch = 0
    if resume:
        # Load checkpoint.
        print('==> Resuming from checkpoint..')
        assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
        checkpoint = torch.load(f'./checkpoint/{name}_{seed}_ckpt.t7')
        classifier = checkpoint['net']
        best_acc = checkpoint['acc']
        start_epoch = checkpoint['epoch'] + 1
        rng_state = checkpoint['rng_state']
        torch.set_rng_state(rng_state)
        print(f"=> Model loaded start_epoch{start_epoch}, acc={best_acc}")

    else:

        classifier = get_classifier(classes, d, pretrained)
    criterion = nn.CrossEntropyLoss().cuda()
    num_gpus = torch.cuda.device_count()
    if num_gpus > 1:
        classifier = nn.DataParallel(classifier).cuda()
    else:
        classifier = classifier.cuda()
    optimizer = optim.SGD(classifier.parameters(), lr, momentum=0.9, weight_decay=WD, nesterov=True)
    cudnn.benchmark = True
    print(' => Total params: %.2fM' % (sum(p.numel() for p in classifier.parameters()) / 1000000.0))
    print(f"=> {d}  Training model")
    print(f"=> Training Epochs = {str(epochs)}")
    print(f"=> Initial Learning Rate = {str(lr)}")

    # criterion = maybe_cuda(criterion)
    # optimizer = optim.Adam(classifier.parameters(), lr=1e-3)
    train_linear_classifier(classifier, criterion, optimizer, seed=seed, name=name,
                            train_labeled_dataset=train_labeled_dataset,
                            batch_size=batch_size,
                            num_epochs=epochs,
                            fewshot=fewshot,
                            augment=augment, autoaugment=autoaugment, aug_param=aug_param, test_data=test_data,
                            start_epoch=start_epoch, cutout=cutout, random_erase=random_erase, shot=shot)

    acc_1_valid, acc_5_valid = accuracy(classifier, train_labeled_dataset, batch_size=batch_size, aug_param=aug_param)
    fewshot_acc_1_test, fewshot_acc_5_test = accuracy(classifier, test_data, batch_size=batch_size,
                                                      aug_param=aug_param_test)

    print('fewshot_acc_1_valid accuracy@1', acc_1_valid)
    print('fewshot_acc_5_valid accuracy@1', acc_5_valid)

    print('fewshot_acc_1_test accuracy@1', fewshot_acc_1_test)
    print('fewshot_acc_5_test accuracy@5', fewshot_acc_5_test)
    logger.append([acc_1_valid, acc_5_valid, fewshot_acc_1_test, fewshot_acc_5_test])
    scores_test_acc1_fewshot[seed] = fewshot_acc_1_test
    scores_test_acc5_fewshot[seed] = fewshot_acc_5_test

    w = csv.writer(open(f"baseline_shot_{shot}_acc1.csv", "w+"))
    for key, val in scores_test_acc1_fewshot.items():
        w.writerow([key, val])
    w = csv.writer(open(f"baseline_shot_{shot}_acc5.csv", "w+"))
    for key, val in scores_test_acc5_fewshot.items():
        w.writerow([key, val])
    logger.close()
Ejemplo n.º 5
0
def main():
    global netG, netZ, Zs_real, aug_param, criterion, aug_param_test
    global best_acc
    print(args)
    if args.data == 'cifar10':
        print("cifar10")
        aug_param = aug_param_test = get_cifar10_param()
        root = '/cs/dataset/CIFAR/'
        classes = 10
        batch_size = min(args.batch_size, 128)
        normalize = transforms.Normalize(mean=aug_param['mean'],
                                         std=aug_param['std'])
        image_size = 32
        transform_train = transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomCrop(image_size, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            # RandomPadandCrop(32),
            # RandomFlip(),
            # ToTensor(),
            normalize,
        ])

        transform_val = transforms.Compose([transforms.ToTensor(), normalize])
        train_labeled_set, train_unlabeled_set, val_set, test_set = get_cifar10(
            root,
            args.n_labeled,
            args.n_unlabeled,
            transform_train=transform_train,
            transform_val=transform_val)
    elif args.data == 'cifar':
        aug_param = aug_param_test = get_cifar100_param()

        root = '/cs/dataset/CIFAR/'
        classes = 100
        batch_size = min(args.batch_size, 128)
        normalize = transforms.Normalize(mean=aug_param['mean'],
                                         std=aug_param['std'])
        image_size = 32
        transform_train = transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomCrop(image_size, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            # RandomPadandCrop(32),
            # RandomFlip(),
            # ToTensor(),
            normalize,
        ])

        transform_val = transforms.Compose([transforms.ToTensor(), normalize])
        train_labeled_set, train_unlabeled_set, val_set, test_set = get_cifar100(
            root,
            args.n_labeled,
            args.n_unlabeled,
            transform_train=transform_train,
            transform_val=transform_val)
    elif args.data == 'cub':
        split_file = None
        samples_per_class = int(args.n_labeled)
        split_file = f'train_test_split_{samples_per_class}.txt'
        # train_repeats = 30 // samples_per_class
        classes = 200
        aug_param = get_cub_param()
        image_size = aug_param['rand_crop']
        print(aug_param)
        normalize = transforms.Normalize(mean=aug_param['mean'],
                                         std=aug_param['std'])

        transform_train = transforms.Compose([
            transforms.RandomResizedCrop(aug_param['rand_crop'],
                                         scale=(0.875, 1.)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ToTensor(), normalize
        ])

        transform_val = transforms.Compose([
            transforms.Resize(aug_param['image_size']),
            transforms.CenterCrop(aug_param['rand_crop']),
            transforms.ToTensor(), normalize
        ])

        batch_size = min(args.batch_size, 32)
        train_labeled_set = Cub2011(root=f"../data/{args.data}",
                                    train=True,
                                    split_file=split_file,
                                    transform=transform_train)
        train_unlabeled_set = Cub2011(
            root=f"../data/{args.data}",
            train=False,
            split_file="train_test_split_30.txt",
            transform=TransformTwice(transform_train))
        test_set = Cub2011(root=f"../data/{args.data}",
                           train=False,
                           transform=transform_val)

    train_data_size = len(train_labeled_set)
    print(f"train_data size:{train_data_size}")
    print(f"test_data size:{len(test_set)}")
    print(f"train_unlabeled_set data size:{len(train_unlabeled_set)}")
    labeled_trainloader_2 = data.DataLoader(train_labeled_set,
                                            batch_size=batch_size,
                                            shuffle=True,
                                            num_workers=0,
                                            drop_last=True)
    labeled_trainloader = get_loader_with_idx(train_labeled_set,
                                              **aug_param,
                                              batch_size=batch_size,
                                              augment=None,
                                              drop_last=True)
    # unlabeled_trainloader = data.DataLoader(train_unlabeled_set, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=True)
    offset_ = len(train_labeled_set)
    # offset_ += len(test_set) # Uncommenet for transdutive mode
    unlabeled_trainloader = get_loader_with_idx(train_unlabeled_set,
                                                **aug_param,
                                                batch_size=batch_size,
                                                augment=None,
                                                drop_last=True,
                                                offset_idx=offset_)
    # val_loader = data.DataLoader(val_set, batch_size=args.batch_size, shuffle=False, num_workers=0)
    test_loader = data.DataLoader(test_set,
                                  batch_size=args.batch_size,
                                  shuffle=False,
                                  num_workers=0)
    if not os.path.isdir(args.out):
        mkdir_p(args.out)

    # Model

    def create_model(ema=False):
        # print("==> creating WRN-28-2")
        # classifier = WideResNet2(num_classes=100)
        print("==> creating resnet50")
        classifier = get_classifier(classes, "resnet50", True)
        num_gpus = torch.cuda.device_count()
        if num_gpus > 1:
            print(f"=> Using {num_gpus} GPUs")
            classifier = nn.DataParallel(classifier.cuda(),
                                         device_ids=list(
                                             range(num_gpus))).cuda()
            cudnn.benchmark = True
        else:
            classifier = classifier.cuda()
        if ema:
            for param in classifier.parameters():
                param.detach_()

        return classifier

    classifier = create_model()
    ema_classifier = create_model(ema=True)

    # Loading  pretrained glico_model
    keyword = args.keyword
    print('    Total params: %.2fM' %
          (sum(p.numel() for p in classifier.parameters()) / 1000000.0))

    train_criterion = SemiLoss()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(classifier.parameters(), lr=args.lr)

    ema_optimizer = WeightEMA(classifier, ema_classifier, alpha=args.ema_decay)
    start_epoch = 0

    # Resume
    title = 'cifar-10'
    if args.resume:
        # Load checkpoint.
        print('==> Resuming from checkpoint..')
        assert os.path.isfile(
            args.resume), 'Error: no checkpoint directory found!'
        args.out = os.path.dirname(args.resume)
        checkpoint = torch.load(args.resume)
        best_acc = checkpoint['best_acc']
        start_epoch = checkpoint['epoch']
        classifier.load_state_dict(checkpoint['state_dict'])
        ema_classifier.load_state_dict(checkpoint['ema_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        logger = Logger(os.path.join(args.out, 'log.txt'),
                        title=title,
                        resume=True)
    else:
        logger = Logger(os.path.join(args.out, 'log.txt'), title=title)
        logger.set_names([
            'Train Loss', 'Train Loss X', 'Train Loss U', 'Valid Loss',
            'Valid Acc.', 'Test Loss', 'Test Acc.'
        ])

    step = 0
    test_accs = []
    # Train and val
    for epoch in range(start_epoch, args.epochs):
        print('\nEpoch: [%d | %d] LR: %f' %
              (epoch + 1, args.epochs, state['lr']))

        train_loss, train_loss_x, train_loss_u = train(
            labeled_trainloader, unlabeled_trainloader, classifier, optimizer,
            ema_optimizer, train_criterion, epoch, use_cuda, normalize,
            classes)
        _, train_acc = validate(labeled_trainloader_2,
                                ema_classifier,
                                criterion,
                                epoch,
                                use_cuda,
                                mode='Train Stats')
        val_loss, val_acc = validate(test_loader,
                                     ema_classifier,
                                     criterion,
                                     epoch,
                                     use_cuda,
                                     mode='Valid Stats')
        test_loss, test_acc = validate(test_loader,
                                       ema_classifier,
                                       criterion,
                                       epoch,
                                       use_cuda,
                                       mode='Test Stats ')

        step = args.val_iteration * (epoch + 1)

        # append logger file
        logger.append([
            train_loss, train_loss_x, train_loss_u, val_loss, val_acc,
            test_loss, test_acc
        ])

        # save classifier
        is_best = val_acc > best_acc
        best_acc = max(val_acc, best_acc)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': classifier.state_dict(),
                'ema_state_dict': ema_classifier.state_dict(),
                'acc': val_acc,
                'best_acc': best_acc,
                'optimizer': optimizer.state_dict(),
            }, is_best)
        test_accs.append(test_acc)
    logger.close()

    print('Best acc:')
    print(best_acc)

    print('Mean acc:')
    print(np.mean(test_accs[-20:]))
def run_eval_gan(epochs=200,
                 d="wideresnet",
                 fewshot=False,
                 augment=False,
                 shot=50,
                 unlabeled_shot=1,
                 data="cifar",
                 pretrained=False,
                 seed=0):
    global aug_param, criterion, aug_param_test, batch_size
    manual_seed(seed)
    title = f"GAN_{d}"
    if fewshot:
        title = title + "_fewshot"
    dir = "eval"
    if not os.path.isdir(dir):
        os.mkdir(dir)
    data_dir = '../../data'
    cifar_dir_cs = '/cs/dataset/CIFAR/'
    if data == "cifar":
        classes = 100
        lr = 0.1
        batch_size = 128
        WD = 5e-4
        aug_param = aug_param_test = get_cifar_param()
        if not fewshot:
            train_labeled_dataset = torchvision.datasets.CIFAR100(
                root=data_dir, train=True, download=True)
            train_unlabeled_dataset = torchvision.datasets.CIFAR100(
                root=data_dir, train=False, download=True)
            test_data = train_unlabeled_dataset
        else:
            print("=> Fewshot")
            # train_labeled_dataset, train_unlabeled_dataset = get_cifar100_small(cifar_dir_small, shot)
            train_labeled_dataset, train_unlabeled_dataset, _, test_data = get_cifar100(
                cifar_dir_cs, n_labeled=shot, n_unlabled=unlabeled_shot)

    if data == "cub":
        classes = 200
        batch_size = 16
        lr = 0.001
        WD = 1e-5
        aug_param = aug_param_test = get_cub_param()
        split_file = None
        if fewshot:
            samples_per_class = int(shot)
            split_file = 'train_test_split_{}.txt'.format(samples_per_class)
        train_labeled_dataset = Cub2011(root=f"../../data/{data}",
                                        train=True,
                                        split_file=split_file)
        train_unlabeled_dataset = Cub2011(root=f"../../data/{data}",
                                          train=False,
                                          split_file=split_file)
        test_data = Cub2011(root=f"../../data/{data}", train=False)

    train_data_size = len(train_labeled_dataset)
    print(f"train_labeled_dataset size:{train_data_size}")
    print(f"test_data size:{len(test_data)}")
    print(f"transductive data size:{len(train_unlabeled_dataset)}")

    set_seed(seed)
    generator = Generator()
    discriminator = Discriminator()
    classifier = get_classifier(classes, d, pretrained)
    epoch = 180
    generator, _ = load_model(epoch, generator, discriminator)
    generator = generator.cuda()
    optimizer = optim.SGD(classifier.parameters(),
                          lr,
                          momentum=0.9,
                          weight_decay=WD,
                          nesterov=True)
    print("=> Train new classifier")
    criterion = nn.CrossEntropyLoss().cuda()
    num_gpus = torch.cuda.device_count()
    if num_gpus > 1:
        print(f"=> Using {num_gpus} GPUs")
        classifier = nn.DataParallel(classifier).cuda()
        # cudnn.benchmark = True
    else:
        classifier = classifier.cuda()

    print(' => Total params: %.2fM' %
          (sum(p.numel() for p in classifier.parameters()) / 1000000.0))
    print(f"=> {d}  Training model")
    print(f"=> Training Epochs = {str(epochs)}")
    print(f"=> Initial Learning Rate = {str(lr)}")
    generic_train_classifier(generator,
                             classifier,
                             optimizer,
                             train_labeled_dataset,
                             batch_size=batch_size,
                             num_epochs=epochs,
                             augment=augment,
                             fewshot=fewshot,
                             test_data=test_data,
                             n_classes=classes,
                             aug_param=aug_param,
                             aug_param_test=aug_param_test,
                             shot=shot)

    print("=> Done training classifier")

    valid_acc_1, valid_acc_5 = accuracy(classifier,
                                        train_labeled_dataset,
                                        batch_size=batch_size,
                                        aug_param=aug_param_test)
    test_acc_1, test_acc_5 = accuracy(classifier,
                                      test_data,
                                      batch_size=batch_size,
                                      aug_param=aug_param_test)

    print('train_acc accuracy@1', valid_acc_1)
    print('train_acc accuracy@5', valid_acc_5)
    print('test accuracy@1', test_acc_1)
    print('test accuracy@5', test_acc_5)
Ejemplo n.º 7
0
        train_labeled_dataset = torchvision.datasets.CIFAR100(root=data_dir,
                                                              train=True,
                                                              download=True)
        train_unlabeled_dataset = torchvision.datasets.CIFAR100(root=data_dir,
                                                                train=False,
                                                                download=True)
if dataset == 'cub':
    split_file = None
    if args.fewshot:
        samples_per_class = int(args.shot)
        split_file = 'train_test_split_{}.txt'.format(samples_per_class)
    # train_repeats = 30 // samples_per_class
    classes = 200
    batch_size = min(args.batch_size, 32)
    train_labeled_dataset = Cub2011(
        root=f"/cs/labs/daphna/idan.azuri/data/cub",
        train=True,
        split_file=split_file)
    train_unlabeled_dataset = Cub2011(
        root=f"/cs/labs/daphna/idan.azuri/data/cub",
        train=False,
        split_file=split_file)
    test_dataset = []
if dataset == "stl":
    print("STL-10")
    classes = 10
    batch_size = min(args.batch_size, 32)
    train_labeled_dataset = torchvision.datasets.STL10(
        root=f"../../data/{dataset}", split='train', download=True)
    train_unlabeled_dataset = torchvision.datasets.STL10(
        root=f"../../data/{dataset}", split='unlabeled', download=True)
    train_unlabeled_dataset = torch.utils.data.Subset(train_unlabeled_dataset,
        testset = torchvision.datasets.CIFAR10(
            root='../data',
            train=False,
            download=False,
            transform=torchvision.transforms.ToTensor())

    elif cfg['dataset'] == 'cub':

        transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize([224, 224],
                                          interpolation=PIL.Image.BICUBIC),
            torchvision.transforms.ToTensor()
        ])

        trainset = Cub2011(root='../data', train=True, transform=transform)
        trainset.class_to_idx = range(200)
        testset = Cub2011(root='../data', train=False, transform=transform)

    elif cfg['dataset'] == 'cars':

        transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize([224, 224],
                                          interpolation=PIL.Image.BICUBIC),
            torchvision.transforms.ToTensor()
        ])

        trainset = torchvision.datasets.ImageFolder(
            root='../data/stanford-cars/car_data/car_data/train/',
            transform=transform)
        testset = torchvision.datasets.ImageFolder(