Beispiel #1
0
def load_cifar10_data(datadir):
    train_transform, test_transform = _data_transforms_cifar10()

    cifar10_train_ds = CIFAR10_truncated(datadir, train=True, download=True, transform=train_transform)
    cifar10_test_ds = CIFAR10_truncated(datadir, train=False, download=True, transform=test_transform)

    X_train, y_train = cifar10_train_ds.data, cifar10_train_ds.target
    X_test, y_test = cifar10_test_ds.data, cifar10_test_ds.target

    return (X_train, y_train, X_test, y_test)
Beispiel #2
0
def load_cifar10_data(datadir):

    transform = transforms.Compose([transforms.ToTensor()])

    cifar10_train_ds = CIFAR10_truncated(datadir, train=True, download=True, transform=transform)
    cifar10_test_ds = CIFAR10_truncated(datadir, train=False, download=True, transform=transform)

    X_train, y_train = cifar10_train_ds.data, cifar10_train_ds.target
    X_test, y_test = cifar10_test_ds.data, cifar10_test_ds.target

    return (X_train, y_train, X_test, y_test)
Beispiel #3
0
def get_dataloader(datadir, train_bs, test_bs, dataidxs=None):
    # transform = transforms.Compose([transforms.ToTensor()])

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010))
    ])

    train_ds = CIFAR10_truncated(datadir,
                                 dataidxs=dataidxs,
                                 train=True,
                                 transform=transform,
                                 download=True)
    test_ds = CIFAR10_truncated(datadir,
                                train=False,
                                transform=transform,
                                download=True)

    if args.resampling_balance:
        # Oversample the minority classes
        train_sampler = resampling_balance(train_ds)
        train_dl = data.DataLoader(dataset=train_ds,
                                   sampler=train_sampler,
                                   batch_size=train_bs,
                                   num_workers=2)
        test_sampler = resampling_balance(test_ds)
        test_dl = data.DataLoader(dataset=test_ds,
                                  sampler=test_sampler,
                                  batch_size=test_bs,
                                  shuffle=False,
                                  num_workers=2)

        print("Resampling Works!")
        targets = train_dl.dataset.target
        class_count = np.unique(targets, return_counts=True)[1]
        print("Data training after resample:", class_count)

        targets_test = test_dl.dataset.target
        class_count = np.unique(targets_test, return_counts=True)[1]
        print("Data test after resample:", class_count)
    else:
        train_dl = data.DataLoader(dataset=train_ds,
                                   batch_size=train_bs,
                                   shuffle=True,
                                   num_workers=2)
        test_dl = data.DataLoader(dataset=test_ds,
                                  batch_size=test_bs,
                                  shuffle=False,
                                  num_workers=2)
        print("Resampling Turned Off!")

    return train_dl, test_dl