Esempio n. 1
0
def multi_indicator_dataset(dataset, num, limit, class_object, args):
    if (dataset == 'MNIST'):
        print("Loading {}-multi-indicator for MNIST dataset...".format(num))
        data_train = MNIST(root='./data',
                           train=True,
                           download=True,
                           transform=transform)
        data_test = MNIST(root='./data',
                          train=False,
                          download=True,
                          transform=transform)
    elif (dataset == 'fMNIST'):
        print("Loading full Fashion-MNIST dataset...")
        data_train = FashionMNIST(root='./data',
                                  train=True,
                                  download=True,
                                  transform=transform)
        data_test = FashionMNIST(root='./data',
                                 train=False,
                                 download=True,
                                 transform=transform)
    else:
        print("Loading full QuickDraw! dataset...")

    if (dataset == 'MNIST' or dataset == 'fMNIST'):
        # train batch
        idx = (data_train.targets < limit)
        data_train.targets = data_train.targets[idx]
        data_train.data = data_train.data[idx]

        idx = 1
        for i in num:
            data_train.targets[data_train.targets == i] = 10 + idx
            print('adding...')
            idx += 1

        data_train.targets[data_train.targets < 10] = 0
        data_train.targets[data_train.targets > 10] -= 10

        idx_0 = (data_train.targets == 0)
        idx_1 = (data_train.targets != 0)
        sum_idx_0 = 0
        total = sum(idx_1) // len(num)

        for i in range(len(idx_0)):
            sum_idx_0 += idx_0[i]

            if sum_idx_0 == total:
                idx_0[i + 1:] = False
                break

        idx = idx_0 + idx_1
        print(sum(idx))
        data_train.targets = data_train.targets[idx]
        data_train.data = data_train.data[idx]

        train_label = data_train.targets.cpu().detach().numpy()
        trainloader = DataLoader(data_train,
                                 batch_size=args.batch_size_train,
                                 shuffle=True)

        # test batch
        idx = (data_test.targets < limit)
        data_test.targets = data_test.targets[idx]
        data_test.data = data_test.data[idx]

        idx = 1
        for i in num:
            data_test.targets[data_test.targets == i] = 10 + idx
            idx += 1

        data_test.targets[data_test.targets < 10] = 0
        data_test.targets[data_test.targets > 10] -= 10

        idx_0 = (data_test.targets == 0)
        idx_1 = (data_test.targets != 0)
        sum_idx_0 = 0
        total = sum(idx_1) // len(num)
        print(sum(idx_1))
        # total = 843

        for i in range(len(idx_0)):
            sum_idx_0 += idx_0[i]

            if sum_idx_0 == total:
                idx_0[i + 1:] = False
                break

        idx = idx_0 + idx_1
        print(sum(idx))
        data_test.targets = data_test.targets[idx]
        data_test.data = data_test.data[idx]

        test_label = data_test.targets.cpu().detach().numpy()
        testloader = DataLoader(data_test,
                                batch_size=args.batch_size_test,
                                shuffle=False)

    return trainloader, testloader, train_label, test_label
#dataloaderzeros = DataLoader(
#    MNIST('./data', train=True, download=True, transform=img_transform),
#    batch_size=batch_size, shuffle=True, collate_fn = my_collate)
print('Making each number training set')
datasets = []
for i in range(10):
    dataset = MNIST('./data',
                    transform=img_transform,
                    download=True,
                    train=True)
    #print(dataset)
    #[0:5851]
    idx = dataset.targets == i
    dataset.targets = dataset.targets[idx]
    dataset.data = dataset.data[idx]
    dataset = torch.utils.data.random_split(
        dataset,
        [num_datapoints, len(dataset) - num_datapoints])[0]
    #dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    datasets.append(dataset)
    #print(len(dataset))

print('Making each number combined with 0 by percent predicter set')
VAE_dataloaders_w_zeros = []
for i in range(10):
    numbersets = []
    for j in range(11):
        numberzeros = int(num_datapoints * (10 - j) / 10)
        numbervals = int(num_datapoints * j / 10)
        #print(numberzeros)
Esempio n. 3
0
def indicator_dataset(dataset, num, limit, class_object, args):
    if (dataset == 'MNIST'):
        print("Loading {}-indicator for MNIST dataset...".format(num))
        data_train = MNIST(root='./data',
                           train=True,
                           download=True,
                           transform=transform)
        data_test = MNIST(root='./data',
                          train=False,
                          download=True,
                          transform=transform)
    elif (dataset == 'fMNIST'):
        print("Loading full Fashion-MNIST dataset...")
        data_train = FashionMNIST(root='./data',
                                  train=True,
                                  download=True,
                                  transform=transform)
        data_test = FashionMNIST(root='./data',
                                 train=False,
                                 download=True,
                                 transform=transform)
    else:
        print("Loading full QuickDraw! dataset...")
        train_data = []
        train_label = []
        test_data = []
        test_label = []
        for i in range(len(class_object)):
            # load npy file and concatenate data
            ob = np.load('./data/quickdraw/full_numpy_bitmap_' +
                         class_object[i] + '.npy')
            # choose train size and test size
            train = ob[0:5000, ]
            test = ob[5000:6000, ]
            train_label = np.concatenate(
                (train_label, i * np.ones(train.shape[0])), axis=0)
            test_label = np.concatenate(
                (test_label, i * np.ones(test.shape[0])), axis=0)

            if i == 0:
                train_data = train
                test_data = test
            else:
                train_data = np.concatenate((train_data, train), axis=0)
                test_data = np.concatenate((test_data, test), axis=0)

        train_label[train_label != num] = -1
        train_label[train_label == num] = 1
        train_label[train_label == -1] = 0

        test_label[test_label != num] = -1
        test_label[test_label == num] = 1
        test_label[test_label == -1] = 0

        # generate dataloader
        trainset = feature_Dataset(train_data, train_label.astype(int),
                                   transform)
        trainloader = DataLoader(trainset,
                                 batch_size=args.batch_size_train,
                                 shuffle=True)

        testset = feature_Dataset(test_data, test_label.astype(int), transform)
        testloader = DataLoader(testset,
                                batch_size=args.batch_size_test,
                                shuffle=False)

    if (dataset == 'MNIST' or dataset == 'fMNIST'):
        # train batch
        idx = (data_train.targets < limit)
        data_train.targets = data_train.targets[idx]
        data_train.data = data_train.data[idx]

        print("Changing label...")
        data_train.targets[data_train.targets == 1] = 10
        data_train.targets[data_train.targets == 6] = 1
        data_train.targets[data_train.targets == 10] = 6

        for i in num:
            data_train.targets[data_train.targets == i] = 10

        data_train.targets[data_train.targets != 10] = 0
        data_train.targets[data_train.targets == 10] = 1

        idx_0 = (data_train.targets == 0)
        idx_1 = (data_train.targets == 1)
        sum_idx_0 = 0
        total = sum(idx_1)

        for i in range(len(idx_0)):
            sum_idx_0 += idx_0[i]

            if sum_idx_0 == total:
                idx_0[i + 1:] = False
                break

        idx = idx_0 + idx_1
        print(sum(idx))
        data_train.targets = data_train.targets[idx]
        data_train.data = data_train.data[idx]

        train_label = data_train.targets.cpu().detach().numpy()
        trainloader = DataLoader(data_train,
                                 batch_size=args.batch_size_train,
                                 shuffle=True)

        # test batch
        idx = (data_test.targets < limit)
        data_test.targets = data_test.targets[idx]
        data_test.data = data_test.data[idx]

        print("Changing label...")
        data_test.targets[data_test.targets == 1] = 10
        data_test.targets[data_test.targets == 6] = 1
        data_test.targets[data_test.targets == 10] = 6

        for i in num:
            data_test.targets[data_test.targets == i] = 10

        data_test.targets[data_test.targets != 10] = 0
        data_test.targets[data_test.targets == 10] = 1

        idx_0 = (data_test.targets == 0)
        idx_1 = (data_test.targets == 1)
        sum_idx_0 = 0
        print(sum(idx_1))
        # total = sum(idx_1)
        total = 1042

        for i in range(len(idx_0)):
            sum_idx_0 += idx_0[i]

            if sum_idx_0 == total:
                idx_0[i + 1:] = False
                break

        idx = idx_0 + idx_1
        print(sum(idx))
        data_test.targets = data_test.targets[idx]
        data_test.data = data_test.data[idx]

        test_label = data_test.targets.cpu().detach().numpy()
        testloader = DataLoader(data_test,
                                batch_size=args.batch_size_test,
                                shuffle=False)

    return trainloader, testloader, train_label, test_label
Esempio n. 4
0
def load_data(opt):
    """ Load Data

    Args:
        opt ([type]): Argument Parser

    Raises:
        IOError: Cannot Load Dataset

    Returns:
        [type]: dataloader
    """

    ##
    # LOAD DATA SET
    if opt.dataroot == '':
        opt.dataroot = './data/{}'.format(opt.dataset)

    if opt.dataset in ['cifar10']:
        splits = ['train', 'test']
        drop_last_batch = {'train': True, 'test': False}
        shuffle = {'train': True, 'test': False}

        transform = transforms.Compose(
            [
                transforms.Resize(opt.isize),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ]
        )

        classes = {
            'plane': 0, 'car': 1, 'bird': 2, 'cat': 3, 'deer': 4,
            'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9
        }

        dataset = {}
        dataset['train'] = CIFAR10(root=opt.dataroot, train=True, download=True,
                                   transform=transform)
        dataset['test'] = CIFAR10(root=opt.dataroot, train=False, download=True,
                                  transform=transform)

        dataset['train'].train_data, dataset['train'].train_labels, \
        dataset['test'].test_data, dataset['test'].test_labels = get_cifar_anomaly_dataset(
            trn_img=dataset['train'].train_data,
            trn_lbl=dataset['train'].train_labels,
            tst_img=dataset['test'].test_data,
            tst_lbl=dataset['test'].test_labels,
            abn_cls_idx=classes[opt.anomaly_class]
        )

        dataloader = {x: torch.utils.data.DataLoader(dataset=dataset[x],
                                                     batch_size=opt.batchsize,
                                                     shuffle=shuffle[x],
                                                     num_workers=int(opt.workers),
                                                     drop_last=drop_last_batch[x]) for x in splits}
        return dataloader

    elif opt.dataset in ['mnist']:
        opt.anomaly_class = int(opt.anomaly_class)
        # ZJ: set to match mnist channel
        opt.nc = 1

        splits = ['train', 'test']
        drop_last_batch = {'train': True, 'test': False}
        shuffle = {'train': True, 'test': True}

        transform = transforms.Compose(
            [
                transforms.Scale(opt.isize),
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))
            ]
        )

        dataset = {}
        dataset['train'] = MNIST(root=opt.dataroot, train=True, download=True, transform=transform)
        dataset['test'] = MNIST(root=opt.dataroot, train=False, download=True, transform=transform)

        # dataset['train'].train_data, dataset['train'].train_labels, \
        # dataset['test'].test_data, dataset['test'].test_labels = get_mnist_anomaly_dataset(
        dataset['train'].data, dataset['train'].targets, \
        dataset['test'].data, dataset['test'].targets = get_mnist_anomaly_dataset(
            trn_img=dataset['train'].train_data,
            trn_lbl=dataset['train'].train_labels,
            tst_img=dataset['test'].test_data,
            tst_lbl=dataset['test'].test_labels,
            abn_cls_idx=int(opt.anomaly_class)
        )

        dataloader = {x: torch.utils.data.DataLoader(dataset=dataset[x],
                                                     batch_size=opt.batchsize,
                                                     shuffle=shuffle[x],
                                                     num_workers=int(opt.workers),
                                                     drop_last=drop_last_batch[x]) for x in splits}
        return dataloader

    elif opt.dataset in ['mnist2']:
        opt.anomaly_class = int(opt.anomaly_class)

        splits = ['train', 'test']
        drop_last_batch = {'train': True, 'test': False}
        shuffle = {'train': True, 'test': True}

        transform = transforms.Compose(
            [
                transforms.Scale(opt.isize),
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))
            ]
        )

        dataset = {}
        dataset['train'] = MNIST(root=opt.dataroot, train=True, download=True, transform=transform)
        dataset['test'] = MNIST(root=opt.dataroot, train=False, download=True, transform=transform)

        dataset['train'].train_data, dataset['train'].train_labels, \
        dataset['test'].test_data, dataset['test'].test_labels = get_mnist2_anomaly_dataset(
            trn_img=dataset['train'].train_data,
            trn_lbl=dataset['train'].train_labels,
            tst_img=dataset['test'].test_data,
            tst_lbl=dataset['test'].test_labels,
            nrm_cls_idx=opt.anomaly_class,
            proportion=opt.proportion
        )

        dataloader = {x: torch.utils.data.DataLoader(dataset=dataset[x],
                                                     batch_size=opt.batchsize,
                                                     shuffle=shuffle[x],
                                                     num_workers=int(opt.workers),
                                                     drop_last=drop_last_batch[x]) for x in splits}
        return dataloader

    else:
        splits = ['train', 'test']
        drop_last_batch = {'train': True, 'test': False}
        shuffle = {'train': True, 'test': True}
        transform = transforms.Compose([transforms.Scale(opt.isize),
                                        transforms.CenterCrop(opt.isize),
                                        transforms.ToTensor(),
                                        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ])

        dataset = {x: ImageFolder(os.path.join(opt.dataroot, x), transform) for x in splits}
        dataloader = {x: torch.utils.data.DataLoader(dataset=dataset[x],
                                                     batch_size=opt.batchsize,
                                                     shuffle=shuffle[x],
                                                     num_workers=int(opt.workers),
                                                     drop_last=drop_last_batch[x]) for x in splits}
        return dataloader
def main():

    img_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda tensor:min_max_normalization(tensor, 0, 1)),
        transforms.Lambda(lambda tensor:tensor_round(tensor))
    ])
    

    dataset = MNIST('./data', train=True, transform=img_transform, download=True)
    train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    testset = MNIST('./data', train=False, transform=img_transform, download=True)
    testloader = DataLoader(testset, batch_size=batch_size, shuffle=True)
    
    # visualize the distributions of the continuous feature U over 5,000 images
    visuadata =  MNIST('./data', train=False, transform=img_transform, download=True)
    X = dataset.data
    L = np.array(dataset.targets)
    
    first = True
    
    for label in range(10):
        index = np.where(L == label)[0]
    
        N = index.shape[0]
        np.random.seed(0)
        perm = np.random.permutation(N)
        index = index[perm]
    
        data = X[index[0:500]]
        labels = L[index[0:500]]
        if first:
            visualization_L = labels
            visualization_data = data
        else:
            visualization_L = np.concatenate((visualization_L, labels))
            visualization_data = torch.cat((visualization_data, data))
    
    
        first = False
    
        visuadata.data = visualization_data
        visuadata.targets = visualization_L
    
    # Data Loader
    visualization_loader = DataLoader(dataset=visuadata,
                                            batch_size=batch_size,
                                            shuffle=False,
                                            num_workers = 0)      
        
    
    
    model = autoencoder(encode_length=encode_length)
    criterion = nn.BCELoss()
    optimizer = torch.optim.Adam(
        model.parameters(), lr=learning_rate, weight_decay=1e-5)

    
    for epoch in range(num_epochs):
        print('--------training epoch {}--------'.format(epoch))        
        adjust_learning_rate(optimizer, epoch)    
        
        # train the model using SGD        
        for i, (img, _) in enumerate(train_loader):   
            img = img.view(img.size(0), -1)
            img = Variable(img)
  
            # ===================forward=====================
            output, h, b = model(img)
            loss_BCE = criterion(output, img)
            onesvec  =  Variable(torch.ones(h.size(0), 1))  
            Tcode  = torch.transpose(b, 1, 0)
            loss_reg = torch.mean(torch.pow(Tcode.mm(onesvec)/h.size(0), 2))/2
            loss = loss_BCE + Alpha*loss_reg
            # ===================backward====================
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        # Test the Model using testset            
        if (epoch + 1) % 1== 0:       


            '''
            Calculate the mAP over test set            
            '''             

            retrievalB, retrievalL, queryB, queryL = compress(train_loader, testloader, model)            
            result_map = calculate_map(qB=queryB, rB=retrievalB, queryL=queryL, retrievalL=retrievalL)
            print('---{}_mAP: {}---'.format(name, result_map))  
            
          
            
            '''
            visulization of latent variable over 5,000 images
            In this setting, we set encode_length = 3            
            '''
            if encode_length ==3:
                z_buf = list([])
                label_buf = list([])
                for ii, (img, labelb) in enumerate(visualization_loader):
                    img = img.view(img.size(0), -1)
                    img = Variable(img)
                    # ===================forward=====================
                    _, qz, _ = model(img)        
                    z_buf.extend(qz.cpu().data.numpy())
                    label_buf.append(labelb)
                X = np.vstack(z_buf)
                Y = np.hstack(label_buf)
                plot_latent_variable3d(X, Y, epoch, name)   
Esempio n. 6
0
def load_data(config):
    normal_class = config['normal_class']
    batch_size = config['batch_size']
    img_size = config['image_size']

    if config['dataset_name'] in ['cifar10']:
        img_transform = transforms.Compose([
            transforms.ToTensor(),
        ])

        os.makedirs("./train/CIFAR10", exist_ok=True)
        dataset = CIFAR10('./train/CIFAR10',
                          train=True,
                          download=True,
                          transform=img_transform)
        dataset.data = dataset.data[np.array(dataset.targets) == normal_class]
        dataset.targets = [normal_class] * dataset.data.shape[0]

        train_set, val_set = torch.utils.data.random_split(
            dataset, [dataset.data.shape[0] - 851, 851])

        os.makedirs("./test/CIFAR10", exist_ok=True)
        test_set = CIFAR10("./test/CIFAR10",
                           train=False,
                           download=True,
                           transform=img_transform)

    elif config['dataset_name'] in ['mnist']:
        img_transform = transforms.Compose([
            transforms.Resize(img_size),
            transforms.ToTensor(),
        ])

        os.makedirs("./train/MNIST", exist_ok=True)
        dataset = MNIST('./train/MNIST',
                        train=True,
                        download=True,
                        transform=img_transform)
        dataset.data = dataset.data[np.array(dataset.targets) == normal_class]
        dataset.targets = [normal_class] * dataset.data.shape[0]

        train_set, val_set = torch.utils.data.random_split(
            dataset, [dataset.data.shape[0] - 851, 851])

        os.makedirs("./test/MNIST", exist_ok=True)
        test_set = MNIST("./test/MNIST",
                         train=False,
                         download=True,
                         transform=img_transform)

    elif config['dataset_name'] in ['fashionmnist']:
        img_transform = transforms.Compose([
            transforms.Resize(img_size),
            transforms.ToTensor(),
        ])

        os.makedirs("./train/FashionMNIST", exist_ok=True)
        dataset = FashionMNIST('./train/FashionMNIST',
                               train=True,
                               download=True,
                               transform=img_transform)
        dataset.data = dataset.data[np.array(dataset.targets) == normal_class]
        dataset.targets = [normal_class] * dataset.data.shape[0]

        train_set, val_set = torch.utils.data.random_split(
            dataset, [dataset.data.shape[0] - 851, 851])

        os.makedirs("./test/FashionMNIST", exist_ok=True)
        test_set = FashionMNIST("./test/FashionMNIST",
                                train=False,
                                download=True,
                                transform=img_transform)

    elif config['dataset_name'] in ['brain_tumor', 'head_ct']:
        img_transform = transforms.Compose([
            transforms.Resize([img_size, img_size]),
            transforms.Grayscale(num_output_channels=1),
            transforms.ToTensor()
        ])

        root_path = 'Dataset/medical/' + config['dataset_name']
        train_data_path = root_path + '/train'
        test_data_path = root_path + '/test'
        dataset = ImageFolder(root=train_data_path, transform=img_transform)
        load_dataset = DataLoader(dataset, batch_size=batch_size, shuffle=True)
        train_dataset_array = next(iter(load_dataset))[0]
        my_dataset = TensorDataset(train_dataset_array)
        train_set, val_set = torch.utils.data.random_split(
            my_dataset, [train_dataset_array.shape[0] - 5, 5])

        test_set = ImageFolder(root=test_data_path, transform=img_transform)

    elif config['dataset_name'] in ['coil100']:
        img_transform = transforms.Compose([transforms.ToTensor()])

        root_path = 'Dataset/coil100/' + config['dataset_name']
        train_data_path = root_path + '/train'
        test_data_path = root_path + '/test'
        dataset = ImageFolder(root=train_data_path, transform=img_transform)
        load_dataset = DataLoader(dataset, batch_size=batch_size, shuffle=True)
        train_dataset_array = next(iter(load_dataset))[0]
        my_dataset = TensorDataset(train_dataset_array)
        train_set, val_set = torch.utils.data.random_split(
            my_dataset, [train_dataset_array.shape[0] - 5, 5])

        test_set = ImageFolder(root=test_data_path, transform=img_transform)

    elif config['dataset_name'] in ['MVTec']:
        data_path = 'Dataset/MVTec/' + normal_class + '/train'
        data_list = []

        orig_transform = transforms.Compose(
            [transforms.Resize(img_size),
             transforms.ToTensor()])

        orig_dataset = ImageFolder(root=data_path, transform=orig_transform)

        train_orig, val_set = torch.utils.data.random_split(
            orig_dataset, [len(orig_dataset) - 25, 25])
        data_list.append(train_orig)

        for i in range(3):
            img_transform = transforms.Compose([
                transforms.Resize(img_size),
                transforms.RandomAffine(0, scale=(1.05, 1.2)),
                transforms.ToTensor()
            ])

            dataset = ImageFolder(root=data_path, transform=img_transform)
            data_list.append(dataset)

        dataset = ConcatDataset(data_list)

        train_loader = torch.utils.data.DataLoader(dataset,
                                                   batch_size=800,
                                                   shuffle=True)
        train_dataset_array = next(iter(train_loader))[0]
        train_set = TensorDataset(train_dataset_array)

        test_data_path = 'Dataset/MVTec/' + normal_class + '/test'
        test_set = ImageFolder(root=test_data_path, transform=orig_transform)

    train_dataloader = torch.utils.data.DataLoader(
        train_set,
        batch_size=batch_size,
        shuffle=True,
    )

    val_dataloader = torch.utils.data.DataLoader(
        val_set,
        batch_size=batch_size,
        shuffle=True,
    )

    test_dataloader = torch.utils.data.DataLoader(
        test_set,
        batch_size=batch_size,
        shuffle=True,
    )

    return train_dataloader, val_dataloader, test_dataloader
Esempio n. 7
0
def load_data(opt):
    """ Load Data
    Args:
        opt ([type]): Argument Parser
    Raises:
        IOError: Cannot Load Dataset
    Returns:
        [type]: dataloader
    """

    ##
    # LOAD DATA SET
    print(opt.dataset)

    if opt.dataroot == '':
        opt.dataroot = './data/{}'.format(opt.dataset)

    if opt.dataset in ['cifar10']:
        splits = ['train', 'test']
        drop_last_batch = {'train': True, 'test': False}
        shuffle = {'train': True, 'test': False}

        transform = transforms.Compose([
            transforms.Resize(opt.isize),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        classes = {
            'plane': 0,
            'car': 1,
            'bird': 2,
            'cat': 3,
            'deer': 4,
            'dog': 5,
            'frog': 6,
            'horse': 7,
            'ship': 8,
            'truck': 9
        }

        dataset = {}
        dataset['train'] = CIFAR10(root='./CIFAR10',
                                   train=True,
                                   download=True,
                                   transform=transform)
        dataset['test'] = CIFAR10(root='./CIFAR10',
                                  train=False,
                                  download=True,
                                  transform=transform)
        dataset['train'].data, dataset['train'].targets, \
        dataset['test'].data, dataset['test'].targets = get_cifar_anomaly_dataset(
            trn_img=dataset['train'].data,
            trn_lbl=dataset['train'].targets,
            tst_img=dataset['test'].data,
            tst_lbl=dataset['test'].targets,
            abn_cls_idx=classes[opt.abnormal_class],
            manualseed=opt.manualSeed
        )

        dataloader = {
            x: torch.utils.data.DataLoader(
                dataset=dataset[x],
                batch_size=opt.batchsize,
                shuffle=shuffle[x],
                num_workers=int(opt.workers),
                drop_last=drop_last_batch[x],
                worker_init_fn=(None if opt.manualSeed == -1 else
                                lambda x: np.random.seed(opt.manualSeed)))
            for x in splits
        }
        return dataloader

    elif opt.dataset in ['cifarop']:
        print('use cifa10 a noramal and 9 abnormal')

        splits = ['train', 'test', 'val']
        drop_last_batch = {'train': True, 'test': False, 'val': False}
        shuffle = {'train': True, 'test': False, 'val': False}

        transform = transforms.Compose([
            transforms.Resize(opt.isize),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        classes = {
            'plane': 0,
            'car': 1,
            'bird': 2,
            'cat': 3,
            'deer': 4,
            'dog': 5,
            'frog': 6,
            'horse': 7,
            'ship': 8,
            'truck': 9
        }

        dataset = {}
        dataset['train'] = CIFAR10(root='./CIFAR10',
                                   train=True,
                                   download=True,
                                   transform=transform)
        dataset['test'] = CIFAR10(root='./CIFAR10',
                                  train=False,
                                  download=True,
                                  transform=transform)
        dataset['val'] = CIFAR10(root='./CIFAR10',
                                 train=False,
                                 download=True,
                                 transform=transform)

        dataset['train'].data, dataset['train'].targets, \
        dataset['test'].data, dataset['test'].targets, \
        dataset['val'].data, dataset['val'].targets = get_cifar_anomaly_datasetop(
            trn_img=dataset['train'].data,
            trn_lbl=dataset['train'].targets,
            tst_img=dataset['test'].data,
            tst_lbl=dataset['test'].targets,
            abn_cls_idx=classes[opt.abnormal_class],
            manualseed=opt.manualSeed
        )

        dataloader = {
            x: torch.utils.data.DataLoader(
                dataset=dataset[x],
                batch_size=opt.batchsize,
                shuffle=shuffle[x],
                num_workers=int(opt.workers),
                drop_last=drop_last_batch[x],
                worker_init_fn=(None if opt.manualSeed == -1 else
                                lambda x: np.random.seed(opt.manualSeed)))
            for x in splits
        }
        return dataloader

    elif opt.dataset in ['mnist']:
        opt.abnormal_class = int(opt.abnormal_class)

        splits = ['train', 'test']
        drop_last_batch = {'train': True, 'test': False}
        shuffle = {'train': True, 'test': True}

        transform = transforms.Compose([
            transforms.Resize(opt.isize),
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])

        dataset = {}
        dataset['train'] = MNIST(root='./data',
                                 train=True,
                                 download=True,
                                 transform=transform)
        dataset['test'] = MNIST(root='./data',
                                train=False,
                                download=True,
                                transform=transform)

        dataset['train'].data, dataset['train'].targets, \
        dataset['test'].data, dataset['test'].targets = get_mnist_anomaly_dataset(
            trn_img=dataset['train'].data,
            trn_lbl=dataset['train'].targets,
            tst_img=dataset['test'].data,
            tst_lbl=dataset['test'].targets,
            abn_cls_idx=opt.abnormal_class,
            manualseed=opt.manualSeed
        )

        dataloader = {
            x: torch.utils.data.DataLoader(
                dataset=dataset[x],
                batch_size=opt.batchsize,
                shuffle=shuffle[x],
                num_workers=int(opt.workers),
                drop_last=drop_last_batch[x],
                worker_init_fn=(None if opt.manualSeed == -1 else
                                lambda x: np.random.seed(opt.manualSeed)))
            for x in splits
        }

        return dataloader

    elif opt.dataset in ['mnist2']:
        opt.abnormal_class = int(opt.abnormal_class)

        splits = ['train', 'test']
        drop_last_batch = {'train': True, 'test': False}
        shuffle = {'train': True, 'test': True}

        transform = transforms.Compose([
            transforms.Resize(opt.isize),
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])

        dataset = {}
        dataset['train'] = MNIST(root='./data',
                                 train=True,
                                 download=True,
                                 transform=transform)
        dataset['test'] = MNIST(root='./data',
                                train=False,
                                download=True,
                                transform=transform)

        dataset['train'].data, dataset['train'].targets, \
        dataset['test'].data, dataset['test'].targets = get_mnist2_anomaly_dataset(
            trn_img=dataset['train'].data,
            trn_lbl=dataset['train'].targets,
            tst_img=dataset['test'].data,
            tst_lbl=dataset['test'].targets,
            nrm_cls_idx=opt.abnormal_class,
            proportion=opt.proportion,
            manualseed=opt.manualSeed
        )

        dataloader = {
            x: torch.utils.data.DataLoader(
                dataset=dataset[x],
                batch_size=opt.batchsize,
                shuffle=shuffle[x],
                num_workers=int(opt.workers),
                drop_last=drop_last_batch[x],
                worker_init_fn=(None if opt.manualSeed == -1 else
                                lambda x: np.random.seed(opt.manualSeed)))
            for x in splits
        }
        return dataloader

    else:
        splits = ['train', 'test']
        drop_last_batch = {'train': True, 'test': False}
        shuffle = {'train': True, 'test': True}
        transform = transforms.Compose([
            transforms.Resize(opt.isize),
            transforms.CenterCrop(opt.isize),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

        dataset = {
            x: ImageFolder(os.path.join(opt.dataroot, x), transform)
            for x in splits
        }
        dataloader = {
            x: torch.utils.data.DataLoader(
                dataset=dataset[x],
                batch_size=opt.batchsize,
                shuffle=shuffle[x],
                num_workers=int(opt.workers),
                drop_last=drop_last_batch[x],
                worker_init_fn=(None if opt.manualSeed == -1 else
                                lambda x: np.random.seed(opt.manualSeed)))
            for x in splits
        }

        return dataloader
Esempio n. 8
0
        out = F.relu(out)
        out = F.dropout(out, 0.2)
        out = self.fc2(out)
        out = F.relu(out)
        out = F.dropout(out, 0.2)
        out = self.fc3(out)
        if not self.training:
            out = F.softmax(out, dim=1)
        return out


# initalize data for sleeping network
trainset = MNIST(".", train=True, download=True, transform=transform)
testset = MNIST(".", train=False, download=True, transform=transform)
#split the data
trainset.data = trainset.data[0:27105]
trainset.targets = trainset.targets[0:27105]
# # print(trainset.targets[0:11905])
trainloader = DataLoader(trainset, batch_size=128, shuffle=True)
testloader = DataLoader(testset, batch_size=128, shuffle=True)

model = torch.load('save_ann.pkl')
loss_function = nn.CrossEntropyLoss()
# live_loss_plot = LiveLossPlot()
optimiser = optim.SGD(model.parameters(), lr=0.1, momentum=0.5)
# trial = torchbearer.Trial(model, optimiser, loss_function, callbacks=[live_loss_plot], metrics=['loss', 'accuracy']).to(device)
trial = torchbearer.Trial(model,
                          optimiser,
                          loss_function,
                          metrics=['loss', 'accuracy']).to(device)
trial.with_generators(trainloader, test_generator=testloader)
def fetch_dataloaders(args):
    # preprocessing transforms
    transform = T.Compose([
        T.ToTensor(),  # tensor in [0,1]
        lambda x: x.mul(255).div(2**(8 - args.n_bits)).floor(),  # lower bits
        partial(preprocess, n_bits=args.n_bits)
    ])  # to model space [-1,1]
    target_transform = (lambda y: torch.eye(args.n_cond_classes)[y]
                        ) if args.n_cond_classes else None

    if args.dataset == 'mnist':
        args.image_dims = (1, 28, 28)
        train_dataset = MNIST(args.data_path,
                              train=True,
                              transform=transform,
                              target_transform=target_transform)
        valid_dataset = MNIST(args.data_path,
                              train=False,
                              transform=transform,
                              target_transform=target_transform)
    elif args.dataset == 'cifar10':
        args.image_dims = (3, 32, 32)
        train_dataset = CIFAR10(args.data_path,
                                train=True,
                                transform=transform,
                                target_transform=target_transform)
        valid_dataset = CIFAR10(args.data_path,
                                train=False,
                                transform=transform,
                                target_transform=target_transform)
    elif args.dataset == 'colored-mnist':
        args.image_dims = (3, 28, 28)
        # NOTE -- data is quantized to 2 bits and in (N,H,W,C) format
        with open(args.data_path, 'rb'
                  ) as f:  # return dict {'train': np array; 'test': np array}
            data = pickle.load(f)
        # quantize to n_bits to match the transforms for other datasets and construct tensors in shape N,C,H,W
        train_data = torch.from_numpy(
            np.floor(data['train'].astype(np.float32) /
                     (2**(2 - args.n_bits)))).permute(0, 3, 1, 2)
        valid_data = torch.from_numpy(
            np.floor(data['test'].astype(np.float32) /
                     (2**(2 - args.n_bits)))).permute(0, 3, 1, 2)
        # preprocess to [-1,1] and setup datasets -- NOTE using 0s for labels to have a symmetric dataloader
        train_dataset = TensorDataset(preprocess(train_data, args.n_bits),
                                      torch.zeros(train_data.shape[0]))
        valid_dataset = TensorDataset(preprocess(valid_data, args.n_bits),
                                      torch.zeros(valid_data.shape[0]))
    else:
        raise RuntimeError('Dataset not recognized')

    if args.mini_data:  # dataset to a single batch
        if args.dataset == 'colored-mnist':
            train_dataset = train_dataset.tensors[0][:args.batch_size]
        else:
            train_dataset.data = train_dataset.data[:args.batch_size]
            train_dataset.targets = train_dataset.targets[:args.batch_size]
        valid_dataset = train_dataset

    print(
        'Dataset {}\n\ttrain len: {}\n\tvalid len: {}\n\tshape: {}\n\troot: {}'
        .format(args.dataset, len(train_dataset), len(valid_dataset),
                train_dataset[0][0].shape, args.data_path))

    train_dataloader = DataLoader(train_dataset,
                                  args.batch_size,
                                  shuffle=True,
                                  pin_memory=(args.device.type == 'cuda'),
                                  num_workers=4)
    valid_dataloader = DataLoader(valid_dataset,
                                  args.batch_size,
                                  shuffle=False,
                                  pin_memory=(args.device.type == 'cuda'),
                                  num_workers=4)

    # save a sample
    data_sample = next(iter(train_dataloader))[0]
    writer.add_image('data_sample',
                     make_grid(data_sample, normalize=True, scale_each=True),
                     args.step)
    save_image(data_sample,
               os.path.join(args.output_dir, 'data_sample.png'),
               normalize=True,
               scale_each=True)

    return train_dataloader, valid_dataloader
Esempio n. 10
0
def load_data(config):
    normal_class = config['normal_class']
    batch_size = config['batch_size']

    if config['dataset_name'] in ['cifar10']:
        img_transform = transforms.Compose([
            transforms.Resize((256, 256), Image.ANTIALIAS),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                 std=(0.229, 0.224, 0.225))
        ])

        os.makedirs("./Dataset/CIFAR10/train", exist_ok=True)
        dataset = CIFAR10('./Dataset/CIFAR10/train',
                          train=True,
                          download=True,
                          transform=img_transform)
        print("Cifar10 DataLoader Called...")
        print("All Train Data: ", dataset.data.shape)
        dataset.data = dataset.data[np.array(dataset.targets) == normal_class]
        dataset.targets = [normal_class] * dataset.data.shape[0]
        print("Normal Train Data: ", dataset.data.shape)

        os.makedirs("./Dataset/CIFAR10/test", exist_ok=True)
        test_set = CIFAR10("./Dataset/CIFAR10/test",
                           train=False,
                           download=True,
                           transform=img_transform)
        print("Test Train Data:", test_set.data.shape)

    elif config['dataset_name'] in ['mnist']:
        img_transform = transforms.Compose([
            #     transforms.Grayscale(num_output_channels=1),
            transforms.Resize((32, 32)),
            transforms.ToTensor()
            #  transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))]
        ])

        os.makedirs("./Dataset/MNIST/train", exist_ok=True)
        dataset = MNIST('./Dataset/MNIST/train',
                        train=True,
                        download=True,
                        transform=img_transform)
        print("MNIST DataLoader Called...")
        print("All Train Data: ", dataset.data.shape)
        dataset.data = dataset.data[np.array(dataset.targets) == normal_class]
        dataset.targets = [normal_class] * dataset.data.shape[0]
        print("Normal Train Data: ", dataset.data.shape)

        os.makedirs("./Dataset/MNIST/test", exist_ok=True)
        test_set = MNIST("./Dataset/MNIST/test",
                         train=False,
                         download=True,
                         transform=img_transform)
        print("Test Train Data:", test_set.data.shape)

    elif config['dataset_name'] in ['fashionmnist']:
        img_transform = transforms.Compose([
            #     transforms.Grayscale(num_output_channels=1),
            transforms.Resize((32, 32)),
            transforms.ToTensor()
            #  transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))]
        ])

        os.makedirs("./Dataset/FashionMNIST/train", exist_ok=True)
        dataset = FashionMNIST('./Dataset/FashionMNIST/train',
                               train=True,
                               download=True,
                               transform=img_transform)
        print("FashionMNIST DataLoader Called...")
        print("All Train Data: ", dataset.data.shape)
        dataset.data = dataset.data[np.array(dataset.targets) == normal_class]
        dataset.targets = [normal_class] * dataset.data.shape[0]
        print("Normal Train Data: ", dataset.data.shape)

        os.makedirs("./Dataset/FashionMNIST/test", exist_ok=True)
        test_set = FashionMNIST("./Dataset/FashionMNIST/test",
                                train=False,
                                download=True,
                                transform=img_transform)
        print("Test Train Data:", test_set.data.shape)

    elif config['dataset_name'] in ['mvtec']:
        data_path = 'Dataset/MVTec/' + normal_class + '/train'

        mvtec_img_size = config['mvtec_img_size']

        orig_transform = transforms.Compose([
            transforms.Resize([mvtec_img_size, mvtec_img_size]),
            transforms.ToTensor()
        ])

        dataset = ImageFolder(root=data_path, transform=orig_transform)

        test_data_path = 'Dataset/MVTec/' + normal_class + '/test'

        test_set = ImageFolder(root=test_data_path, transform=orig_transform)
    elif config['dataset_name'] in ['retina']:
        data_path = 'Dataset/OCT2017/train'

        orig_transform = transforms.Compose(
            [transforms.Resize([128, 128]),
             transforms.ToTensor()])

        dataset = ImageFolder(root=data_path, transform=orig_transform)

        test_data_path = 'Dataset/OCT2017/test'

        test_set = ImageFolder(root=test_data_path, transform=orig_transform)

    train_dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
    )
    test_dataloader = torch.utils.data.DataLoader(
        test_set,
        batch_size=batch_size,
        shuffle=True,
    )

    return train_dataloader, test_dataloader
Esempio n. 11
0
def main():
    # setup arguments
    parser = utils.ArgParser(description=__doc__)
    arguments.add_default_args(parser)
    arguments.add_exp_identifier_args(parser)
    arguments.add_trainer_args(parser)
    arguments.add_dataset_test_arg(parser)
    args = parser.parse_args()

    # load repository config yaml file to dict
    exp_group, exp_name, config_file = arguments.setup_experiment_identifier_from_args(
        args, EXP_TYPE)
    config = load_yaml_config_file(config_file)

    # update experiment config and dataset path given the script arguments
    config = arguments.update_config_from_args(config, args)
    dataset_path = arguments.update_path_from_args(args)

    # read experiment config dict
    cfg = MLPMNISTExperimentConfig(config)
    if args.print_config:
        print(cfg)

    # set seed
    verb = "Set seed"
    if cfg.random_seed is None:
        cfg.random_seed = np.random.randint(0, 2**15, dtype=np.int32)
        verb = "Randomly generated seed"
    print(f"{verb} {cfg.random_seed} deterministic {cfg.cudnn_deterministic} "
          f"benchmark {cfg.cudnn_benchmark}")
    set_seed(cfg.random_seed,
             cudnn_deterministic=cfg.cudnn_deterministic,
             cudnn_benchmark=cfg.cudnn_benchmark)

    # create datasets
    train_set = MNIST(str(dataset_path),
                      train=True,
                      download=True,
                      transform=ToTensor())
    val_set = MNIST(str(dataset_path),
                    train=False,
                    download=True,
                    transform=ToTensor())

    # make datasets smaller if requested in config
    if cfg.dataset_train.max_datapoints > -1:
        train_set.data = train_set.data[:cfg.dataset_train.max_datapoints]
    if cfg.dataset_val.max_datapoints > -1:
        val_set.data = val_set.data[:cfg.dataset_val.max_datapoints]

    # create dataloaders
    train_loader = create_loader(train_set,
                                 cfg.dataset_train,
                                 batch_size=cfg.train.batch_size)
    val_loader = create_loader(val_set,
                               cfg.dataset_val,
                               batch_size=cfg.val.batch_size)

    if args.test_dataset:
        # run dataset test and exit
        run_mlpmnist_dataset_test(train_set, train_loader)
        return
    print("---------- Setup done!")

    for run_number in range(1, args.num_runs + 1):
        run_name = f"{args.run_name}{run_number}"

        # create model
        model_mgr = MLPModelManager(cfg)

        # always load best epoch during validation
        load_best = args.load_best or args.validate

        # create trainer
        trainer = MLPMNISTTrainer(cfg,
                                  model_mgr,
                                  exp_group,
                                  exp_name,
                                  run_name,
                                  len(train_loader),
                                  log_dir=args.log_dir,
                                  log_level=args.log_level,
                                  logger=None,
                                  print_graph=args.print_graph,
                                  reset=args.reset,
                                  load_best=load_best,
                                  load_epoch=args.load_epoch,
                                  inference_only=args.validate)

        if args.validate:
            # run validation
            trainer.validate_epoch(val_loader)
        else:
            # run training
            trainer.train_model(train_loader, val_loader)

        # done with this round
        trainer.close()
        del model_mgr
        del trainer
Esempio n. 12
0
    split_model = SplitNN([model_part1, model_part2], optims)
    split_model.train()

    # ----- Data -----
    data_transform = transforms.Compose([
        transforms.ToTensor(),
        # PyTorch examples; https://github.com/pytorch/examples/blob/master/mnist/main.py
        transforms.Normalize((0.1307, ), (0.3081, )),
    ])
    train_data = MNIST(data_dir,
                       download=True,
                       train=True,
                       transform=data_transform)

    # We only want to use a subset of the data to force overfitting
    train_data.data = train_data.data[:args.n_train_data]
    train_data.targets = train_data.targets[:args.n_train_data]

    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size)

    # Test data
    test_data = MNIST(data_dir,
                      download=True,
                      train=False,
                      transform=data_transform)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=1024)

    # ----- Train -----
    n_epochs = args.epochs
Esempio n. 13
0
def get_dataset(args,
                config,
                test=False,
                rev=False,
                one_hot=True,
                subset=False,
                shuffle=True):
    total_labels = 10 if config.data.dataset.lower().split(
        '_')[0] != 'cifar100' else 100
    reduce_labels = total_labels != config.n_labels
    if config.data.dataset.lower() in [
            'mnist_transferbaseline', 'cifar10_transferbaseline',
            'fashionmnist_transferbaseline', 'cifar100_transferbaseline'
    ]:
        print('loading baseline transfer dataset')
        rev = True
        test = False
        subset = True
        reduce_labels = True

    if config.data.random_flip is False:
        transform = transforms.Compose(
            [transforms.Resize(config.data.image_size),
             transforms.ToTensor()])
    else:
        if not test:
            transform = transforms.Compose([
                transforms.Resize(config.data.image_size),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.ToTensor()
            ])
        else:
            transform = transforms.Compose([
                transforms.Resize(config.data.image_size),
                transforms.ToTensor()
            ])

    if config.data.dataset.lower().split('_')[0] == 'mnist':
        dataset = MNIST(os.path.join(args.run, 'datasets'),
                        train=not test,
                        download=True,
                        transform=transform)
    elif config.data.dataset.lower().split('_')[0] in [
            'fashionmnist', 'fmnist'
    ]:
        dataset = FashionMNIST(os.path.join(args.run, 'datasets'),
                               train=not test,
                               download=True,
                               transform=transform)
    elif config.data.dataset.lower().split('_')[0] == 'cifar10':
        dataset = CIFAR10(os.path.join(args.run, 'datasets'),
                          train=not test,
                          download=True,
                          transform=transform)
    elif config.data.dataset.lower().split('_')[0] == 'cifar100':
        dataset = CIFAR100(os.path.join(args.run, 'datasets'),
                           train=not test,
                           download=True,
                           transform=transform)
    else:
        raise ValueError('Unknown config dataset {}'.format(
            config.data.dataset))

    if type(dataset.targets) is list:
        # CIFAR10 and CIFAR100 store targets as list, unlike (F)MNIST which uses torch.Tensor
        dataset.targets = np.array(dataset.targets)

    if not rev:
        labels_to_consider = np.arange(config.n_labels)
        target_transform = lambda label: single_one_hot_encode(
            label, n_labels=config.n_labels)
        cond_size = config.n_labels

    else:
        labels_to_consider = np.arange(config.n_labels, total_labels)
        target_transform = lambda label: single_one_hot_encode_rev(
            label, start_label=config.n_labels, n_labels=total_labels)
        cond_size = total_labels - config.n_labels
    if reduce_labels:
        idx = np.any(
            [np.array(dataset.targets) == i for i in labels_to_consider],
            axis=0).nonzero()
        dataset.targets = dataset.targets[idx]
        dataset.data = dataset.data[idx]
    if one_hot:
        dataset.target_transform = target_transform
    if subset and args.subset_size != 0:
        dataset = torch.utils.data.Subset(dataset, np.arange(args.subset_size))
    dataloader = DataLoader(dataset,
                            batch_size=config.training.batch_size,
                            shuffle=shuffle,
                            num_workers=0)

    return dataloader, dataset, cond_size
Esempio n. 14
0
def load_dataset(args):
    '''
		Loads the dataset specified
	'''

    # MNIST dataset
    if args.dataset == 'mnist':
        trans_img = transforms.Compose([transforms.ToTensor()])

        print("Downloading MNIST data...")
        trainset = MNIST('./data',
                         train=True,
                         transform=trans_img,
                         download=True)
        testset = MNIST('./data',
                        train=False,
                        transform=trans_img,
                        download=True)

    # CIFAR-10 dataset
    if args.dataset == 'cifar10':
        # Data
        print('==> Preparing data..')
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        trainset = CIFAR10(root='./data',
                           train=True,
                           transform=transform_train,
                           download=True)
        testset = CIFAR10(root='./data',
                          train=False,
                          transform=transform_test,
                          download=True)

    if args.dataset == 'cifar100':
        # Data
        print('==> Preparing data..')
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        trainset = CIFAR100(root='./data',
                            train=True,
                            transform=transform_train,
                            download=True)
        testset = CIFAR100(root='./data',
                           train=False,
                           transform=transform_test,
                           download=True)

    if args.dataset == 'fashionmnist':

        normalize = transforms.Normalize(
            mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
            std=[x / 255.0 for x in [63.0, 62.1, 66.7]])

        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])

        trainset = FASHION(root='./data',
                           train=True,
                           transform=transform,
                           download=True)
        testset = FASHION(root='./data',
                          train=False,
                          transform=transform,
                          download=True)

    if args.dataset == 'svhn':
        train_transform = transforms.Compose([])

        normalize = transforms.Normalize(
            mean=[x / 255.0 for x in [109.9, 109.7, 113.8]],
            std=[x / 255.0 for x in [50.1, 50.6, 50.8]])

        train_transform.transforms.append(transforms.ToTensor())
        train_transform.transforms.append(normalize)

        trainset = SVHN(root='./data',
                        split='train',
                        transform=train_transform,
                        download=True)

        extra_dataset = SVHN(root='./data',
                             split='extra',
                             transform=train_transform,
                             download=True)

        # Combine both training splits, as is common practice for SVHN

        data = np.concatenate([trainset.data, extra_dataset.data], axis=0)
        labels = np.concatenate([trainset.labels, extra_dataset.labels],
                                axis=0)

        trainset.data = data
        trainset.labels = labels

        test_transform = transforms.Compose([transforms.ToTensor(), normalize])
        testset = SVHN(root='./data',
                       split='test',
                       transform=test_transform,
                       download=True)

    # Self-Paced Learning Enabled
    if args.spld:
        train_idx = np.arange(len(trainset))
        #numpy.random.shuffle(train_idx)
        n_train = len(train_idx)
        train_sampler = SubsetSequentialSamplerSPLDML(range(len(trainset)),
                                                      range(len(trainset)))
        trainloader = DataLoader(trainset,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=4,
                                 sampler=train_sampler)

        testloader = DataLoader(testset,
                                batch_size=args.batch_size,
                                shuffle=True,
                                num_workers=4)
    elif args.spldml:
        n_train = len(trainset)
        train_sampler = SubsetSequentialSamplerSPLDML(range(len(trainset)),
                                                      range(args.batch_size))
        trainloader = DataLoader(trainset,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=1,
                                 sampler=train_sampler)

        testloader = DataLoader(testset,
                                batch_size=args.batch_size,
                                shuffle=True,
                                num_workers=1)
    # Deep Metric Learning
    elif args.dml:
        n_train = len(trainset)
        train_sampler = SubsetSequentialSampler(range(len(trainset)),
                                                range(args.batch_size))
        trainloader = DataLoader(trainset,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=1,
                                 sampler=train_sampler)

        testloader = DataLoader(testset,
                                batch_size=args.batch_size,
                                shuffle=True,
                                num_workers=1)
    elif args.stratified:
        n_train = len(trainset)
        labels = getattr(trainset, 'train_labels')

        if isinstance(labels, list):
            labels = torch.FloatTensor(np.array(labels))

        train_sampler = StratifiedSampler(labels, args.batch_size)
        trainloader = DataLoader(trainset,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=4,
                                 sampler=train_sampler)

        testloader = DataLoader(testset,
                                batch_size=args.batch_size,
                                shuffle=False,
                                num_workers=4)
    # Random sampling
    else:
        n_train = len(trainset)
        trainloader = DataLoader(trainset,
                                 batch_size=args.batch_size,
                                 shuffle=True,
                                 num_workers=4)

        testloader = DataLoader(testset,
                                batch_size=args.batch_size,
                                shuffle=True,
                                num_workers=4)

    return trainloader, testloader, trainset, testset, n_train