示例#1
0
    parser.add_argument('--use_torch',
                        action='store_true',
                        help='make torch be backend, or the numpy is used '
                        '(default=False)')
    args = parser.parse_args()

    if args.dataset == "cifar10":
        dataset = datasets.CIFAR10('./data',
                                   train=False,
                                   download=True,
                                   transform=transforms.ToTensor())
    elif args.dataset == "stl10":
        dataset = datasets.STL10('./data',
                                 split='unlabeled',
                                 download=True,
                                 transform=transforms.Compose([
                                     transforms.Resize((48, 48)),
                                     transforms.ToTensor()
                                 ]))
    else:
        # implement your dataset here
        raise NotImplementedError("Dataset %s is not supported" % args.dataset)

    def image_generator(dataset):
        for x, _ in dataset:
            yield x.numpy()

    m, s = get_statistics(image_generator(dataset),
                          num_images=len(dataset),
                          batch_size=50,
                          use_torch=args.use_torch,
示例#2
0
                            transforms.Resize(opt.imageSize),
                            transforms.CenterCrop(opt.imageSize),
                            transforms.ToTensor(),
                            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                        ]))
elif opt.dataset == 'cifar10':
    dataset = dset.CIFAR10(root=opt.dataroot, download=True,
                           transform=transforms.Compose([
                               transforms.Resize(opt.imageSize),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
elif opt.dataset == 'stl10':
    dataset = dset.STL10(root=opt.dataroot, download=True,
                         transform=transforms.Compose([
                             transforms.Resize(opt.imageSize),
                             transforms.ToTensor(),
                             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                         ]))
elif opt.dataset == 'fake':
    dataset = dset.FakeData(image_size=(3, opt.imageSize, opt.imageSize),
                            transform=transforms.ToTensor())
assert dataset
dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize,
                                         shuffle=True, num_workers=int(opt.workers))

device = torch.device("cuda:0" if opt.cuda else "cpu")
ngpu = int(opt.ngpu)
nz = int(opt.nz)
ngf = int(opt.ngf)
ndf = int(opt.ndf)
nc = 3
示例#3
0
def build_dataset(
    dataset,
    batch_size,
    input_dir=None,
    labeled_only=False
):  # Nawid- Loads the dataset and performs datasets on different implementations

    train_dir, val_dir = _get_directories(
        dataset, input_dir)  # Nawid - Obtains the directory of a dataset

    if dataset == Dataset.C10:
        num_classes = 10
        train_transform = TransformsC10(
        )  # Nawid - Obtains the transform for the training set C10
        test_transform = train_transform.test_transform
        train_dataset = datasets.CIFAR10(root='/content/gdrive/My Drive/',
                                         train=True,
                                         transform=train_transform,
                                         download=True)
        test_dataset = datasets.CIFAR10(root='/content/gdrive/My Drive/',
                                        train=False,
                                        transform=test_transform,
                                        download=True)
    elif dataset == Dataset.C100:
        num_classes = 100
        train_transform = TransformsC10()
        test_transform = train_transform.test_transform
        train_dataset = datasets.CIFAR100(root='/content/gdrive/My Drive/',
                                          train=True,
                                          transform=train_transform,
                                          download=True)
        test_dataset = datasets.CIFAR100(root='/content/gdrive/My Drive/',
                                         train=False,
                                         transform=test_transform,
                                         download=True)
    elif dataset == Dataset.STL10:
        num_classes = 10
        train_transform = TransformsSTL10()
        test_transform = train_transform.test_transform
        train_split = 'train' if labeled_only else 'train+unlabeled'
        train_dataset = datasets.STL10(root='/content/gdrive/My Drive/',
                                       split=train_split,
                                       transform=train_transform,
                                       download=True)
        test_dataset = datasets.STL10(root='/content/gdrive/My Drive/',
                                      split='test',
                                      transform=test_transform,
                                      download=True)
    elif dataset == Dataset.IN128:
        num_classes = 1000
        train_transform = TransformsImageNet128()
        test_transform = train_transform.test_transform
        train_dataset = datasets.ImageFolder(train_dir, train_transform)
        test_dataset = datasets.ImageFolder(val_dir, test_transform)
    elif dataset == Dataset.PLACES205:
        num_classes = 1000
        train_transform = TransformsImageNet128()
        test_transform = train_transform.test_transform
        train_dataset = datasets.ImageFolder(train_dir, train_transform)
        test_dataset = datasets.ImageFolder(val_dir, test_transform)

    # build pytorch dataloaders for the datasets
    train_loader = \
        torch.utils.data.DataLoader(dataset=train_dataset,
                                    batch_size=batch_size,
                                    shuffle=True,
                                    pin_memory=True,
                                    drop_last=True,
                                    num_workers=16)
    test_loader = \
        torch.utils.data.DataLoader(dataset=test_dataset,
                                    batch_size=batch_size,
                                    shuffle=True,
                                    pin_memory=True,
                                    drop_last=True,
                                    num_workers=16)

    return train_loader, test_loader, num_classes
示例#4
0
def test(dataset_name, epoch):
    # assert dataset_name in ['MNIST', 'mnist_m']
    assert dataset_name in [
        'cifar', 'd_robust_CIFAR', 'd_non_robust_CIFAR', 'stl'
    ]

    model_root = 'models'
    image_root = os.path.join('dataset', dataset_name)

    cuda = True
    cudnn.benchmark = True
    batch_size = 128
    # image_size = 28
    image_size = 32
    alpha = 0

    cifar_cls_mapping = np.array([0, 1, 2, 3, 4, 5, -1, 6, 7, 8])
    stl_cls_mapping = np.array([0, 2, 1, 3, 4, 5, 6, -1, 7, 8])
    """load data"""

    # img_transform_source = transforms.Compose([
    #     transforms.Resize(image_size),
    #     transforms.ToTensor(),
    #     transforms.Normalize(mean=(0.1307,), std=(0.3081,))
    # ])
    #
    # img_transform_target = transforms.Compose([
    #     transforms.Resize(image_size),
    #     transforms.ToTensor(),
    #     transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    # ])

    img_transform_cifar = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
    ])

    img_transform_stl = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])

    # if dataset_name == 'mnist_m':
    #     test_list = os.path.join(image_root, 'mnist_m_test_labels.txt')
    #
    #     dataset = GetLoader(
    #         data_root=os.path.join(image_root, 'mnist_m_test'),
    #         data_list=test_list,
    #         transform=img_transform_target
    #     )
    # else:
    #     dataset = datasets.MNIST(
    #         root='dataset',
    #         train=False,
    #         transform=img_transform_source,
    #     )

    if dataset_name == 'stl':
        dataset = datasets.STL10(root='dataset',
                                 split='test',
                                 transform=img_transform_stl)
        dataset.labels = stl_cls_mapping[dataset.labels]
        dataset_mask = dataset.labels != -1
        dataset.data = dataset.data[dataset_mask]
        dataset.labels = dataset.labels[dataset_mask]
    elif dataset_name == 'cifar':
        dataset = datasets.CIFAR10(root='dataset',
                                   train=False,
                                   transform=img_transform_cifar)
        dataset.targets = cifar_cls_mapping[dataset.targets]
        dataset_mask = dataset.targets != -1
        dataset.data = dataset.data[dataset_mask]
        dataset.targets = dataset.targets[dataset_mask]
    elif dataset_name == 'd_robust_CIFAR' or dataset_name == 'd_non_robust_CIFAR':
        train_data = torch.cat(
            torch.load(os.path.join(image_root, f"CIFAR_ims")))
        train_labels = torch.cat(
            torch.load(os.path.join(image_root, f"CIFAR_lab")))
        cls_mapping = torch.from_numpy(cifar_cls_mapping)
        train_labels = cls_mapping[train_labels]
        dataset_mask = train_labels != -1
        train_data = train_data[dataset_mask]
        train_labels = train_labels[dataset_mask]
        dataset = torch.utils.data.TensorDataset(train_data, train_labels)
    else:
        raise ValueError(f"Test dataset name is not valid: {dataset_name}")

    dataloader = torch.utils.data.DataLoader(dataset=dataset,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             num_workers=8)
    """ training """

    my_net = torch.load(
        os.path.join(
            # model_root, 'mnist_mnistm_model_epoch_' + str(epoch) + '.pth'
            # model_root, 'cifar_stl_model_epoch_' + str(epoch) + '.pth'
            model_root,
            'stl_cifar_model_epoch_' + str(epoch) + '.pth'))
    my_net = my_net.eval()

    if cuda:
        my_net = my_net.cuda()

    len_dataloader = len(dataloader)
    data_target_iter = iter(dataloader)

    i = 0
    n_total = 0
    n_correct = 0

    while i < len_dataloader:

        # test model using target data
        data_target = data_target_iter.next()
        t_img, t_label = data_target

        batch_size = len(t_label)

        input_img = torch.FloatTensor(batch_size, 3, image_size, image_size)
        class_label = torch.LongTensor(batch_size)

        if cuda:
            t_img = t_img.cuda()
            t_label = t_label.cuda()
            input_img = input_img.cuda()
            class_label = class_label.cuda()

        input_img.resize_as_(t_img).copy_(t_img)
        class_label.resize_as_(t_label).copy_(t_label)

        class_output, _ = my_net(input_data=input_img,
                                 alpha=alpha,
                                 training=False)
        pred = class_output.data.max(1, keepdim=True)[1]
        n_correct += pred.eq(class_label.data.view_as(pred)).cpu().sum()
        n_total += batch_size

        i += 1

    accu = n_correct.data.numpy() * 1.0 / n_total

    print('epoch: %d, accuracy of the %s dataset: %f' %
          (epoch, dataset_name, accu))

    return accu
示例#5
0
def get_stl_dataloaders(root='data',
                        batch_size=128,
                        num_workers=8,
                        is_instance=False):
    train_transform = transforms.Compose([
        transforms.Resize(32),
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4468, 0.4399, 0.4068),
                             (0.2419, 0.2384, 0.2541)),
    ])
    test_transform = transforms.Compose([
        transforms.Resize(32),
        transforms.ToTensor(),
        transforms.Normalize((0.4468, 0.4399, 0.4068),
                             (0.2419, 0.2384, 0.2541)),
    ])

    if is_instance:
        train_set = STLInstance(root=root,
                                download=True,
                                split='train',
                                transform=train_transform)
        n_data = len(train_set)
    else:
        train_set = datasets.STL10(root=root,
                                   download=True,
                                   split='train',
                                   transform=train_transform)
    train_loader = DataLoader(train_set,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=num_workers)

    test_set = datasets.STL10(root=root,
                              download=True,
                              split='test',
                              transform=test_transform)
    test_loader = DataLoader(test_set,
                             batch_size=int(batch_size / 2),
                             shuffle=False,
                             num_workers=int(num_workers / 2))

    if is_instance:
        return train_loader, test_loader, n_data
    else:
        return train_loader, test_loader


#
# if __name__ == '__main__':
#     import numpy as np
#     from tqdm import tqdm
#     num = 5000
#     batch = 50
#     size = 32
#     imgs = np.zeros(shape=(num, 3, size, size))
#     train_loader, test_loader = get_stl_dataloaders('../datasets', num_workers=4, batch_size=batch)
#
#     for i, (img, _) in enumerate(tqdm(train_loader)):
#         imgs[i*batch: (i+1)*batch] = img.numpy()
#
#
#     for i in range(3):
#         c = imgs[:, i, :, :]
#         print(round(c.mean(), 4), round(c.std(), 4))
示例#6
0
from utils import get_negative_mask, get_similarity_function
from data_aug.data_transform import DataTransform, get_data_transform_opes

torch.manual_seed(0)

config = yaml.load(open("config.yaml", "r"), Loader=yaml.FullLoader)

batch_size = config['batch_size']
out_dim = config['out_dim']
temperature = config['temperature']
use_cosine_similarity = config['use_cosine_similarity']

data_augment = get_data_transform_opes(s=config['s'], crop_size=96)

train_dataset = datasets.STL10('./data',
                               split='unlabeled',
                               download=True,
                               transform=DataTransform(data_augment))

train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          num_workers=config['num_workers'],
                          drop_last=True,
                          shuffle=True)

# model = Encoder(out_dim=out_dim)
model = ResNetSimCLR(base_model=config["base_convnet"], out_dim=out_dim)

train_gpu = torch.cuda.is_available()
print("Is gpu available:", train_gpu)

# moves the model parameters to gpu
示例#7
0
def main():
    global best_acc, use_apex, mean, std, scale

    args = parse_args()
    args.mean, args.std, args.scale, args.use_apex = mean, std, scale, use_apex
    args.is_master = args.local_rank == 0

    if args.deterministic:
        cudnn.deterministic = True
        torch.manual_seed(0)
        random.seed(0)
        np.random.seed(0)

    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1 and args.use_apex

    if args.is_master:
        print("opt_level = {}".format(args.opt_level))
        print("keep_batchnorm_fp32 = {}".format(args.keep_batchnorm_fp32),
              type(args.keep_batchnorm_fp32))
        print("loss_scale = {}".format(args.loss_scale), type(args.loss_scale))
        print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version()))
        print(f"Use Apex: {args.use_apex}")
        print(f"Distributed Training Enabled: {args.distributed}")

    args.gpu = 0
    args.world_size = 1

    if args.distributed:
        args.gpu = args.local_rank
        torch.cuda.set_device(args.gpu)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()
        # Scale learning rate based on global batch size
        # args.lr *= args.batch_size * args.world_size / 256

    if args.use_apex:
        assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."

    # create model
    model = models.ResNet18(args.num_patches, args.num_angles)

    if args.sync_bn:
        import apex
        print("using apex synced BN")
        model = apex.parallel.convert_syncbn_model(model)

    model = model.cuda()
    optimiser = Ranger(model.parameters(), lr=args.lr)
    criterion = nn.CrossEntropyLoss().cuda()

    # Initialize Amp.  Amp accepts either values or strings for the optional override arguments,
    # for convenient interoperation with argparse.
    if args.use_apex:
        model, optimiser = amp.initialize(
            model,
            optimiser,
            opt_level=args.opt_level,
            keep_batchnorm_fp32=args.keep_batchnorm_fp32,
            loss_scale=args.loss_scale)

    # For distributed training, wrap the model with apex.parallel.DistributedDataParallel.
    # This must be done AFTER the call to amp.initialize.  If model = DDP(model) is called
    # before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter
    # the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks.
    if args.distributed:
        model = DDP(model, delay_allreduce=True)
    else:
        model = nn.DataParallel(model)

    # Optionally resume from a checkpoint
    if args.resume:
        # Use a local scope to avoid dangling references
        def resume():
            global best_acc
            if os.path.isfile(args.resume):
                print("=> loading checkpoint '{}'".format(args.resume))
                checkpoint = torch.load(
                    args.resume,
                    map_location=lambda storage, loc: storage.cuda(args.gpu))
                args.start_epoch = checkpoint['epoch']
                best_acc = checkpoint['best_acc']
                args.poisson_rate = checkpoint["poisson_rate"]
                model.load_state_dict(checkpoint['state_dict'])
                optimiser.load_state_dict(checkpoint['optimiser'])
                print("=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']))
            else:
                print("=> no checkpoint found at '{}'".format(args.resume))

        resume()

    if args.do_ssl:
        stl_unlabeled = datasets.STL10(root=args.data,
                                       split='unlabeled',
                                       download=args.download)
        indices = list(range(len(stl_unlabeled)))
        train_indices = indices[:int(len(indices) * 0.9)]
        val_indices = indices[int(len(indices) * 0.9):]
        train_dataset = SSLTrainDataset(Subset(stl_unlabeled, train_indices),
                                        args.num_patches, args.num_angles,
                                        args.poisson_rate)
        val_dataset = SSLValDataset(Subset(stl_unlabeled, val_indices),
                                    args.num_patches, args.num_angles)

        train_sampler = None
        val_sampler = None
        if args.distributed:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset)
            val_sampler = torch.utils.data.distributed.DistributedSampler(
                val_dataset)

        train_loader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=(train_sampler is None),
                                  num_workers=args.workers,
                                  pin_memory=True,
                                  sampler=train_sampler,
                                  collate_fn=fast_collate)

        val_loader = DataLoader(val_dataset,
                                batch_size=args.batch_size,
                                shuffle=False,
                                num_workers=args.workers,
                                pin_memory=True,
                                sampler=val_sampler,
                                collate_fn=fast_collate)

        if args.evaluate:
            rot_val_loss, rot_val_acc, perm_val_loss, perm_val_acc = apex_validate(
                val_loader, model, criterion, args)
            if args.is_master:
                utils.logger.info(
                    f"Rot Val Loss = {rot_val_loss}, Rot Val Accuracy = {rot_val_acc}"
                )
                utils.logger.info(
                    f"Perm Val Loss = {perm_val_loss}, Perm Val Accuracy = {perm_val_acc}"
                )
            return

        # Create dir to save model and command-line args
        if args.is_master:
            model_dir = time.ctime().replace(" ", "_").replace(":", "_")
            model_dir = os.path.join("models", model_dir)
            os.makedirs(model_dir, exist_ok=True)
            with open(os.path.join(model_dir, "args.json"), "w") as f:
                json.dump(args.__dict__, f, indent=2)
            writer = SummaryWriter()

        for epoch in range(args.start_epoch, args.epochs):
            if args.distributed:
                train_sampler.set_epoch(epoch)

            # train for one epoch
            rot_train_loss, rot_train_acc, perm_train_loss, perm_train_acc = apex_train(
                train_loader, model, criterion, optimiser, args, epoch)

            # evaluate on validation set
            rot_val_loss, rot_val_acc, perm_val_loss, perm_val_acc = apex_validate(
                val_loader, model, criterion, args)

            if (epoch + 1) % args.learn_prd == 0:
                args.poisson_rate += 1
                train_loader.dataset.set_poisson_rate(args.poisson_rate)

            # remember best Acc and save checkpoint
            if args.is_master:
                is_best = perm_val_acc > best_acc
                best_acc = max(perm_val_acc, best_acc)
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'state_dict': model.state_dict(),
                        'best_acc': best_acc,
                        'optimiser': optimiser.state_dict(),
                        "poisson_rate": args.poisson_rate
                    }, is_best, model_dir)

                writer.add_scalars("Rot_Loss", {
                    "train_loss": rot_train_loss,
                    "val_loss": rot_val_loss
                }, epoch)
                writer.add_scalars("Perm_Loss", {
                    "train_loss": perm_train_loss,
                    "val_loss": perm_val_loss
                }, epoch)
                writer.add_scalars("Rot_Accuracy", {
                    "train_acc": rot_train_acc,
                    "val_acc": rot_val_acc
                }, epoch)
                writer.add_scalars("Perm_Accuracy", {
                    "train_acc": perm_train_acc,
                    "val_acc": perm_val_acc
                }, epoch)
                writer.add_scalar("Poisson_Rate",
                                  train_loader.dataset.pdist.rate, epoch)
示例#8
0
def get_data_loader(args):

    if args.dataset == 'mnist':
        trans = transforms.Compose([
            transforms.Scale(32),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])
        train_dataset = MNIST(root=args.dataroot,
                              train=True,
                              download=args.download,
                              transform=trans)
        test_dataset = MNIST(root=args.dataroot,
                             train=False,
                             download=args.download,
                             transform=trans)

    elif args.dataset == 'fashion-mnist':
        trans = transforms.Compose([
            transforms.Scale(32),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])
        train_dataset = FashionMNIST(root=args.dataroot,
                                     train=True,
                                     download=args.download,
                                     transform=trans)
        test_dataset = FashionMNIST(root=args.dataroot,
                                    train=False,
                                    download=args.download,
                                    transform=trans)

    elif args.dataset == 'cifar':
        trans = transforms.Compose([
            transforms.Scale(32),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

        train_dataset = dset.CIFAR10(root=args.dataroot,
                                     train=True,
                                     download=args.download,
                                     transform=trans)
        test_dataset = dset.CIFAR10(root=args.dataroot,
                                    train=False,
                                    download=args.download,
                                    transform=trans)

    elif args.dataset == 'stl10':
        trans = transforms.Compose([
            transforms.Resize(32),
            transforms.ToTensor(),
        ])
        train_dataset = dset.STL10(root=args.dataroot,
                                   train=True,
                                   download=args.download,
                                   transform=trans)
        test_dataset = dset.STL10(root=args.dataroot,
                                  train=False,
                                  download=args.download,
                                  transform=trans)

    # Check if everything is ok with loading datasets
    assert train_dataset
    assert test_dataset

    train_dataloader = data_utils.DataLoader(train_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=True)
    test_dataloader = data_utils.DataLoader(test_dataset,
                                            batch_size=args.batch_size,
                                            shuffle=True)

    return train_dataloader, test_dataloader
示例#9
0
 def load_stl10(self):
     transforms = self.transform(True, True, True, False)
     dataset = dsets.STL10(self.path, transform=transforms, download=True)
     return dataset
示例#10
0
def dataLoadFunc(opt):
    # Data loading parameters
    use_cuda = torch.cuda.is_available()
    params = {'batch_size': opt.batch_size, 'shuffle': True, 'num_workers': 16, 'pin_memory': True} if use_cuda else {}
    
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    if opt.dataset in ['cifar10', 'cifar100']: 
        train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(15),
            transforms.ToTensor(),normalize
        ])

        val_transform = transforms.Compose([
            transforms.ToTensor(),normalize
        ])
    
    elif opt.dataset == "stl10": 
        train_transform = transforms.Compose([
            transforms.RandomCrop(96, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),normalize
        ])

        val_transform = transforms.Compose([
            transforms.ToTensor(),normalize
        ])
    elif opt.dataset == "imagenet":
        train_transform = transforms.Compose([
            transforms.RandomSizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])

        val_transform = transforms.Compose([
                transforms.Scale(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
        ])
    

    if opt.dataset == 'cifar10':
        train_set = datasets.CIFAR10(root='/raid/datasets/public/cifar', train=True, download=False, transform = train_transform)
        valid_set = datasets.CIFAR10(root='/raid/datasets/public/cifar', train=False, download=False, transform = val_transform)
    elif opt.dataset == 'cifar100':
        train_set = datasets.CIFAR100(root='/raid/datasets/public/cifar', train=True, download=False, transform = train_transform)
        valid_set = datasets.CIFAR100(root='/raid/datasets/public/cifar', train=False, download=False, transform = val_transform)
    elif opt.dataset == 'imagenet':
        train_set = datasets.ImageFolder(root='/raid/datasets/public/imagenet/train', transform = train_transform)
        valid_set = datasets.ImageFolder(root='/raid/datasets/public/imagenet/val', transform = val_transform)

    elif opt.dataset == 'stl10':
        train_set = datasets.STL10(root='/raid/datasets/public/stl10', split='train', download=False, transform = train_transform)
        valid_set = datasets.STL10(root='/raid/datasets/public/stl10', split='test', download=False, transform = val_transform)
   

    train_loader = data.DataLoader(train_set, **params)
    valid_loader = data.DataLoader(valid_set, **params)

    return train_loader, valid_loader
示例#11
0
def main(args):

    ### Hyperparameters setting ###
    device = 'cuda' if torch.cuda.is_available else 'cpu'
    main_epochs = args.epochs
    classifier_epochs = args.c_epochs
    T = args.temperature
    patience = args.patience
    num_classes = args.num_classes
    classifier_hidden_dim = args.c_dim
    projection_hidden_dim = args.p_dim
    in_dim = 512  # Constant as long as we use ResNet18

    # model definition
    f, g = resnet18_encoder().to(device), ProjectionHead(
        in_dim, projection_hidden_dim).to(device)

    if not args.test:
        ### Train SimCLR ###
        dataset = DataSetWrapper(args.batch_size,
                                 args.num_worker,
                                 args.valid_size,
                                 input_shape=(96, 96, 3))
        train_loader, valid_loader = dataset.get_data_loaders()

        criterion = NT_XentLoss(T)
        optimizer = torch.optim.Adam(list(f.parameters()) +
                                     list(g.parameters()),
                                     3e-4,
                                     weight_decay=1e-5)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=len(train_loader), eta_min=0, last_epoch=-1)

        train_losses, val_losses = train(main_epochs, patience, optimizer,
                                         scheduler, train_loader, valid_loader,
                                         f, g, criterion)

        plot_loss_curve(train_losses, val_losses, 'results/train_loss.png',
                        'results/val_loss.png')

    else:
        ### Test ###
        load_checkpoint(f, g, args.checkpoint)
        classifier = Classifier(in_dim, num_classes,
                                classifier_hidden_dim).to(device)

        data_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

        if not os.path.exists('checkpoints/classifier.pt'):
            ### Train Classifier ###
            train_dataset = datasets.STL10('./data',
                                           split='train',
                                           download=True,
                                           transform=data_transform)
            train_loader = DataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      num_workers=args.num_worker)

            criterion = nn.CrossEntropyLoss()

            if args.fine_tuning:
                params = list(f.parameters()) + list(classifier.parameters())
            else:
                params = classifier.parameters()

            optimizer = torch.optim.Adam(params, lr=1e-4)

            train_classifier(classifier_epochs, train_loader, f, classifier,
                             criterion, optimizer)
            save_checkpoint_classifier(classifier, 'checkpoints/classifier.pt')

        else:
            load_checkpoint_classifier(classifier, 'checkpoints/classifier.pt')

        ### Test ###
        test_dataset = datasets.STL10('./data',
                                      split='test',
                                      download=True,
                                      transform=data_transform)
        test_loader = DataLoader(test_dataset,
                                 batch_size=args.batch_size,
                                 num_workers=args.num_worker)

        accuracy = test(test_loader, f, classifier)
        print("Test Accuracy : %.4f" % (accuracy))
示例#12
0
    def factory(
        self,
        pathname,
        name,
        subset='train',
        idenselect=[],
        download=False,
        transform=None,
    ):
        """Factory dataset
        """

        assert (self._checksubset(subset))
        pathname = os.path.expanduser(pathname)

        # pythorch vision dataset soported

        if name == 'mnist':
            btrain = (subset == 'train')
            pathname = create_folder(pathname, name)
            data = datasets.MNIST(pathname,
                                  train=btrain,
                                  transform=transform,
                                  download=download)
            data.labels = np.array(data.targets)

        elif name == 'fashion':
            btrain = (subset == 'train')
            pathname = create_folder(pathname, name)
            data = datasets.FashionMNIST(pathname,
                                         train=btrain,
                                         transform=transform,
                                         download=download)
            data.labels = np.array(data.targets)

        elif name == 'emnist':
            btrain = (subset == 'train')
            pathname = create_folder(pathname, name)
            data = datasets.EMNIST(pathname,
                                   split='byclass',
                                   train=btrain,
                                   transform=transform,
                                   download=download)
            data.labels = np.array(data.targets)

        elif name == 'cifar10':
            btrain = (subset == 'train')
            pathname = create_folder(pathname, name)
            data = datasets.CIFAR10(pathname,
                                    train=btrain,
                                    transform=transform,
                                    download=download)
            data.labels = np.array(data.targets)

        elif name == 'cifar100':
            btrain = (subset == 'train')
            pathname = create_folder(pathname, name)
            data = datasets.CIFAR100(pathname,
                                     train=btrain,
                                     transform=transform,
                                     download=download)
            data.labels = np.array(data.targets)

        elif name == 'stl10':
            split = 'train' if (subset == 'train') else 'test'
            pathname = create_folder(pathname, name)
            data = datasets.STL10(pathname,
                                  split=split,
                                  transform=transform,
                                  download=download)

        elif name == 'svhn':
            split = 'train' if (subset == 'train') else 'test'
            pathname = create_folder(pathname, name)
            data = datasets.SVHN(pathname,
                                 split=split,
                                 transform=transform,
                                 download=download)
            data.classes = np.unique(data.labels)

        # internet dataset

        elif name == 'cub2011':
            btrain = (subset == 'train')
            pathname = create_folder(pathname, name)
            data = cub2011.CUB2011(pathname, train=btrain, download=download)
            data.labels = np.array(data.targets)

        elif name == 'cars196':
            btrain = (subset == 'train')
            pathname = create_folder(pathname, name)
            data = cars196.Cars196(pathname, train=btrain, download=download)
            data.labels = np.array(data.targets)

        elif name == 'stanford_online_products':
            btrain = (subset == 'train')
            pathname = create_folder(pathname, name)
            data = stanford_online_products.StanfordOnlineProducts(
                pathname, train=btrain, download=download)
            data.labels = np.array(data.targets)
            data.btrain = btrain

        # kaggle dataset
        elif name == 'imaterialist':
            pathname = create_folder(pathname, name)
            data = imaterialist.IMaterialistDatset(pathname, subset, 'jpg')

        # fer recognition datasets

        elif name == 'ferp':
            pathname = create_folder(pathname, name)
            if subset == 'train': subfolder = ferp.train
            elif subset == 'val': subfolder = ferp.valid
            elif subset == 'test': subfolder = ferp.test
            else: assert (False)
            data = ferp.FERPDataset(pathname, subfolder, download=download)

        elif name == 'ck':
            btrain = (subset == 'train')
            pathname = create_folder(pathname, name)
            data = fer.FERClassicDataset(pathname,
                                         'ck',
                                         idenselect=idenselect,
                                         train=btrain)

        elif name == 'ckp':
            btrain = (subset == 'train')
            pathname = create_folder(pathname, name)
            data = fer.FERClassicDataset(pathname,
                                         'ckp',
                                         idenselect=idenselect,
                                         train=btrain)

        elif name == 'jaffe':
            btrain = (subset == 'train')
            pathname = create_folder(pathname, name)
            data = fer.FERClassicDataset(pathname,
                                         'jaffe',
                                         idenselect=idenselect,
                                         train=btrain)

        elif name == 'bu3dfe':
            btrain = (subset == 'train')
            pathname = create_folder(pathname, name)
            data = fer.FERClassicDataset(pathname,
                                         'bu3dfe',
                                         idenselect=idenselect,
                                         train=btrain)

        elif name == 'afew':
            btrain = (subset == 'train')
            pathname = create_folder(pathname, name)
            data = afew.Afew(pathname, train=btrain, download=download)
            data.labels = np.array(data.targets)

        elif name == 'celeba':
            btrain = (subset == 'train')
            pathname = create_folder(pathname, name)
            data = celeba.CelebaDataset(pathname,
                                        train=btrain,
                                        download=download)

        elif name == 'affectnet':
            btrain = (subset == 'train')
            pathname = create_folder(pathname, name)
            data = affect.create_affect(path=pathname, train=btrain)

#         elif name == 'ferblack':
#             btrain=(subset=='train')
#             pathname = create_folder(pathname, name)
#             data = ferfolder.FERFolderDataset(pathname, train=btrain, idenselect=idenselect, download=download)
#             data.labels = np.array( data.labels )

        elif name == 'ckdark':
            btrain = (subset == 'train')
            pathname = create_folder(pathname, 'ck')
            data = fer.FERDarkClassicDataset(pathname,
                                             'ck',
                                             idenselect=idenselect,
                                             train=btrain)

        elif name == 'ckpdark':
            btrain = (subset == 'train')
            pathname = create_folder(pathname, 'ckp')
            data = fer.FERDarkClassicDataset(pathname,
                                             'ckp',
                                             idenselect=idenselect,
                                             train=btrain)

        elif name == 'bu3dfedark':
            btrain = (subset == 'train')
            pathname = create_folder(pathname, 'bu3dfe')
            data = fer.FERDarkClassicDataset(pathname,
                                             'bu3dfe',
                                             idenselect=idenselect,
                                             train=btrain)

        elif name == 'jaffedark':
            btrain = (subset == 'train')
            pathname = create_folder(pathname, 'jaffe')
            data = fer.FERDarkClassicDataset(pathname,
                                             'jaffe',
                                             idenselect=idenselect,
                                             train=btrain)

        elif name == 'affectnetdark':
            btrain = (subset == 'train')
            pathname = create_folder(pathname, 'affectnet')
            data = affect.create_affectdark(path=pathname, train=btrain)

        # metric learning dataset

        elif name == 'cub2011metric':
            btrain = (subset == 'train')
            pathname = create_folder(pathname, 'cub2011')
            data = cub2011.CUB2011MetricLearning(pathname,
                                                 train=btrain,
                                                 download=download)
            data.labels = np.array(data.targets)

        elif name == 'cars196metric':
            btrain = (subset == 'train')
            pathname = create_folder(pathname, 'cars196')
            data = cars196.Cars196MetricLearning(pathname,
                                                 train=btrain,
                                                 download=download)
            data.labels = np.array(data.targets)

        else:
            assert (False)

        data.btrain = (subset == 'train')
        return data
示例#13
0
                             transforms.ToTensor(),
                             transforms.Normalize((0.5, ), (0.5, )),
                         ]))
    nc = 1

elif opt.dataset == 'fake':
    dataset = dset.FakeData(image_size=(3, imageSize, imageSize),
                            transform=transforms.ToTensor())
    nc = 3

elif opt.dataset == 'stl10':
    dataset = dset.STL10(root=opt.dataroot,
                         split='unlabeled',
                         transform=transforms.Compose([
                             transforms.Resize(imageSize),
                             transforms.CenterCrop(imageSize),
                             transforms.ToTensor(),
                             transforms.Normalize((0.5, 0.5, 0.5),
                                                  (0.5, 0.5, 0.5))
                         ]),
                         download=True)
    nc = 3
    m_true, s_true = compute_dataset_statistics(target_set="STL10",
                                                batch_size=50,
                                                dims=2048,
                                                cuda=True,
                                                device=default_device)

assert dataset
dataloader = torch.utils.data.DataLoader(dataset,
                                         batch_size=batchSize,
                                         shuffle=True,
示例#14
0
def main():
    # Init logger6
    if not os.path.isdir(args.save_path):
        os.makedirs(args.save_path)
    log = open(os.path.join(args.save_path,
                            'log_seed_{}.txt'.format(args.manualSeed)), 'w')
    print_log('save path : {}'.format(args.save_path), log)
    state = {k: v for k, v in args._get_kwargs()}
    print_log(state, log)
    print_log("Random Seed: {}".format(args.manualSeed), log)
    print_log("python version : {}".format(
        sys.version.replace('\n', ' ')), log)
    print_log("torch  version : {}".format(torch.__version__), log)
    print_log("cudnn  version : {}".format(
        torch.backends.cudnn.version()), log)

    # Init the tensorboard path and writer
    tb_path = os.path.join(args.save_path, 'tb_log')
    # logger = Logger(tb_path)
    # writer = SummaryWriter(tb_path)

    # Init dataset
    if not os.path.isdir(args.data_path):
        os.makedirs(args.data_path)

    if args.dataset == 'cifar10':
        mean = [x / 255 for x in [125.3, 123.0, 113.9]]
        std = [x / 255 for x in [63.0, 62.1, 66.7]]
    elif args.dataset == 'cifar100':
        mean = [x / 255 for x in [129.3, 124.1, 112.4]]
        std = [x / 255 for x in [68.2, 65.4, 70.4]]
    elif args.dataset == 'svhn':
        mean = [0.5, 0.5, 0.5]
        std = [0.5, 0.5, 0.5]
    elif args.dataset == 'mnist':
        mean = [0.5, 0.5, 0.5]
        std = [0.5, 0.5, 0.5]
    elif args.dataset == 'imagenet':
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
    else:
        assert False, "Unknow dataset : {}".format(args.dataset)

    if args.dataset == 'imagenet':
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])
        test_transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])  # here is actually the validation dataset
    else:
        train_transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, padding=4),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])
        test_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])

    if args.dataset == 'mnist':
        train_data = dset.MNIST(
            args.data_path, train=True, transform=train_transform, download=True)
        test_data = dset.MNIST(args.data_path, train=False,
                               transform=test_transform, download=True)
        num_classes = 10
    elif args.dataset == 'cifar10':
        train_data = dset.CIFAR10(
            args.data_path, train=True, transform=train_transform, download=True)
        test_data = dset.CIFAR10(
            args.data_path, train=False, transform=test_transform, download=True)
        num_classes = 10
    elif args.dataset == 'cifar100':
        train_data = dset.CIFAR100(
            args.data_path, train=True, transform=train_transform, download=True)
        test_data = dset.CIFAR100(
            args.data_path, train=False, transform=test_transform, download=True)
        num_classes = 100
    elif args.dataset == 'svhn':
        train_data = dset.SVHN(args.data_path, split='train',
                               transform=train_transform, download=True)
        test_data = dset.SVHN(args.data_path, split='test',
                              transform=test_transform, download=True)
        num_classes = 10
    elif args.dataset == 'stl10':
        train_data = dset.STL10(
            args.data_path, split='train', transform=train_transform, download=True)
        test_data = dset.STL10(args.data_path, split='test',
                               transform=test_transform, download=True)
        num_classes = 10
    elif args.dataset == 'imagenet':
        train_dir = os.path.join(args.data_path, 'train')
        test_dir = os.path.join(args.data_path, 'val')
        train_data = dset.ImageFolder(train_dir, transform=train_transform)
        test_data = dset.ImageFolder(test_dir, transform=test_transform)
        num_classes = 1000
    else:
        assert False, 'Do not support dataset : {}'.format(args.dataset)

    train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True,
                                               num_workers=args.workers, pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False,
                                              num_workers=args.workers, pin_memory=True)

    print_log("=> creating model '{}'".format(args.arch), log)

    # Init model, criterion, and optimizer
    net = models.__dict__[args.arch](num_classes)
    print_log("=> network :\n {}".format(net), log)

    if args.use_cuda:
        if args.ngpu > 1:
            net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))

    # define loss function (criterion) and optimizer
    criterion = torch.nn.CrossEntropyLoss()

    if args.optimizer == "SGD":
        print("using SGD as optimizer")
        optimizer = torch.optim.SGD(filter(lambda param: param.requires_grad, net.parameters()),
                                    lr=state['learning_rate'],
                                    momentum=state['momentum'], weight_decay=state['decay'], nesterov=True)

    elif args.optimizer == "Adam":
        print("using Adam as optimizer")
        optimizer = torch.optim.Adam(filter(lambda param: param.requires_grad, net.parameters()),
                                     lr=state['learning_rate'],
                                     weight_decay=state['decay'])

    elif args.optimizer == "YF":
        print("using YellowFin as optimizer")
        optimizer = YFOptimizer(filter(lambda param: param.requires_grad, net.parameters()), lr=state['learning_rate'],
                                mu=state['momentum'], weight_decay=state['decay'])

    elif args.optimizer == "RMSprop":
        print("using RMSprop as optimizer")
        optimizer = torch.optim.RMSprop(filter(lambda param: param.requires_grad, net.parameters()),
                                        lr=state['learning_rate'], alpha=0.99, eps=1e-08, weight_decay=0, momentum=0)

    if args.use_cuda:
        net.cuda()
        criterion.cuda()

    recorder = RecorderMeter(args.epochs)  # count number of epoches

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print_log("=> loading checkpoint '{}'".format(args.resume), log)
            checkpoint = torch.load(args.resume)
            if not (args.fine_tune):
                args.start_epoch = checkpoint['epoch']
                recorder = checkpoint['recorder']
                optimizer.load_state_dict(checkpoint['optimizer'])

            state_tmp = net.state_dict()
            if 'state_dict' in checkpoint.keys():
                state_tmp.update(checkpoint['state_dict'])
            else:
                state_tmp.update(checkpoint)

            net.load_state_dict(state_tmp)
            # net.load_state_dict(checkpoint['state_dict'])

            print_log("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, args.start_epoch), log)
        else:
            print_log("=> no checkpoint found at '{}'".format(args.resume), log)
    else:
        print_log(
            "=> do not use any checkpoint for {} model".format(args.arch), log)

    if args.evaluate:
        validate(test_loader, net, criterion, log)
        return


    # Main loop
    start_time = time.time()
    epoch_time = AverageMeter()

    for epoch in range(args.start_epoch, args.epochs):
        current_learning_rate, current_momentum = adjust_learning_rate(
            optimizer, epoch, args.gammas, args.schedule)
        # Display simulation time
        need_hour, need_mins, need_secs = convert_secs2time(
            epoch_time.avg * (args.epochs - epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(
            need_hour, need_mins, need_secs)

        print_log(
            '\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [LR={:6.4f}][M={:1.2f}]'.format(time_string(), epoch, args.epochs,
                                                                                   need_time, current_learning_rate,
                                                                                   current_momentum)
            + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False),
                                                               100 - recorder.max_accuracy(False)), log)

        # # ============ TensorBoard logging ============#
        # # we show the model param initialization to give a intuition when we do the fine tuning

        # for name, param in net.named_parameters():
        #     name = name.replace('.', '/')
        #     if "delta_th" not in name:
        #         writer.add_histogram(name, param.clone().cpu().detach().numpy(), epoch)

        # # ============ TensorBoard logging ============#

        # train for one epoch
        train_acc, train_los = train(
            train_loader, net, criterion, optimizer, epoch, log)

        # evaluate on validation set
        val_acc, val_los = validate(test_loader, net, criterion, log)

        is_best = val_acc > recorder.max_accuracy(istrain=False)
        recorder.update(epoch, train_los, train_acc, val_los, val_acc)

        if args.model_only:
            checkpoint_state = {'state_dict': net.state_dict()}
        else:
            checkpoint_state = {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': net.state_dict(),
                'recorder': recorder,
                'optimizer': optimizer.state_dict(),
            }

        save_checkpoint(checkpoint_state, is_best,
                        args.save_path, 'checkpoint.pth.tar', log)

        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()
        recorder.plot_curve(os.path.join(args.save_path, 'curve.png'))

        # save addition accuracy log for plotting
        accuracy_logger(base_dir=args.save_path,
                        epoch=epoch,
                        train_accuracy=train_acc,
                        test_accuracy=val_acc)

        # ============ TensorBoard logging ============#

        # Log the graidents distribution
        # for name, param in net.named_parameters():
        #     name = name.replace('.', '/')
        #     writer.add_histogram(name + '/grad',
        #                          param.grad.clone().cpu().data.numpy(), epoch + 1, bins='tensorflow')

        # ## Log the weight and bias distribution
        # for name, module in net.named_modules():
        #     name = name.replace('.', '/')
        #     class_name = str(module.__class__).split('.')[-1].split("'")[0]

        #     if "Conv2d" in class_name or "Linear" in class_name:
        #         if module.weight is not None:
        #             writer.add_histogram(name + '/weight/',
        #                                  module.weight.clone().cpu().data.numpy(), epoch + 1, bins='tensorflow')

        # writer.add_scalar('loss/train_loss', train_los, epoch + 1)
        # writer.add_scalar('loss/test_loss', val_los, epoch + 1)
        # writer.add_scalar('accuracy/train_accuracy', train_acc, epoch + 1)
        # writer.add_scalar('accuracy/test_accuracy', val_acc, epoch + 1)
    # ============ TensorBoard logging ============#

    log.close()
示例#15
0
                [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize(( 0.5 , 0.5 , 0.5 ), ( 0.5 , 0.5 , 0.5 ))]
            ),
        ),
        batch_size=opt.batch_size,
        shuffle=True,
    )

if opt.dataset == 3:                # STL 10
    # Configure data loader
    os.makedirs("../../data/STL10", exist_ok=True)
    dataloader = torch.utils.data.DataLoader(
        datasets.STL10(
            "../../data/STL10",
            split='train',
            transform=transforms.Compose(
                [transforms.Resize(opt.img_size), transforms.ToTensor(),transforms.Normalize(( 0.5 , 0.5 , 0.5 ), ( 0.5 , 0.5 , 0.5 ))]
            ),
            #target_transform=None,
            download=True
            ),
        
        batch_size=opt.batch_size,
        shuffle=True,
    )  

if opt.dataset == 4:                # FASHION MNIST
    # Configure data loader
    os.makedirs("../../data/FashionMNIST", exist_ok=True)
    dataloader = torch.utils.data.DataLoader(
        datasets.FashionMNIST(
            "../../data/FashionMNIST",
示例#16
0
def get_data(train_size=200):

    if args.dataset == 'mnist':
        test_dir = '/home/y/yx277/research/ImageDataset/mnist'

        test_dataset = datasets.MNIST(root=test_dir,
                                      train=False,
                                      download=True,
                                      transform=None)
        data = test_dataset.data
        label = test_dataset.targets

        index = label < args.n_classes
        data = data[index]
        labels = label[index]
        # labels = label
        # data = torch.from_numpy(np.load('../data/mnist/test_image.npy'))
        # labels = torch.from_numpy(np.load('../data/mnist/test_label.npy'))

        indices = np.random.permutation(data.shape[0])
        sub_x = data[indices[:train_size]].float().reshape((-1, 1, 28, 28))
        sub_x /= 255
        test_data = data[indices[train_size:]].float().reshape((-1, 1, 28, 28))
        test_data /= 255
        test_label = labels[indices[train_size:]].long()

    elif args.dataset == 'cifar10':
        test_dir = '/home/y/yx277/research/ImageDataset/cifar10'

        test_dataset = datasets.CIFAR10(root=test_dir,
                                        train=False,
                                        download=True,
                                        transform=None)
        data = torch.from_numpy(np.array(test_dataset.data, dtype=np.float32))
        label = torch.from_numpy(np.array(test_dataset.targets,
                                          dtype=np.int64))
        index = label < args.n_classes
        data = data[index]
        labels = label[index]

        # data, labels = binary_class(data, labels, 6, 8)
        # labels = label

        indices = np.random.permutation(data.shape[0])
        sub_x = data[indices[:train_size]].float()
        sub_x /= 255
        test_data = data[indices[train_size:]].float()
        test_data /= 255
        test_label = labels[indices[train_size:]].long()

    elif args.dataset == 'cifar10_binary':
        test_dir = '/home/y/yx277/research/ImageDataset/cifar10'

        test_dataset = datasets.CIFAR10(root=test_dir,
                                        train=False,
                                        download=True,
                                        transform=None)
        data = torch.from_numpy(np.array(test_dataset.data, dtype=np.float32))
        label = torch.from_numpy(np.array(test_dataset.targets,
                                          dtype=np.int64))

        data, labels = binary_class(data, label, 6, 8)
        # labels = label

        indices = np.random.permutation(data.shape[0])
        sub_x = data[indices[:train_size]].float()
        sub_x /= 255
        test_data = data[indices[train_size:]].float()
        test_data /= 255
        test_label = labels[indices[train_size:]].long()

    elif args.dataset == 'stl10':
        test_dir = '/home/y/yx277/research/ImageDataset/stl10'

        test_dataset = datasets.STL10(root=test_dir,
                                      split='test',
                                      download=False,
                                      transform=None)
        data = test_dataset.data
        label = test_dataset.labels
        index = label < 2
        data = data[index]
        labels = label[index]

        indices = np.random.permutation(data.shape[0])
        sub_x = torch.from_numpy(data[indices[:train_size]].transpose(
            [0, 2, 3, 1])).float()
        sub_x /= 255
        test_data = torch.from_numpy(data[indices[train_size:]].transpose(
            [0, 2, 3, 1])).float()
        test_data /= 255
        test_label = torch.from_numpy(labels[indices[train_size:]]).long()

    elif args.dataset == 'imagenet':
        test_dir = '../data/imagenet'

        data = torch.from_numpy(np.load('%s/test_image.npy' % test_dir))
        label = torch.from_numpy(np.load('%s/test_label.npy' % test_dir))
        index = label < args.n_classes
        data = data[index]
        labels = label[index]

        indices = np.random.permutation(data.shape[0])
        sub_x = data[indices[:train_size]].float()

        test_data = data[indices[train_size:]].float()

        test_label = labels[indices[train_size:]].long()

    elif args.dataset == 'gtsrb':
        test_dir = '../data/gtsrb'

        data = torch.from_numpy(np.load('%s/test_image.npy' % test_dir))
        label = torch.from_numpy(np.load('%s/test_label.npy' % test_dir))
        index = label < args.n_classes
        data = data[index]
        labels = label[index]

        indices = np.random.permutation(data.shape[0])
        sub_x = data[indices[:train_size]].float()

        test_data = data[indices[train_size:]].float()

        test_label = labels[indices[train_size:]].long()

    elif args.dataset == 'gtsrb_binary':
        test_dir = '../data/gtsrb_binary'

        data = torch.from_numpy(np.load('%s/test_image.npy' % test_dir))
        label = torch.from_numpy(np.load('%s/test_label.npy' % test_dir))
        index = label < args.n_classes
        data = data[index]
        labels = label[index]

        indices = np.random.permutation(data.shape[0])
        sub_x = data[indices[:train_size]].float()

        test_data = data[indices[train_size:]].float()

        test_label = labels[indices[train_size:]].long()

    if binarize:
        sub_x = bi(sub_x, args.eps)
        test_data = bi(test_data, args.eps)

    return sub_x, test_data, test_label
示例#17
0
def compute_dataset_statistics(target_set="STL10",
                               batch_size=50,
                               dims=2048,
                               cuda=True):
    imageSize = 64
    if target_set == "CIFAR10":
        dataset = datasets.CIFAR10(root="~/datasets/data_cifar10",
                                   train=False,
                                   download=True,
                                   transform=transforms.Compose([
                                       transforms.Resize(imageSize),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.5, 0.5, 0.5),
                                                            (0.5, 0.5, 0.5)),
                                   ]))
    elif target_set == "MNIST":
        dataset = datasets.MNIST(root="~/datasets",
                                 train=True,
                                 download=True,
                                 transform=transforms.Compose([
                                     transforms.Resize(imageSize),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, ), (0.5, )),
                                 ]))
    elif target_set == "LSUN":
        dataset = datasets.LSUN(root='~/datasets/data_lsun',
                                classes='church_outdoor',
                                transform=transforms.Compose([
                                    transforms.Resize(imageSize),
                                    transforms.CenterCrop(imageSize),
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.5, 0.5, 0.5),
                                                         (0.5, 0.5, 0.5)),
                                ]))
    elif target_set == 'STL10':
        dataset = datasets.STL10(root='~/datasets/data_stl10',
                                 split='unlabeled',
                                 transform=transforms.Compose([
                                     transforms.Resize(imageSize),
                                     transforms.CenterCrop(imageSize),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5),
                                                          (0.5, 0.5, 0.5)),
                                 ]),
                                 download=True)
        nc = 3

    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=batch_size,
                                              shuffle=True)

    block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
    model = InceptionV3([block_idx])
    if cuda:
        model.cuda()

    model.eval()
    pred_arr = np.empty((len(dataset), dims))
    start = 0

    print("Computing statistics of the given dataset...")

    for (x, y) in tqdm(data_loader):
        if target_set == "MNIST":
            tmp = torch.zeros((x.size()[0], 3, x.size()[2], x.size()[3]))
            for i in range(3):
                tmp[:, i, :, :] = x[:, 0, :, :]
            x = tmp
        end = start + x.size(0)
        if cuda:
            x = x.cuda()
        pred = model(x)[0]
        if pred.size(2) != 1 or pred.size(3) != 1:
            pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
        pred_arr[start:end] = pred.cpu().data.numpy().reshape(pred.size(0), -1)
        start = end

    mu = np.mean(pred_arr, axis=0)
    sigma = np.cov(pred_arr, rowvar=False)
    return mu, sigma
def load_data_subset(data_aug, batch_size,workers,dataset, data_target_dir, labels_per_class=100, valid_labels_per_class = 500):
    ## copied from GibbsNet_pytorch/load.py
    import numpy as np
    from functools import reduce
    from operator import __or__
    from torch.utils.data.sampler import SubsetRandomSampler
        
    if dataset == 'cifar10':
        mean = [x / 255 for x in [125.3, 123.0, 113.9]]
        std = [x / 255 for x in [63.0, 62.1, 66.7]]
    elif dataset == 'cifar100':
        mean = [x / 255 for x in [129.3, 124.1, 112.4]]
        std = [x / 255 for x in [68.2, 65.4, 70.4]]
    elif dataset == 'svhn':
        mean = [x / 255 for x in [127.5, 127.5, 127.5]]
        std = [x / 255 for x in [127.5, 127.5, 127.5]]
    elif dataset == 'tiny-imagenet-200':
        mean = [x / 255 for x in [127.5, 127.5, 127.5]]
        std = [x / 255 for x in [127.5, 127.5, 127.5]]
    elif dataset == 'mnist':
        pass 
    else:
        assert False, "Unknow dataset : {}".format(dataset)
    
    if data_aug==1:
        print ('data aug')
        if dataset == 'svhn':
            train_transform = transforms.Compose(
                                             [ transforms.RandomCrop(32, padding=2), transforms.ToTensor(),
                                              transforms.Normalize(mean, std)])
            test_transform = transforms.Compose(
                                            [transforms.ToTensor(), transforms.Normalize(mean, std)])
        elif dataset == 'mnist':
            hw_size = 24
            train_transform = transforms.Compose([
                                transforms.RandomCrop(hw_size),                
                                transforms.ToTensor(),
                                transforms.Normalize((0.1307,), (0.3081,))
                           ])
            test_transform = transforms.Compose([
                                transforms.CenterCrop(hw_size),                       
                                transforms.ToTensor(),
                                transforms.Normalize((0.1307,), (0.3081,))
                           ])
        elif dataset == 'tiny-imagenet-200':
            train_transform = transforms.Compose(
                                                 [transforms.RandomHorizontalFlip(),
                                                  transforms.RandomCrop(64, padding=4),
                                                  transforms.ToTensor(),
                                                  transforms.Normalize(mean, std)])
            test_transform = transforms.Compose(
                                                [transforms.ToTensor(), transforms.Normalize(mean, std)])
        else:    
            train_transform = transforms.Compose(
                                                 [transforms.RandomHorizontalFlip(),
                                                  transforms.RandomCrop(32, padding=2),
                                                  transforms.ToTensor(),
                                                  transforms.Normalize(mean, std)])
            test_transform = transforms.Compose(
                                                [transforms.ToTensor(), transforms.Normalize(mean, std)])
    else:
        print ('no data aug')
        if dataset == 'mnist':
            hw_size = 28
            train_transform = transforms.Compose([
                                transforms.ToTensor(),       
                                transforms.Normalize((0.1307,), (0.3081,))
                           ])
            test_transform = transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize((0.1307,), (0.3081,))
                           ])
                
        else:   
            train_transform = transforms.Compose(
                                                 [transforms.ToTensor(),
                                                 transforms.Normalize(mean, std)])
            test_transform = transforms.Compose(
                                                [transforms.ToTensor(), transforms.Normalize(mean, std)])
    if dataset == 'cifar10':
        train_data = datasets.CIFAR10(data_target_dir, train=True, transform=train_transform, download=True)
        test_data = datasets.CIFAR10(data_target_dir, train=False, transform=test_transform, download=True)
        num_classes = 10
    elif dataset == 'cifar100':
        train_data = datasets.CIFAR100(data_target_dir, train=True, transform=train_transform, download=True)
        test_data = datasets.CIFAR100(data_target_dir, train=False, transform=test_transform, download=True)
        num_classes = 100
    elif dataset == 'svhn':
        train_data = datasets.SVHN(data_target_dir, split='train', transform=train_transform, download=True)
        test_data = datasets.SVHN(data_target_dir, split='test', transform=test_transform, download=True)
        num_classes = 10
    elif dataset == 'mnist':
        train_data = datasets.MNIST(data_target_dir, train=True, transform=train_transform, download=True)
        test_data = datasets.MNIST(data_target_dir, train=False, transform=test_transform, download=True)
        num_classes = 10
    #print ('svhn', train_data.labels.shape)
    elif dataset == 'stl10':
        train_data = datasets.STL10(data_target_dir, split='train', transform=train_transform, download=True)
        test_data = datasets.STL10(data_target_dir, split='test', transform=test_transform, download=True)
        num_classes = 10
    elif dataset == 'tiny-imagenet-200':
        train_root = os.path.join(data_target_dir, 'train')  # this is path to training images folder
        validation_root = os.path.join(data_target_dir, 'val/images')  # this is path to validation images folder
        train_data = datasets.ImageFolder(train_root, transform=train_transform)
        test_data = datasets.ImageFolder(validation_root,transform=test_transform)
        num_classes = 200
    elif dataset == 'imagenet':
        assert False, 'Do not finish imagenet code'
    else:
        assert False, 'Do not support dataset : {}'.format(dataset)

        
    n_labels = num_classes
    
    def get_sampler(labels, n=None, n_valid= None):
        # Only choose digits in n_labels
        # n = number of labels per class for training
        # n_val = number of lables per class for validation
        #print type(labels)
        #print (n_valid)
        (indices,) = np.where(reduce(__or__, [labels == i for i in np.arange(n_labels)]))
        # Ensure uniform distribution of labels
        np.random.shuffle(indices)
        
        indices_valid = np.hstack([list(filter(lambda idx: labels[idx] == i, indices))[:n_valid] for i in range(n_labels)])
        indices_train = np.hstack([list(filter(lambda idx: labels[idx] == i, indices))[n_valid:n_valid+n] for i in range(n_labels)])
        indices_unlabelled = np.hstack([list(filter(lambda idx: labels[idx] == i, indices))[:] for i in range(n_labels)])
        #import pdb; pdb.set_trace()
        #print (indices_train.shape)
        #print (indices_valid.shape)
        #print (indices_unlabelled.shape)
        indices_train = torch.from_numpy(indices_train)
        indices_valid = torch.from_numpy(indices_valid)
        indices_unlabelled = torch.from_numpy(indices_unlabelled)
        sampler_train = SubsetRandomSampler(indices_train)
        sampler_valid = SubsetRandomSampler(indices_valid)
        sampler_unlabelled = SubsetRandomSampler(indices_unlabelled)
        return sampler_train, sampler_valid, sampler_unlabelled
    
    #print type(train_data.train_labels)
    
    # Dataloaders for MNIST
    if dataset == 'svhn':
        train_sampler, valid_sampler, unlabelled_sampler = get_sampler(train_data.labels, labels_per_class, valid_labels_per_class)
    elif dataset == 'mnist':
        train_sampler, valid_sampler, unlabelled_sampler = get_sampler(train_data.train_labels.numpy(), labels_per_class, valid_labels_per_class)
    elif dataset == 'tiny-imagenet-200':
        pass
    else: 
        train_sampler, valid_sampler, unlabelled_sampler = get_sampler(train_data.targets, labels_per_class, valid_labels_per_class)

    if dataset == 'tiny-imagenet-200':
        labelled = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)
        validation = None
        unlabelled = None
        test = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=workers, pin_memory=True)
    else:
        labelled = torch.utils.data.DataLoader(train_data, batch_size=batch_size, sampler = train_sampler, shuffle=False, num_workers=workers, pin_memory=True)
        validation = torch.utils.data.DataLoader(train_data, batch_size=batch_size, sampler = valid_sampler, shuffle=False, num_workers=workers, pin_memory=True)
        unlabelled = torch.utils.data.DataLoader(train_data, batch_size=batch_size, sampler = unlabelled_sampler, shuffle=False, num_workers=workers, pin_memory=True)
        test = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=workers, pin_memory=True)

    return labelled, validation, unlabelled, test, num_classes
def main():
  # Init logger
  
  if not os.path.isdir(args.save_path):
    os.makedirs(args.save_path)
  log = open(os.path.join(args.save_path, 'log_seed_{}.txt'.format(args.manualSeed)), 'w')
  print_log('save path : {}'.format(args.save_path), log)
  state = {k: v for k, v in args._get_kwargs()}
  print_log(state, log)
  print_log("Random Seed: {}".format(args.manualSeed), log)
  print_log("python version : {}".format(sys.version.replace('\n', ' ')), log)
  print_log("torch  version : {}".format(torch.__version__), log)
  print_log("cudnn  version : {}".format(torch.backends.cudnn.version()), log)

  # Init dataset
  
  if not os.path.exists(args.data_path):
    os.makedirs(args.data_path)

  if args.dataset == 'cifar10':
    mean = [x / 255 for x in [125.3, 123.0, 113.9]]
    std = [x / 255 for x in [63.0, 62.1, 66.7]]
  elif args.dataset == 'cifar100':
    mean = [x / 255 for x in [129.3, 124.1, 112.4]]
    std = [x / 255 for x in [68.2, 65.4, 70.4]]
  else:
    assert False, "Unknow dataset : {}".format(args.dataset)

  train_transform = transforms.Compose(
    [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(),
     transforms.Normalize(mean, std)])
  test_transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize(mean, std)])

  if args.dataset == 'cifar10':
    train_data = dset.CIFAR10(args.data_path, train=True, transform=train_transform, download=True)
    test_data = dset.CIFAR10(args.data_path, train=False, transform=test_transform, download=True)
    num_classes = 10
  elif args.dataset == 'cifar100':
    train_data = dset.CIFAR100(args.data_path, train=True, transform=train_transform, download=True)
    test_data = dset.CIFAR100(args.data_path, train=False, transform=test_transform, download=True)
    num_classes = 100
  elif args.dataset == 'svhn':
    train_data = dset.SVHN(args.data_path, split='train', transform=train_transform, download=True)
    test_data = dset.SVHN(args.data_path, split='test', transform=test_transform, download=True)
    num_classes = 10
  elif args.dataset == 'stl10':
    train_data = dset.STL10(args.data_path, split='train', transform=train_transform, download=True)
    test_data = dset.STL10(args.data_path, split='test', transform=test_transform, download=True)
    num_classes = 10
  elif args.dataset == 'imagenet':
    assert False, 'Do not finish imagenet code'
  else:
    assert False, 'Do not support dataset : {}'.format(args.dataset)

  train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True,
                         num_workers=args.workers, pin_memory=True)
  test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False,
                        num_workers=args.workers, pin_memory=True)

  print_log("=> creating model '{}'".format(args.arch), log)
  # Init model, criterion, and optimizer
  net = models.__dict__[args.arch](num_classes)
  print_log("=> network :\n {}".format(net), log)

  net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))

  # define loss function (criterion) and optimizer
  criterion = torch.nn.CrossEntropyLoss()
  optimizer = torch.optim.SGD(net.parameters(), state['learning_rate'], momentum=state['momentum'],
                weight_decay=state['decay'], nesterov=False)
  #optimizer = AccSGD(net.parameters(), state['learning_rate'], kappa = 1000.0, xi = 10.0)
  #optimizer.zero_grad()
  #loss_fn(model(input), target).backward()
 # optimizer.step()              
  if args.use_cuda:
    net.cuda()
    criterion.cuda()

  recorder = RecorderMeter(args.epochs)
  # optionally resume from a checkpoint
  if args.resume:
    if os.path.isfile(args.resume):
      print_log("=> loading checkpoint '{}'".format(args.resume), log)
      checkpoint = torch.load(args.resume)
      recorder = checkpoint['recorder']
      args.start_epoch = checkpoint['epoch']
      net.load_state_dict(checkpoint['state_dict'])
      optimizer.load_state_dict(checkpoint['optimizer'])
      print_log("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch']), log)
    else:
      raise ValueError("=> no checkpoint found at '{}'".format(args.resume))
  else:
    print_log("=> do not use any checkpoint for {} model".format(args.arch), log)

  if args.evaluate:
    validate(test_loader, net, criterion, log)
    return

  # Main loop
  start_time = time.time()
  epoch_time = AverageMeter()
  train_cc=0
  loss_e=0
  
  for epoch in range(args.start_epoch, args.epochs):
    current_learning_rate = adjust_learning_rate(optimizer, epoch, args.gammas, args.schedule,loss_e)
    optimizer = torch.optim.SGD(net.parameters(), current_learning_rate, momentum=state['momentum'],
                weight_decay=state['decay'], nesterov=False)
    need_hour, need_mins, need_secs = convert_secs2time(epoch_time.avg * (args.epochs-epoch))
    need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs)

    print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.format(time_string(), epoch, args.epochs, need_time, current_learning_rate) \
                + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False), 100-recorder.max_accuracy(False)), log)

    # train for one epoch
    train_acc, train_los,loss_e = train(train_loader, net, criterion, optimizer, epoch, log,train_cc)

    # evaluate on validation set
    #val_acc,   val_los   = extract_features(test_loader, net, criterion, log)
    val_acc,   val_los   = validate(test_loader, net, criterion, log)
    is_best = recorder.update(epoch, train_los, train_acc, val_los, val_acc)

    save_checkpoint({
      'epoch': epoch + 1,
      'arch': args.arch,
      'state_dict': net.state_dict(),
      'recorder': recorder,
      'optimizer' : optimizer.state_dict(),
      'args'      : copy.deepcopy(args),
    }, is_best, args.save_path, 'sgd8_40lcheck.pth.tar')

    # measure elapsed time
    epoch_time.update(time.time() - start_time)
    start_time = time.time()
    recorder.plot_curve( os.path.join(args.save_path, 'sgd8_40l.png') )

  log.close()
def load_data_subset_unpre(data_aug, batch_size,workers,dataset, data_target_dir, labels_per_class=100, valid_labels_per_class = 500):
    ## loads the data without any preprocessing##
    import numpy as np
    from functools import reduce
    from operator import __or__
    from torch.utils.data.sampler import SubsetRandomSampler
    
    """
    def per_image_standarize(x):
        mean = x.mean()
        std = x.std()
        adjusted_std = torch.max(std, 1.0/torch.sqrt(torch.FloatTensor([x.shape[0]*x.shape[1]*x.shape[2]])))
        standarized_input = (x- mean)/ adjusted_std
        return standarized_input
    """   
    if data_aug==1:
        print ('data aug')
        if dataset == 'svhn':
            train_transform = transforms.Compose(
                                            [transforms.RandomCrop(32, padding=2),
                                            transforms.ToTensor(), 
                                            transforms.Lambda(lambda x : x.mul(255))
                                            ])
        else:    
            train_transform = transforms.Compose(
                                                [transforms.RandomHorizontalFlip(),
                                                transforms.RandomCrop(32, padding=2),
                                                transforms.ToTensor(), 
                                                transforms.Lambda(lambda x : x.mul(255))
                                                ])
        test_transform = transforms.Compose(
                                            [transforms.ToTensor(),
                                             transforms.Lambda(lambda x : x.mul(255))])
    else:
        print ('no data aug')
        train_transform = transforms.Compose(
                                            [transforms.ToTensor(),
                                             transforms.Lambda(lambda x : x.mul(255))
                                            ])
        test_transform = transforms.Compose(
                                            [transforms.ToTensor(),
                                             transforms.Lambda(lambda x : x.mul(255))])
    
    if dataset == 'cifar10':
        train_data = datasets.CIFAR10(data_target_dir, train=True, transform=train_transform, download=True)
        test_data = datasets.CIFAR10(data_target_dir, train=False, transform=test_transform, download=True)
        num_classes = 10
    elif dataset == 'cifar100':
        train_data = datasets.CIFAR100(data_target_dir, train=True, transform=train_transform, download=True)
        test_data = datasets.CIFAR100(data_target_dir, train=False, transform=test_transform, download=True)
        num_classes = 100
    elif dataset == 'svhn':
        train_data = datasets.SVHN(data_target_dir, split='train', transform=train_transform, download=True)
        test_data = datasets.SVHN(data_target_dir, split='test', transform=test_transform, download=True)
        num_classes = 10
    elif dataset == 'mnist':
        train_data = datasets.MNIST(data_target_dir, train=True, transform=train_transform, download=True)
        test_data = datasets.MNIST(data_target_dir, train=False, transform=test_transform, download=True)
        num_classes = 10
    #print ('svhn', train_data.labels.shape)
    elif dataset == 'stl10':
        train_data = datasets.STL10(data_target_dir, split='train', transform=train_transform, download=True)
        test_data = datasets.STL10(data_target_dir, split='test', transform=test_transform, download=True)
        num_classes = 10
    elif dataset == 'imagenet':
        assert False, 'Do not finish imagenet code'
    else:
        assert False, 'Do not support dataset : {}'.format(dataset)
        
    n_labels = num_classes
    
    def get_sampler(labels, n=None, n_valid= None):
        # Only choose digits in n_labels
        # n = number of labels per class for training
        # n_val = number of lables per class for validation
        #print type(labels)
        #print (n_valid)
        (indices,) = np.where(reduce(__or__, [labels == i for i in np.arange(n_labels)]))
        # Ensure uniform distribution of labels
        np.random.shuffle(indices)
        
        indices_valid = np.hstack([list(filter(lambda idx: labels[idx] == i, indices))[:n_valid] for i in range(n_labels)])
        indices_train = np.hstack([list(filter(lambda idx: labels[idx] == i, indices))[n_valid:n_valid+n] for i in range(n_labels)])
        indices_unlabelled = np.hstack([list(filter(lambda idx: labels[idx] == i, indices))[:] for i in range(n_labels)])
        #print (indices_train.shape)
        #print (indices_valid.shape)
        #print (indices_unlabelled.shape)
        indices_train = torch.from_numpy(indices_train)
        indices_valid = torch.from_numpy(indices_valid)
        indices_unlabelled = torch.from_numpy(indices_unlabelled)
        sampler_train = SubsetRandomSampler(indices_train)
        sampler_valid = SubsetRandomSampler(indices_valid)
        sampler_unlabelled = SubsetRandomSampler(indices_unlabelled)
        return sampler_train, sampler_valid, sampler_unlabelled
    
    #print type(train_data.train_labels)
    
    # Dataloaders for MNIST
    if dataset == 'svhn':
        train_sampler, valid_sampler, unlabelled_sampler = get_sampler(train_data.labels, labels_per_class, valid_labels_per_class)
    elif dataset == 'mnist':
        train_sampler, valid_sampler, unlabelled_sampler = get_sampler(train_data.train_labels.numpy(), labels_per_class, valid_labels_per_class)
    else: 
        train_sampler, valid_sampler, unlabelled_sampler = get_sampler(train_data.targets, labels_per_class, valid_labels_per_class)
    
    labelled = torch.utils.data.DataLoader(train_data, batch_size=batch_size, sampler = train_sampler, shuffle=False, num_workers=workers, pin_memory=True)
    validation = torch.utils.data.DataLoader(train_data, batch_size=batch_size, sampler = valid_sampler, shuffle=False, num_workers=workers, pin_memory=True)
    unlabelled = torch.utils.data.DataLoader(train_data, batch_size=batch_size, sampler = unlabelled_sampler,shuffle=False,  num_workers=workers, pin_memory=True)
    test = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=workers, pin_memory=True)

    return labelled, validation, unlabelled, test, num_classes
示例#21
0
def get_dataloader(dataset_name, split, batch_size, \
                   add_split = None, shuffle = True, ratio=-1, num_workers=4):

    print('[%s] Loading %s-%s from %s' %
          (datetime.now(), split, add_split, dataset_name))

    if dataset_name == 'MNIST':

        data_root_list = []

        for data_root in data_root_list:
            if os.path.exists(data_root):
                print('Found %s in %s' % (dataset_name, data_root))
                break

        normalize = transforms.Normalize((0.1307, ), (0.3081, ))
        # if split == 'train':
        MNIST_transform = transforms.Compose(
            [transforms.Resize(32),
             transforms.ToTensor(), normalize])
        # else:
        dataset = MNIST_utils.MNIST(root=data_root,
                                    train=True if split == 'train' else False,
                                    add_split=add_split,
                                    download=False,
                                    transform=MNIST_transform,
                                    ratio=ratio)
        loader = torch.utils.data.DataLoader(dataset,
                                             batch_size=batch_size,
                                             shuffle=shuffle,
                                             num_workers=2)

    elif dataset_name == 'SVHN':

        data_root_list = []

        for data_root in data_root_list:
            if os.path.exists(data_root):
                print('Found %s in %s' % (dataset_name, data_root))
                break
        normalize = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

        if split == 'train':
            trainset = SVHN_utils.SVHN(
                root=data_root,
                split='train',
                add_split=add_split,
                download=True,
                transform=transforms.Compose([
                    # transforms.RandomCrop(32, padding=4),
                    # transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    normalize
                ]),
                ratio=ratio)
            loader = torch.utils.data.DataLoader(trainset,
                                                 batch_size=batch_size,
                                                 shuffle=shuffle,
                                                 num_workers=2)
        elif split == 'test':
            testset = SVHN_utils.SVHN(root=data_root,
                                      split='test',
                                      download=True,
                                      transform=transforms.Compose(
                                          [transforms.ToTensor(), normalize]))
            loader = torch.utils.data.DataLoader(testset,
                                                 batch_size=batch_size,
                                                 shuffle=shuffle,
                                                 num_workers=2)

    elif dataset_name in ['CIFAR10', 'cifar10']:

        data_root_list = [
            '/home/shangyu/datasets/CIFAR10', '/data/datasets/CIFAR10',
            '/home/sinno/datasets/CIFAR10',
            '/Users/shangyu/Documents/datasets/CIFAR10'
        ]

        for data_root in data_root_list:
            if os.path.exists(data_root):
                print('Found %s in %s' % (dataset_name, data_root))
                break

        if split == 'train':

            transform_train = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465),
                                     (0.2023, 0.1994, 0.2010)),
            ])

            trainset = CIFAR10_utils.CIFAR10(root=data_root,
                                             train=True,
                                             download=True,
                                             transform=transform_train,
                                             ratio=ratio)
            print('Number of training instances used: %d' % (len(trainset)))
            loader = torch.utils.data.DataLoader(trainset,
                                                 batch_size=batch_size,
                                                 shuffle=shuffle,
                                                 num_workers=2)

        elif split == 'test' or split == 'val':
            transform_test = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465),
                                     (0.2023, 0.1994, 0.2010)),
            ])
            testset = torchvision.datasets.CIFAR10(root=data_root,
                                                   train=False,
                                                   download=True,
                                                   transform=transform_test)
            loader = torch.utils.data.DataLoader(testset,
                                                 batch_size=batch_size,
                                                 shuffle=shuffle,
                                                 num_workers=2)

    elif dataset_name == 'CIFAR100':

        data_root_list = []
        for data_root in data_root_list:
            if os.path.exists(data_root):
                break
        normalize = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

        if split == 'train':
            transform_train = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ])
            trainset = CIFAR10_utils.CIFAR100(root=data_root,
                                              train=True,
                                              download=True,
                                              transform=transform_train,
                                              ratio=ratio)
            print('Number of training instances used: %d' % (len(trainset)))
            loader = torch.utils.data.DataLoader(trainset,
                                                 batch_size=batch_size,
                                                 shuffle=shuffle,
                                                 num_workers=2)

        elif split == 'test' or split == 'val':
            transform_test = transforms.Compose([
                transforms.ToTensor(),
                normalize,
            ])
            testset = torchvision.datasets.CIFAR100(root=data_root,
                                                    train=False,
                                                    download=True,
                                                    transform=transform_test)
            loader = torch.utils.data.DataLoader(testset,
                                                 batch_size=batch_size,
                                                 shuffle=shuffle,
                                                 num_workers=2)

    elif dataset_name == 'STL10':

        data_root_list = []
        for data_root in data_root_list:
            if os.path.exists(data_root):
                print('Found STL10 in %s' % data_root)
                break

        if split == 'train':
            loader = torch.utils.data.DataLoader(datasets.STL10(
                root=data_root,
                split='train',
                download=True,
                transform=transforms.Compose([
                    transforms.Pad(4),
                    transforms.RandomCrop(96),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ])),
                                                 batch_size=batch_size,
                                                 shuffle=True)

        if split in ['test', 'val']:
            loader = torch.utils.data.DataLoader(datasets.STL10(
                root=data_root,
                split='test',
                download=True,
                transform=transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ])),
                                                 batch_size=batch_size,
                                                 shuffle=False)

    elif dataset_name == 'ImageNet':
        data_root_list = []
        for data_root in data_root_list:
            if os.path.exists(data_root):
                break
        traindir = ('../train_imagenet_list.pkl', '../classes.pkl',
                    '../classes-to-idx.pkl', '%s/train' % data_root)
        valdir = ('../val_imagenet_list.pkl', '../classes.pkl',
                  '../classes-to-idx.pkl', '%s/val-pytorch' % data_root)
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        if split == 'train':
            trainDataset = imagenet_utils.ImageFolder(
                traindir,
                transforms.Compose([
                    transforms.RandomResizedCrop(224),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    normalize,
                ]),
                ratio=ratio)
            print('Number of training data used: %d' % (len(trainDataset)))
            loader = torch.utils.data.DataLoader(trainDataset, batch_size=batch_size, shuffle=True, \
                                                 num_workers = num_workers, pin_memory=True)

        elif split == 'val' or split == 'test':
            valDataset = imagenet_utils.ImageFolder(
                valdir,
                transforms.Compose([
                    transforms.Resize(256),
                    transforms.CenterCrop(224),
                    transforms.ToTensor(),
                    normalize,
                ]))
            loader = torch.utils.data.DataLoader(valDataset, batch_size=batch_size, shuffle=True, \
                                                 num_workers = num_workers, pin_memory=True)

    else:
        raise (
            'Your dataset is not implemented in this project, please write it by your own'
        )

    print ('[DATA LOADING] Loading from %s-%s-%s finish. Number of images: %d, Number of batches: %d' \
           %(dataset_name, split, add_split, len(loader.dataset), len(loader)))

    return loader
def load_data(data_aug, batch_size,workers,dataset, data_target_dir):
    
    if dataset == 'cifar10':
        mean = [x / 255 for x in [125.3, 123.0, 113.9]]
        std = [x / 255 for x in [63.0, 62.1, 66.7]]
    elif dataset == 'cifar100':
        mean = [x / 255 for x in [129.3, 124.1, 112.4]]
        std = [x / 255 for x in [68.2, 65.4, 70.4]]
        
    elif dataset == 'svhn':
        mean = [x / 255 for x in [127.5, 127.5, 127.5]]
        std = [x / 255 for x in [127.5, 127.5, 127.5]]
    else:
        assert False, "Unknow dataset : {}".format(dataset)
    
    if data_aug==1:
        if dataset == 'svhn':
            train_transform = transforms.Compose(
                                             [ transforms.RandomCrop(32, padding=2), transforms.ToTensor(),
                                              transforms.Normalize(mean, std)])
            test_transform = transforms.Compose(
                                            [transforms.ToTensor(), transforms.Normalize(mean, std)])
        else:
            train_transform = transforms.Compose(
                                                 [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(),
                                                  transforms.Normalize(mean, std)])
            test_transform = transforms.Compose(
                                                [transforms.ToTensor(), transforms.Normalize(mean, std)])
    else:
        train_transform = transforms.Compose(
                                             [ transforms.ToTensor(),
                                              transforms.Normalize(mean, std)])
        test_transform = transforms.Compose(
                                            [transforms.ToTensor(), transforms.Normalize(mean, std)])
    if dataset == 'cifar10':
        train_data = datasets.CIFAR10(data_target_dir, train=True, transform=train_transform, download=True)
        test_data = datasets.CIFAR10(data_target_dir, train=False, transform=test_transform, download=True)
        num_classes = 10
    elif dataset == 'cifar100':
        train_data = datasets.CIFAR100(data_target_dir, train=True, transform=train_transform, download=True)
        test_data = datasets.CIFAR100(data_target_dir, train=False, transform=test_transform, download=True)
        num_classes = 100
    elif dataset == 'svhn':
        train_data = datasets.SVHN(data_target_dir, split='train', transform=train_transform, download=True)
        test_data = datasets.SVHN(data_target_dir, split='test', transform=test_transform, download=True)
        num_classes = 10
    elif dataset == 'stl10':
        train_data = datasets.STL10(data_target_dir, split='train', transform=train_transform, download=True)
        test_data = datasets.STL10(data_target_dir, split='test', transform=test_transform, download=True)
        num_classes = 10
    elif dataset == 'imagenet':
        assert False, 'Do not finish imagenet code'
    else:
        assert False, 'Do not support dataset : {}'.format(dataset)

    train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True,
                         num_workers=workers, pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False,
                        num_workers=workers, pin_memory=True)
    
    return train_loader, test_loader, num_classes
示例#23
0
                                  transform=img_transform_source,
                                  download=True)

from modify_cifar_stl import modify_cifar

modify_cifar(dataset_source)

dataloader_source = torch.utils.data.DataLoader(dataset=dataset_source,
                                                batch_size=batch_size,
                                                shuffle=True,
                                                num_workers=0)

train_list = os.path.join(target_image_root, 'svhn_train_labels.txt')

dataset_target = datasets.STL10(root='dataset',
                                transform=img_transform_target,
                                download=True)

from modify_cifar_stl import modify_stl

modify_stl(dataset_target)

dataloader_target = torch.utils.data.DataLoader(dataset=dataset_target,
                                                batch_size=batch_size,
                                                shuffle=True,
                                                num_workers=0)

# load model

my_net = CNNModel()
def stl10():
    return collect_download_configs(
        lambda: datasets.STL10(ROOT, download=True),
        name="STL10",
    )
示例#25
0
def _get_stl10(root, split, transform, target_transform, download):
    return datasets.STL10(root=root,
                          split=split,
                          transform=transform,
                          target_transform=target_transform,
                          download=download)
示例#26
0
def train():
    if FLAGS.dataset == 'cifar10':
        dataset = datasets.CIFAR10(
            './data',
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                transforms.Lambda(lambda x: x + torch.rand_like(x) / 128)
            ]))
    if FLAGS.dataset == 'stl10':
        dataset = datasets.STL10(
            './data',
            split='unlabeled',
            download=True,
            transform=transforms.Compose([
                transforms.Resize((48, 48)),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                transforms.Lambda(lambda x: x + torch.rand_like(x) / 128)
            ]))

    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=FLAGS.batch_size,
                                             shuffle=True,
                                             num_workers=4,
                                             drop_last=True)

    net_G = net_G_models[FLAGS.arch](FLAGS.z_dim).to(device)
    net_D = net_D_models[FLAGS.arch]().to(device)
    loss_fn = loss_fns[FLAGS.loss]()

    optim_G = optim.Adam(net_G.parameters(), lr=FLAGS.lr_G, betas=FLAGS.betas)
    optim_D = optim.Adam(net_D.parameters(), lr=FLAGS.lr_D, betas=FLAGS.betas)
    sched_G = optim.lr_scheduler.LambdaLR(
        optim_G, lambda step: 1 - step / FLAGS.total_steps)
    sched_D = optim.lr_scheduler.LambdaLR(
        optim_D, lambda step: 1 - step / FLAGS.total_steps)

    os.makedirs(os.path.join(FLAGS.logdir, 'sample'))
    writer = SummaryWriter(os.path.join(FLAGS.logdir))
    sample_z = torch.randn(FLAGS.sample_size, FLAGS.z_dim).to(device)
    with open(os.path.join(FLAGS.logdir, "flagfile.txt"), 'w') as f:
        f.write(FLAGS.flags_into_string())
    writer.add_text("flagfile",
                    FLAGS.flags_into_string().replace('\n', '  \n'))

    real, _ = next(iter(dataloader))
    grid = (make_grid(real[:FLAGS.sample_size]) + 1) / 2
    writer.add_image('real_sample', grid)

    looper = infiniteloop(dataloader)
    with trange(1, FLAGS.total_steps + 1, dynamic_ncols=True) as pbar:
        for step in pbar:
            # Discriminator
            for _ in range(FLAGS.n_dis):
                with torch.no_grad():
                    z = torch.randn(FLAGS.batch_size, FLAGS.z_dim).to(device)
                    fake = net_G(z).detach()
                real = next(looper).to(device)
                net_D_real = net_D(real)
                net_D_fake = net_D(fake)
                loss = loss_fn(net_D_real, net_D_fake)

                optim_D.zero_grad()
                loss.backward()
                optim_D.step()

                if FLAGS.loss == 'was':
                    loss = -loss
                pbar.set_postfix(loss='%.4f' % loss)
            writer.add_scalar('loss', loss, step)

            # Generator
            z = torch.randn(FLAGS.batch_size * 2, FLAGS.z_dim).to(device)
            loss = loss_fn(net_D(net_G(z)))

            optim_G.zero_grad()
            loss.backward()
            optim_G.step()

            sched_G.step()
            sched_D.step()
            pbar.update(1)

            if step == 1 or step % FLAGS.sample_step == 0:
                fake = net_G(sample_z).cpu()
                grid = (make_grid(fake) + 1) / 2
                writer.add_image('sample', grid, step)
                save_image(
                    grid, os.path.join(FLAGS.logdir, 'sample',
                                       '%d.png' % step))

            if step == 1 or step % FLAGS.eval_step == 0:
                torch.save(
                    {
                        'net_G': net_G.state_dict(),
                        'net_D': net_D.state_dict(),
                        'optim_G': optim_G.state_dict(),
                        'optim_D': optim_D.state_dict(),
                        'sched_G': sched_G.state_dict(),
                        'sched_D': sched_D.state_dict(),
                    }, os.path.join(FLAGS.logdir, 'model.pt'))
                if FLAGS.record:
                    imgs = generate_imgs(net_G, device, FLAGS.z_dim, 50000,
                                         FLAGS.batch_size)
                    is_score, fid_score = get_inception_and_fid_score(
                        imgs, device, FLAGS.fid_cache, verbose=True)
                    pbar.write("%s/%s Inception Score: %.3f(%.5f), "
                               "FID Score: %6.3f" %
                               (step, FLAGS.total_steps, is_score[0],
                                is_score[1], fid_score))
                    writer.add_scalar('inception_score', is_score[0], step)
                    writer.add_scalar('inception_score_std', is_score[1], step)
                    writer.add_scalar('fid_score', fid_score, step)
    writer.close()
def main():
    # Init logger
    if not os.path.isdir(args.save_path):
        os.makedirs(args.save_path)
    log = open(os.path.join(args.save_path, 'log_seed_{}.txt'.format(args.manualSeed)), 'w')
    print_log('save path : {}'.format(args.save_path), log)
    state = {k: v for k, v in args._get_kwargs()}
    print_log(state, log)
    print_log("Random Seed: {}".format(args.manualSeed), log)
    print_log("python version : {}".format(sys.version.replace('\n', ' ')), log)
    print_log("torch  version : {}".format(torch.__version__), log)
    print_log("cudnn  version : {}".format(torch.backends.cudnn.version()), log)
    print_log("Compress Rate: {}".format(args.rate), log)
    print_log("Layer Begin: {}".format(args.layer_begin), log)
    print_log("Layer End: {}".format(args.layer_end), log)
    print_log("Layer Inter: {}".format(args.layer_inter), log)
    print_log("Epoch prune: {}".format(args.epoch_prune), log)
    # Init dataset
    if not os.path.isdir(args.data_path):
        os.makedirs(args.data_path)

    if args.dataset == 'cifar10':
        mean = [x / 255 for x in [125.3, 123.0, 113.9]]
        std = [x / 255 for x in [63.0, 62.1, 66.7]]
    elif args.dataset == 'cifar100':
        mean = [x / 255 for x in [129.3, 124.1, 112.4]]
        std = [x / 255 for x in [68.2, 65.4, 70.4]]
    else:
        assert False, "Unknow dataset : {}".format(args.dataset)

    train_transform = transforms.Compose(
        [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(),
         transforms.Normalize(mean, std)])
    test_transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize(mean, std)])

    if args.dataset == 'cifar10':
        train_data = dset.CIFAR10(args.data_path, train=True, transform=train_transform, download=True)
        test_data = dset.CIFAR10(args.data_path, train=False, transform=test_transform, download=True)
        num_classes = 10
    elif args.dataset == 'cifar100':
        train_data = dset.CIFAR100(args.data_path, train=True, transform=train_transform, download=True)
        test_data = dset.CIFAR100(args.data_path, train=False, transform=test_transform, download=True)
        num_classes = 100
    elif args.dataset == 'svhn':
        train_data = dset.SVHN(args.data_path, split='train', transform=train_transform, download=True)
        test_data = dset.SVHN(args.data_path, split='test', transform=test_transform, download=True)
        num_classes = 10
    elif args.dataset == 'stl10':
        train_data = dset.STL10(args.data_path, split='train', transform=train_transform, download=True)
        test_data = dset.STL10(args.data_path, split='test', transform=test_transform, download=True)
        num_classes = 10
    elif args.dataset == 'imagenet':
        assert False, 'Do not finish imagenet code'
    else:
        assert False, 'Do not support dataset : {}'.format(args.dataset)

    train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True,
                                                 num_workers=args.workers, pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False,
                                                num_workers=args.workers, pin_memory=True)

    print_log("=> creating model '{}'".format(args.arch), log)
    # Init model, criterion, and optimizer
    net = models.__dict__[args.arch](num_classes)
    print_log("=> network :\n {}".format(net), log)

    net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))



    # define loss function (criterion) and optimizer
    criterion = torch.nn.CrossEntropyLoss()

    optimizer = torch.optim.SGD(net.parameters(), state['learning_rate'], momentum=state['momentum'],
                                weight_decay=state['decay'], nesterov=True)

    if args.use_cuda:
        net.cuda()
        criterion.cuda()

    recorder = RecorderMeter(args.epochs)
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print_log("=> loading checkpoint '{}'".format(args.resume), log)
            checkpoint = torch.load(args.resume)
            recorder = checkpoint['recorder']
            args.start_epoch = checkpoint['epoch']
            if args.use_state_dict:
                net.load_state_dict(checkpoint['state_dict'])
            else:
                net = checkpoint['state_dict']

            optimizer.load_state_dict(checkpoint['optimizer'])
            print_log("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch']), log)
        else:
            print_log("=> no checkpoint found at '{}'".format(args.resume), log)
    else:
        print_log("=> do not use any checkpoint for {} model".format(args.arch), log)

    if args.evaluate:
        time1 = time.time()
        validate(test_loader, net, criterion, log)
        time2 = time.time()
        print ('function took %0.3f ms' % ((time2-time1)*1000.0))
        return

    m=Mask(net)

    m.init_length()

    comp_rate =  args.rate
    print("-"*10+"one epoch begin"+"-"*10)
    print("the compression rate now is %f" % comp_rate)

    val_acc_1,   val_los_1   = validate(test_loader, net, criterion, log)

    print(" accu before is: %.3f %%" % val_acc_1)

    m.model = net

    #m.init_mask(comp_rate)
#    m.if_zero()
    #m.do_mask()
    net = m.model
#    m.if_zero()
    if args.use_cuda:
        net = net.cuda()
    #val_acc_2,   val_los_2   = validate(test_loader, net, criterion, log)
    #print(" accu after is: %s %%" % val_acc_2)


    # Main loop
    start_time = time.time()
    epoch_time = AverageMeter()
    for epoch in range(args.start_epoch, args.epochs):
        current_learning_rate = adjust_learning_rate(optimizer, epoch, args.gammas, args.schedule)

        need_hour, need_mins, need_secs = convert_secs2time(epoch_time.avg * (args.epochs-epoch))
        need_time = '[Need: {:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs)

        print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [learning_rate={:6.4f}]'.format(time_string(), epoch, args.epochs, need_time, current_learning_rate) \
                                + ' [Best : Accuracy={:.2f}, Error={:.2f}]'.format(recorder.max_accuracy(False), 100-recorder.max_accuracy(False)), log)

        # train for one epoch
        train_acc, train_los = train(train_loader, net, criterion, optimizer, epoch, log)

        # evaluate on validation set
        val_acc_1,   val_los_1   = validate(test_loader, net, criterion, log)
        if (epoch % args.epoch_prune ==0 or epoch == args.epochs-1):
            m.model = net
            m.if_zero()
            m.init_mask(comp_rate)
            m.do_mask()
            m.if_zero()
            net = m.model
            if args.use_cuda:
                net = net.cuda()

        val_acc_2,   val_los_2   = validate(test_loader, net, criterion, log)


        is_best = recorder.update(epoch, train_los, train_acc, val_los_2, val_acc_2)

        save_checkpoint({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': net,
            'recorder': recorder,
            'optimizer' : optimizer.state_dict(),
        }, is_best, args.save_path, 'checkpoint.pth.tar')

        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()
        #recorder.plot_curve( os.path.join(args.save_path, 'curve.png') )

    log.close()
示例#28
0
def main():
    config = yaml.load(open("./config/config.yaml", "r"),
                       Loader=yaml.FullLoader)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Training with: {device}")

    data_transform = get_simclr_data_transforms(**config['data_transforms'])
    data_transform2 = get_simclr_data_transforms(**config['data_transforms'],
                                                 blur=1.)

    # data_transform = get_simclr_data_transforms_randAugment(config['data_transforms']['input_shape'])
    # data_transform2 = get_simclr_data_transforms_randAugment(config['data_transforms']['input_shape'])

    train_dataset = datasets.STL10('/media/snowflake/Data/',
                                   split='train+unlabeled',
                                   download=True,
                                   transform=MultiViewDataInjector(
                                       [data_transform, data_transform2]))
    # train_dataset = STL(["/home/snowflake/Descargas/STL_data/unlabeled_images",
    #                      "/home/snowflake/Descargas/STL_data/train_images"],
    #                       transform=MultiViewDataInjector([data_transform, data_transform2]))

    # online network (the one that is trained)
    online_network = ResNet(**config['network']).to(device)
    # online_network = MLPmixer(**config['network']).to(device)

    # target encoder
    # target_network = ResNet_BN_mom(**config['network']).to(device)
    target_network = ResNet(**config['network']).to(device)
    # target_network = MLPmixer(**config['network']).to(device)

    pretrained_folder = config['network']['fine_tune_from']

    # load pre-trained model if defined
    if pretrained_folder:
        try:
            checkpoints_folder = os.path.join('./runs', pretrained_folder,
                                              'checkpoints')

            # load pre-trained parameters
            load_params = torch.load(
                os.path.join(os.path.join(checkpoints_folder, 'model.pth')),
                map_location=torch.device(torch.device(device)))

            online_network.load_state_dict(
                load_params['online_network_state_dict'])
            target_network.load_state_dict(
                load_params['target_network_state_dict'])

        except FileNotFoundError:
            print("Pre-trained weights not found. Training from scratch.")

    # predictor network
    predictor = MLPHead(
        in_channels=online_network.projetion.net[-1].out_features,
        **config['network']['projection_head']).to(device)

    optimizer = torch.optim.SGD(
        list(online_network.parameters()) + list(predictor.parameters()),
        **config['optimizer']['params'])

    trainer = BYOLTrainer(online_network=online_network,
                          target_network=target_network,
                          optimizer=optimizer,
                          predictor=predictor,
                          device=device,
                          **config['trainer'])

    trainer.train(train_dataset)
示例#29
0
optimizer = torch.optim.SGD(model.parameters(),
                            starting_lr,
                            momentum=0.9,
                            weight_decay=1e-4)

print("Initialize Dataloaders...")
# Define the transform for the data. Notice, we must resize to 224x224 with this dataset and model.
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Initialize Datasets. STL10 will automatically download if not present
trainset = datasets.STL10(root='./data',
                          split='train',
                          download=True,
                          transform=transform)
valset = datasets.STL10(root='./data',
                        split='test',
                        download=True,
                        transform=transform)

# Create DistributedSampler to handle distributing the dataset across nodes when training
# This can only be called after torch.distributed.init_process_group is called
train_sampler = torch.utils.data.distributed.DistributedSampler(trainset)

# Create the Dataloaders to feed data to the training and validation steps
train_loader = torch.utils.data.DataLoader(trainset,
                                           batch_size=batch_size,
                                           shuffle=(train_sampler is None),
                                           num_workers=workers,
示例#30
0
import torch.nn as nn
import torchvision.transforms as transforms

from src.model import resnet18
from torch.utils.data import DataLoader
from torchvision import datasets
from src.utils import save_checkpoint_classifier

device = 'cuda' if torch.cuda.is_available else 'cpu'

data_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5],
                         std=[0.5, 0.5, 0.5])])

train_dataset = datasets.STL10(
    './data', split='train', download=False, transform=data_transform)
train_loader = DataLoader(
    train_dataset, batch_size=128, num_workers=8)

model = resnet18().to(device)

### Train ###
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

epochs = 50

for epoch in range(epochs):
    t = time.time()
    for x, label in train_loader:
        x, label = x.to(device), label.to(device)