Ejemplo n.º 1
0
def load_real_middle_points(data_type, num_mnist_train_for_GAN, image_size_ref,
                            distance_type, show_image):
    mnist_compose_augment = transforms.Compose([
        transforms.Resize((image_size_ref, image_size_ref)),
        transforms.Grayscale(num_output_channels=1),
        transforms.ToTensor(),
        transforms.Normalize((0.1307, ), (0.3081, ))
    ])
    if distance_type == 'Linf':
        if num_mnist_train_for_GAN == 2000:
            MNIST_AUG_DATA_PATH = "./data/mnist_middle/arxiv/Linf/May_10_dual_direction"  #   400 (7->9) + 400 (9->7) points
        elif num_mnist_train_for_GAN == 500:
            MNIST_AUG_DATA_PATH = "./data/mnist_middle/arxiv/Linf(500_points)/May_16_all"  #   3100 (start from middle, May.16)
    elif distance_type == 'L2':
        MNIST_AUG_DATA_PATH = "./data/mnist_middle/arxiv/L2/May_11_dual_direction"  #   200 (7->9) + 200 (9->7) points
    elif distance_type == 'Linf+L2':
        MNIST_AUG_DATA_PATH = "./data/mnist_middle/arxiv/Linf+L2/May_10_11"  #   600 (7->9) + 600 (9->7) points (= May_10_dual + May_11_dual)

    mnist_train_aug = ImageFolder(root=MNIST_AUG_DATA_PATH,
                                  transform=mnist_compose_augment)
    mnist_train_aug.targets = torch.tensor(mnist_train_aug.targets)
    num_augment = len(mnist_train_aug)

    return mnist_train_aug, num_augment
Ejemplo n.º 2
0
    def get_ham10000_gray_normal(dataset_path, batch_size):
        ####MY IMAGES FOLDER
        #### CONSTANTS
        transform = transforms.Compose([
            transforms.Grayscale(num_output_channels=1),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=(0.5), std=(0.5)
            )  #transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])
        dataset = ImageFolder(root=dataset_path, transform=transform)

        print("BEFORE total samples = " + str(len(dataset.samples)))
        print("BEFORE total labels = " + str(len(dataset.targets)))
        dataset_size = len(dataset)
        classes = dataset.classes
        num_classes = len(dataset.classes)
        img_dict = {}
        for i in range(num_classes):
            img_dict[classes[i]] = 0

        for i in range(dataset_size):
            img, label = dataset[i]
            img_dict[classes[label]] += 1
        print(img_dict)

        #Oversampling INIT
        x_train = dataset.samples
        y_train = dataset.targets
        sampling_seed = 0

        print("before")
        print(type(x_train))
        print(type(x_train[0]))
        print(type(y_train))
        print(type(y_train[0]))
        print(x_train[0])
        print(y_train[0])

        from imblearn.over_sampling import RandomOverSampler
        sampler = RandomOverSampler(random_state=sampling_seed)
        #print(type(dataset.samples[0]))
        #print(type(dataset.targets[0]))
        x_train, y_train = sampler.fit_resample(x_train, y_train)
        x_train = list(map(lambda x: tuple(x), x_train))
        #y_train = y_train.tolist()
        print("\n\nafter")
        print(type(x_train))
        print(type(x_train[0]))
        print(type(y_train))
        print(type(y_train[0]))
        print(x_train[0])
        print(y_train[0])

        dataset.samples = x_train
        dataset.targets = y_train
        #Oversampling END

        print("total samples = " + str(len(dataset.samples)))
        print("total labels = " + str(len(dataset.targets)))
        dataset_size = len(dataset)
        classes = dataset.classes
        num_classes = len(dataset.classes)
        img_dict = {}
        for i in range(num_classes):
            img_dict[classes[i]] = 0

        for i in range(dataset_size):
            img, label = dataset[i]
            img_dict[classes[int(label)]] += 1
        print(img_dict)
        #exit(0)
        ### NEW TRANSFORMS INIT
        #dataset.transform = transforms.Compose([ #transforms.Resize((input_size,input_size)),
        #                              transforms.Grayscale(num_output_channels=1),
        #                              transforms.RandomHorizontalFlip(),
        #                              transforms.RandomVerticalFlip(),transforms.RandomRotation(20),
        #                              transforms.ColorJitter(brightness=0.1, contrast=0.1, hue=0.1),
        #                                transforms.ToTensor(), transforms.Normalize(0.5, 0.5)])
        #
        ### NEW TRANSFORMS END

        data_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=batch_size,
            num_workers=8,  #num_workers=0,
            pin_memory=True,  # TODO -> Cuidado puede dar error
            shuffle=True)
        return data_loader
Ejemplo n.º 3
0
def main(
    backbone,
    dim_z,
    data_root,
    resize_shape,
    opt_name,
    lr,
    margin,
    n_charts,
    n_epochs,
    n_classes,
    n_per_class,
    save_ckpt_freq,
    eval_freq,
    reg_loss_weight,
    q_loss_weight,
    use_wandb,
    out_path,
    num_workers,
    val_or_test,
):
    assert val_or_test in ["val", "test"]

    out_path = wandb.run.dir if use_wandb else out_path

    if n_charts is not None:
        net = TorchvisionMfldEmbed(backbone, dim_z, n_charts, pretrained=True)
    else:
        # dimension is dim_z + 1 since it gets normalized to the unit sphere
        net = TorchvisionEmbed(backbone, dim_z + 1, pretrained=True)

    train_transform = Compose([
        RandomResizedCrop(resize_shape),
        RandomHorizontalFlip(),
        ToTensor(),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    val_transform = Compose([
        Resize((resize_shape, resize_shape)),
        ToTensor(),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    train_dset = ImageFolder(root=data_root, transform=train_transform)
    val_dset = ImageFolder(root=data_root, transform=val_transform)
    targets = train_dset.targets

    if val_or_test == "val":
        # take first half for train and val, second half for holdout test
        train_indices = [
            i for i in range(len(train_dset))
            if train_dset.samples[i][1] < 0.8 * len(os.listdir(data_root)) / 2
        ]
        val_indices = [
            i for i in range(len(train_dset))
            if train_dset.samples[i][1] < len(os.listdir(data_root)) /
            2 and i not in train_indices
        ]
    else:
        # split classes in half for train and val
        train_indices = [
            i for i in range(len(train_dset))
            if train_dset.samples[i][1] < len(os.listdir(data_root)) / 2
        ]
        val_indices = [
            i for i in range(len(train_dset)) if i not in train_indices
        ]

    train_dset = Subset(train_dset, train_indices)
    train_dset.targets = [targets[i] for i in train_indices]
    val_dset = Subset(val_dset, val_indices)
    val_dset.targets = [targets[i] for i in val_indices]

    if data_root == "data/CUB_200_2011/images/":
        assert (len(np.unique(train_dset.targets)) +
                len(np.unique(val_dset.targets)) == 100
                if val_or_test == "val" else 200)
    else:
        assert (len(np.unique(train_dset.targets)) +
                len(np.unique(val_dset.targets)) == 98
                if val_or_test == "val" else 196)

    batch_sampler = BalancedClassBatchSampler(train_dset, n_classes,
                                              n_per_class)

    opt = getattr(optim, opt_name)(params=net.parameters(), lr=lr)

    train_data_loader = DataLoader(train_dset,
                                   batch_sampler=batch_sampler,
                                   num_workers=num_workers)
    val_data_loader = DataLoader(val_dset,
                                 batch_size=n_classes * n_per_class,
                                 num_workers=num_workers)

    if n_charts is not None:
        trainer = ManifoldTripletTrainer(
            net=net,
            opt=opt,
            dim_z=dim_z,
            reg_loss_weight=reg_loss_weight,
            q_loss_weight=q_loss_weight,
            out_path=out_path,
            data_loader=train_data_loader,
            eval_data_loader=val_data_loader,
            margin=margin,
            one_hot_q=True,
            use_wandb=use_wandb,
        )
    else:
        trainer = TripletTrainer(
            net=net,
            opt=opt,
            out_path=out_path,
            data_loader=train_data_loader,
            eval_data_loader=val_data_loader,
            margin=margin,
            use_wandb=use_wandb,
        )

    trainer.train(n_epochs=n_epochs,
                  save_ckpt_freq=save_ckpt_freq,
                  eval_freq=eval_freq)
Ejemplo n.º 4
0
                         transforms.Resize(72),
                         transforms.CenterCrop(64),
                         transforms.ToTensor(),
                     ]))
tst_ds = ImageFolder("./data/RestrictedImgNet/val",
                     transform=transforms.Compose([
                         transforms.Resize(72),
                         transforms.CenterCrop(64),
                         transforms.ToTensor(),
                     ]))

tst_dists = torch.ones((len(tst_ds), n_classes)).float()
batch_size = 256
############# for randomized labeling experiment
np.random.seed(0)
trn_ds.targets = np.random.choice(np.arange(9), size=len(trn_ds.targets))
for i in range(len(trn_ds.imgs)):
    trn_ds.imgs[i] = (trn_ds.imgs[i][0], trn_ds.targets[i])
tst_ds.targets = np.random.choice(np.arange(9), size=len(tst_ds.targets))
for i in range(len(tst_ds.imgs)):
    tst_ds.imgs[i] = (tst_ds.imgs[i][0], tst_ds.targets[i])
import ipdb
ipdb.set_trace()
############
trn_loader = torch.utils.data.DataLoader(trn_ds,
                                         batch_size=batch_size,
                                         shuffle=False,
                                         num_workers=16)
tst_loader = torch.utils.data.DataLoader(tst_ds,
                                         batch_size=batch_size,
                                         shuffle=False,