Пример #1
0
def get_train_loader(args):
    """get the train loader"""
    data_folder = os.path.join(args.data_folder, 'train')
    normalize = transforms.Normalize(mean=[(0 + 100) / 2,
                                           (-86.183 + 98.233) / 2,
                                           (-107.857 + 94.478) / 2],
                                     std=[(100 - 0) / 2, (86.183 + 98.233) / 2,
                                          (107.857 + 94.478) / 2])
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224, scale=(args.crop_low, 1.)),
        transforms.RandomHorizontalFlip(),
        RGB2Lab(),
        transforms.ToTensor(),
        normalize,
    ])
    train_dataset = ImageFolderInstance(data_folder, transform=train_transform)
    train_sampler = None

    # train loader
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.num_workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    # num of samples
    n_data = len(train_dataset)
    print('number of samples: {}'.format(n_data))

    return train_loader, n_data
Пример #2
0
def get_train_loader(args):
    """get the train loader"""
    if 'imagenet' in args.dataset:
        data_folder = os.path.join(args.data_folder, 'train')

        if args.view == 'Lab':
            mean = [(0 + 100) / 2, (-86.183 + 98.233) / 2, (-107.857 + 94.478) / 2]
            std = [(100 - 0) / 2, (86.183 + 98.233) / 2, (107.857 + 94.478) / 2]
            color_transfer = RGB2Lab()
        elif args.view == 'YCbCr':
            mean = [116.151, 121.080, 132.342]
            std = [109.500, 111.855, 111.964]
            color_transfer = RGB2YCbCr()
        else:
            raise NotImplemented('view not implemented {}'.format(args.view))
        normalize = transforms.Normalize(mean=mean, std=std)

        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(224, scale=(args.crop_low, 1.)),
            transforms.RandomHorizontalFlip(),
            color_transfer,
            transforms.ToTensor(),
            normalize,
        ])
        train_dataset = ImageFolderInstance(data_folder, transform=train_transform)
    else:
        assert args.dataset == 'stl10'
        assert args.view == 'Lab'

        mean = [(0 + 100) / 2,
                (-86.183 + 98.233) / 2,
                (-107.857 + 94.478) / 2]
        std = [(100 - 0) / 2,
               (86.183 + 98.233) / 2,
               (107.857 + 94.478) / 2]
        train_transform = transforms.Compose([
            # transforms.RandomCrop(64),
            transforms.RandomResizedCrop(64, scale=(args.crop_low, 1)),
            transforms.RandomHorizontalFlip(),
            RGB2Lab(),
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std)
        ])
        train_dataset = datasets.STL10(
            args.data_folder, 'train+unlabeled',
            transform=train_transform, download=True)
        train_dataset = DatasetInstance(train_dataset)

    train_sampler = None
    # train loader
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
        num_workers=args.num_workers, pin_memory=True, sampler=train_sampler)

    # num of samples
    n_data = len(train_dataset)
    print('number of samples: {}'.format(n_data))

    return train_loader, n_data
Пример #3
0
def main():

    train_dataset = ImageFolderInstance(
        args.data,
        transforms.Compose([
            transforms.CenterCrop(224),
            transforms.ColorJitter(0.2, 0.2, 0.2, 0.2),
            transforms.ToTensor(),
        ]),
        transforms.Compose([
            transforms.CenterCrop(224),
            transforms.ToTensor(),
        ]))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    model = fcn_resnet50(num_classes=1)
    model = torch.nn.DataParallel(model).cuda()

    if not args.loss:
        criterion = diceLoss().cuda()
    else:
        criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    #    cudnn.benchmark = True

    for epoch in range(args.epochs):
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # remember best prec@1 and save checkpoint
        if (epoch % 5) == 0:
            save_checkpoint({
                'epoch': epoch + 1,
                'arch': "res50",
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            })
Пример #4
0
def get_train_loader(args):
    """get the train loader"""
    data_folder = os.path.join(args.data_folder, 'train')

    if args.view == 'Lab':
        mean = [(0 + 100) / 2, (-86.183 + 98.233) / 2, (-107.857 + 94.478) / 2]
        std = [(100 - 0) / 2, (86.183 + 98.233) / 2, (107.857 + 94.478) / 2]
        color_transfer = RGB2Lab()
    elif args.view == 'YCbCr':
        mean = [116.151, 121.080, 132.342]
        std = [109.500, 111.855, 111.964]
        color_transfer = RGB2YCbCr()
    else:
        raise NotImplemented('view not implemented {}'.format(args.view))
    normalize = transforms.Normalize(mean=mean, std=std)

    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224, scale=(args.crop_low, 1.)),
        transforms.RandomHorizontalFlip(),
        color_transfer,
        transforms.ToTensor(),
        normalize,
    ])
    train_dataset = ImageFolderInstance(data_folder, transform=train_transform)
    train_sampler = None

    train_samples = train_dataset.dataset.samples
    train_labels = [train_samples[i][1] for i in range(len(train_samples))]
    train_ordered_labels = np.array(train_labels)

    # train loader
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.num_workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    # num of samples
    n_data = len(train_dataset)
    print('number of samples: {}'.format(n_data))

    return train_loader, train_ordered_labels, n_data
Пример #5
0
def get_test_loader(args):
    """get the train loader"""
    data_folder = os.path.join(args.data_folder, 'validation')

    if args.view == 'Lab':
        mean = [(0 + 100) / 2, (-86.183 + 98.233) / 2, (-107.857 + 94.478) / 2]
        std = [(100 - 0) / 2, (86.183 + 98.233) / 2, (107.857 + 94.478) / 2]
        color_transfer = RGB2Lab()
    elif args.view == 'YCbCr':
        mean = [116.151, 121.080, 132.342]
        std = [109.500, 111.855, 111.964]
        color_transfer = RGB2YCbCr()
    else:
        raise NotImplemented('view not implemented {}'.format(args.view))
    normalize = transforms.Normalize(mean=mean, std=std)

    test_transform = transforms.Compose([
        transforms.Resize(256),  # FIXME: hardcoded for 224 image size
        transforms.CenterCrop(image_size),
        color_transfer,
        transforms.ToTensor(),
        normalize,
    ])
    test_dataset = ImageFolderInstance(data_folder, transform=test_transform)
    test_sampler = None

    # train loader
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=args.batch_size,
                                              shuffle=(test_sampler is None),
                                              num_workers=args.num_workers,
                                              pin_memory=True,
                                              sampler=test_sampler)

    # num of samples
    n_data = len(test_dataset)
    print('number of samples: {}'.format(n_data))

    return test_loader, n_data
Пример #6
0
def main():

    args = parse_option()

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    # set the data loader
    data_folder = os.path.join(args.data_folder, 'train')

    image_size = 224
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    normalize = transforms.Normalize(mean=mean, std=std)

    if args.aug == 'NULL':
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(image_size, scale=(args.crop, 1.)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
    elif args.aug == 'CJ':
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(image_size, scale=(args.crop, 1.)),
            transforms.RandomGrayscale(p=0.2),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
    else:
        raise NotImplemented('augmentation not supported: {}'.format(args.aug))

    train_dataset = ImageFolderInstance(data_folder,
                                        transform=train_transform,
                                        two_crop=args.moco)
    print(len(train_dataset))
    train_sampler = None
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.num_workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    # create model and optimizer
    n_data = len(train_dataset)

    if args.model == 'resnet50':
        model = InsResNet50()
        if args.moco:
            model_ema = InsResNet50()
    elif args.model == 'resnet50x2':
        model = InsResNet50(width=2)
        if args.moco:
            model_ema = InsResNet50(width=2)
    elif args.model == 'resnet50x4':
        model = InsResNet50(width=4)
        if args.moco:
            model_ema = InsResNet50(width=4)
    else:
        raise NotImplementedError('model not supported {}'.format(args.model))

    # copy weights from `model' to `model_ema'
    if args.moco:
        moment_update(model, model_ema, 0)

    # set the contrast memory and criterion
    if args.moco:
        contrast = MemoryMoCo(128, n_data, args.nce_k, args.nce_t,
                              args.softmax).cuda(args.gpu)
    else:
        contrast = MemoryInsDis(128, n_data, args.nce_k, args.nce_t,
                                args.nce_m, args.softmax).cuda(args.gpu)

    criterion = NCESoftmaxLoss() if args.softmax else NCECriterion(n_data)
    criterion = criterion.cuda(args.gpu)

    model = model.cuda()
    if args.moco:
        model_ema = model_ema.cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    cudnn.benchmark = True

    if args.amp:
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.opt_level)
        if args.moco:
            optimizer_ema = torch.optim.SGD(model_ema.parameters(),
                                            lr=0,
                                            momentum=0,
                                            weight_decay=0)
            model_ema, optimizer_ema = amp.initialize(model_ema,
                                                      optimizer_ema,
                                                      opt_level=args.opt_level)

    # optionally resume from a checkpoint
    args.start_epoch = 1
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location='cpu')
            # checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch'] + 1
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            contrast.load_state_dict(checkpoint['contrast'])
            if args.moco:
                model_ema.load_state_dict(checkpoint['model_ema'])

            if args.amp and checkpoint['opt'].amp:
                print('==> resuming amp state_dict')
                amp.load_state_dict(checkpoint['amp'])

            print("=> loaded successfully '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            del checkpoint
            torch.cuda.empty_cache()
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # tensorboard
    logger = tb_logger.Logger(logdir=args.tb_folder, flush_secs=2)

    # routine
    for epoch in range(args.start_epoch, args.epochs + 1):

        adjust_learning_rate(epoch, args, optimizer)
        print("==> training...")

        time1 = time.time()
        if args.moco:
            loss, prob = train_moco(epoch, train_loader, model, model_ema,
                                    contrast, criterion, optimizer, args)
        else:
            loss, prob = train_ins(epoch, train_loader, model, contrast,
                                   criterion, optimizer, args)
        time2 = time.time()
        print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))

        # tensorboard logger
        logger.log_value('ins_loss', loss, epoch)
        logger.log_value('ins_prob', prob, epoch)
        logger.log_value('learning_rate', optimizer.param_groups[0]['lr'],
                         epoch)

        # save model
        if epoch % args.save_freq == 0:
            print('==> Saving...')
            state = {
                'opt': args,
                'model': model.state_dict(),
                'contrast': contrast.state_dict(),
                'optimizer': optimizer.state_dict(),
                'epoch': epoch,
            }
            if args.moco:
                state['model_ema'] = model_ema.state_dict()
            if args.amp:
                state['amp'] = amp.state_dict()
            save_file = os.path.join(
                args.model_folder,
                'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
            torch.save(state, save_file)
            # help release GPU memory
            del state

        # saving the model
        print('==> Saving...')
        state = {
            'opt': args,
            'model': model.state_dict(),
            'contrast': contrast.state_dict(),
            'optimizer': optimizer.state_dict(),
            'epoch': epoch,
        }
        if args.moco:
            state['model_ema'] = model_ema.state_dict()
        if args.amp:
            state['amp'] = amp.state_dict()
        save_file = os.path.join(args.model_folder, 'current.pth')
        torch.save(state, save_file)
        if epoch % args.save_freq == 0:
            save_file = os.path.join(
                args.model_folder,
                'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
            torch.save(state, save_file)
        # help release GPU memory
        del state
        torch.cuda.empty_cache()
Пример #7
0
def main():

    args = parse_option()

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    # set the data loader
    data_folder = os.path.join(args.data_folder, 'train')
    val_folder = os.path.join(args.data_folder, 'val')

    crop_padding = 32
    image_size = 224
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    normalize = transforms.Normalize(mean=mean, std=std)

    if args.aug == 'NULL' and args.dataset == 'imagenet':
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(image_size, scale=(args.crop, 1.)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
    elif args.aug == 'CJ':
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(image_size, scale=(args.crop, 1.)),
            transforms.RandomGrayscale(p=0.2),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
    # elif args.aug == 'NULL' and args.dataset == 'cifar':
    #     train_transform = transforms.Compose([
    #         transforms.RandomResizedCrop(size=32, scale=(0.2, 1.)),
    #         transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
    #         transforms.RandomGrayscale(p=0.2),
    #         transforms.RandomHorizontalFlip(p=0.5),
    #         transforms.ToTensor(),
    #         transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    #     ])
    #
    #     test_transform = transforms.Compose([
    #         transforms.ToTensor(),
    #         transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    #     ])
    elif args.aug == 'simple' and args.dataset == 'imagenet':
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(image_size, scale=(args.crop, 1.)),
            transforms.RandomHorizontalFlip(),
            get_color_distortion(1.0),
            transforms.ToTensor(),
            normalize,
        ])

        # TODO: Currently follow CMC
        test_transform = transforms.Compose([
            transforms.Resize(image_size + crop_padding),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            normalize,
        ])
    elif args.aug == 'simple' and args.dataset == 'cifar':
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(size=32),
            transforms.RandomHorizontalFlip(p=0.5),
            get_color_distortion(0.5),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        test_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

    else:
        raise NotImplemented('augmentation not supported: {}'.format(args.aug))

    # Get Datasets
    if args.dataset == "imagenet":
        train_dataset = ImageFolderInstance(data_folder,
                                            transform=train_transform,
                                            two_crop=args.moco)
        print(len(train_dataset))
        train_sampler = None
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=(train_sampler is None),
            num_workers=args.num_workers,
            pin_memory=True,
            sampler=train_sampler)

        test_dataset = datasets.ImageFolder(val_folder,
                                            transforms=test_transform)

        test_loader = torch.utils.data.DataLoader(test_dataset,
                                                  batch_size=256,
                                                  shuffle=False,
                                                  num_workers=args.num_workers,
                                                  pin_memory=True)

    elif args.dataset == 'cifar':
        # cifar-10 dataset
        if args.contrastive_model == 'simclr':
            train_dataset = CIFAR10Instance_double(root='./data',
                                                   train=True,
                                                   download=True,
                                                   transform=train_transform,
                                                   double=True)
        else:
            train_dataset = CIFAR10Instance(root='./data',
                                            train=True,
                                            download=True,
                                            transform=train_transform)
        train_sampler = None
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=(train_sampler is None),
            num_workers=args.num_workers,
            pin_memory=True,
            sampler=train_sampler,
            drop_last=True)

        test_dataset = CIFAR10Instance(root='./data',
                                       train=False,
                                       download=True,
                                       transform=test_transform)
        test_loader = torch.utils.data.DataLoader(test_dataset,
                                                  batch_size=100,
                                                  shuffle=False,
                                                  num_workers=args.num_workers)

    # create model and optimizer
    n_data = len(train_dataset)

    if args.model == 'resnet50':
        model = InsResNet50()
        if args.contrastive_model == 'moco':
            model_ema = InsResNet50()
    elif args.model == 'resnet50x2':
        model = InsResNet50(width=2)
        if args.contrastive_model == 'moco':
            model_ema = InsResNet50(width=2)
    elif args.model == 'resnet50x4':
        model = InsResNet50(width=4)
        if args.contrastive_model == 'moco':
            model_ema = InsResNet50(width=4)
    elif args.model == 'resnet50_cifar':
        model = InsResNet50_cifar()
        if args.contrastive_model == 'moco':
            model_ema = InsResNet50_cifar()
    else:
        raise NotImplementedError('model not supported {}'.format(args.model))

    # copy weights from `model' to `model_ema'
    if args.contrastive_model == 'moco':
        moment_update(model, model_ema, 0)

    # set the contrast memory and criterion
    if args.contrastive_model == 'moco':
        contrast = MemoryMoCo(128, n_data, args.nce_k, args.nce_t,
                              args.softmax).cuda(args.gpu)
    elif args.contrastive_model == 'simclr':
        contrast = None
    else:
        contrast = MemoryInsDis(128, n_data, args.nce_k, args.nce_t,
                                args.nce_m, args.softmax).cuda(args.gpu)

    if args.softmax:
        criterion = NCESoftmaxLoss()
    elif args.contrastive_model == 'simclr':
        criterion = BatchCriterion(1, args.nce_t, args.batch_size)
    else:
        criterion = NCECriterion(n_data)
    criterion = criterion.cuda(args.gpu)

    model = model.cuda()
    if args.contrastive_model == 'moco':
        model_ema = model_ema.cuda()

    # Exclude BN and bias if needed
    weight_decay = args.weight_decay
    if weight_decay and args.filter_weight_decay:
        parameters = add_weight_decay(model, weight_decay,
                                      args.filter_weight_decay)
        weight_decay = 0.
    else:
        parameters = model.parameters()

    optimizer = torch.optim.SGD(parameters,
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=weight_decay)
    cudnn.benchmark = True

    if args.amp:
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.opt_level)
        if args.contrastive_model == 'moco':
            optimizer_ema = torch.optim.SGD(model_ema.parameters(),
                                            lr=0,
                                            momentum=0,
                                            weight_decay=0)
            model_ema, optimizer_ema = amp.initialize(model_ema,
                                                      optimizer_ema,
                                                      opt_level=args.opt_level)

    if args.LARS:
        optimizer = LARS(optimizer=optimizer, eps=1e-8, trust_coef=0.001)

    # optionally resume from a checkpoint
    args.start_epoch = 0
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location='cpu')
            # checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch'] + 1
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            if contrast:
                contrast.load_state_dict(checkpoint['contrast'])
            if args.contrastive_model == 'moco':
                model_ema.load_state_dict(checkpoint['model_ema'])

            if args.amp and checkpoint['opt'].amp:
                print('==> resuming amp state_dict')
                amp.load_state_dict(checkpoint['amp'])

            print("=> loaded successfully '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            del checkpoint
            torch.cuda.empty_cache()
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # tensorboard
    logger = tb_logger.Logger(logdir=args.tb_folder, flush_secs=2)

    # routine
    for epoch in range(args.start_epoch, args.epochs + 1):

        print("==> training...")

        time1 = time.time()
        if args.contrastive_model == 'moco':
            loss, prob = train_moco(epoch, train_loader, model, model_ema,
                                    contrast, criterion, optimizer, args)
        elif args.contrastive_model == 'simclr':
            print("Train using simclr")
            loss, prob = train_simclr(epoch, train_loader, model, criterion,
                                      optimizer, args)
        else:
            print("Train using InsDis")
            loss, prob = train_ins(epoch, train_loader, model, contrast,
                                   criterion, optimizer, args)
        time2 = time.time()
        print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))

        # tensorboard logger
        logger.log_value('ins_loss', loss, epoch)
        logger.log_value('ins_prob', prob, epoch)
        logger.log_value('learning_rate', optimizer.param_groups[0]['lr'],
                         epoch)

        test_epoch = 2
        if epoch % test_epoch == 0:
            model.eval()

            if args.contrastive_model == 'moco':
                model_ema.eval()

            print('----------Evaluation---------')
            start = time.time()

            if args.dataset == 'cifar':
                acc = kNN(epoch,
                          model,
                          train_loader,
                          test_loader,
                          200,
                          args.nce_t,
                          n_data,
                          low_dim=128,
                          memory_bank=None)

            print("Evaluation Time: '{}'s".format(time.time() - start))
            # writer.add_scalar('nn_acc', acc, epoch)
            logger.log_value('Test accuracy', acc, epoch)

            # print('accuracy: {}% \t (best acc: {}%)'.format(acc, best_acc))
            print('[Epoch]: {}'.format(epoch))
            print('accuracy: {}%)'.format(acc))
            # test_log_file.flush()

        # save model
        if epoch % args.save_freq == 0:
            print('==> Saving...')
            state = {
                'opt': args,
                'model': model.state_dict(),
                # 'contrast': contrast.state_dict(),
                'optimizer': optimizer.state_dict(),
                'epoch': epoch,
            }
            if args.contrastive_model == 'moco':
                state['model_ema'] = model_ema.state_dict()
            if args.amp:
                state['amp'] = amp.state_dict()
            save_file = os.path.join(
                args.model_folder,
                'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
            torch.save(state, save_file)
            # help release GPU memory
            del state

        # saving the model
        print('==> Saving...')
        state = {
            'opt': args,
            'model': model.state_dict(),
            # 'contrast': contrast.state_dict(),
            'optimizer': optimizer.state_dict(),
            'epoch': epoch,
        }
        if args.contrastive_model == 'moco':
            state['model_ema'] = model_ema.state_dict()
        if args.amp:
            state['amp'] = amp.state_dict()
        save_file = os.path.join(args.model_folder, 'current.pth')
        torch.save(state, save_file)
        if epoch % args.save_freq == 0:
            save_file = os.path.join(
                args.model_folder,
                'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
            torch.save(state, save_file)
        # help release GPU memory
        del state
        torch.cuda.empty_cache()
Пример #8
0
def get_train_loader(args):
    """get the train loader"""
    data_folder = os.path.join(args.data_folder, 'train')

    if args.view == 'Lab':
        mean = [(0 + 100) / 2, (-86.183 + 98.233) / 2, (-107.857 + 94.478) / 2]
        std = [(100 - 0) / 2, (86.183 + 98.233) / 2, (107.857 + 94.478) / 2]
        color_transfer = RGB2Lab()
    elif args.view == 'YCbCr':
        mean = [116.151, 121.080, 132.342]
        std = [109.500, 111.855, 111.964]
        color_transfer = RGB2YCbCr()
    else:
        raise NotImplemented('view not implemented {}'.format(args.view))
    normalize = transforms.Normalize(mean=mean, std=std)

    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224, scale=(args.crop_low, 1.)),
        transforms.RandomHorizontalFlip(),
        color_transfer,
        transforms.ToTensor(),
        normalize,
    ])
    train_dataset = ImageFolderInstance(data_folder, transform=train_transform)
    train_sampler = None
    if args.IM:
        # print("using IM space.................")
        if args.IM_type == 'IM':
            print("using IM space.................")
            train_dataset = IM(train_dataset,
                               g_alpha=args.g_alpha,
                               g_num_mix=args.g_num,
                               g_prob=args.g_prob,
                               r_beta=args.r_beta,
                               r_prob=args.r_prob,
                               r_num_mix=args.r_num,
                               r_decay=args.r_pixel_decay)
        if args.IM_type == 'global':
            print("using global space.................")
            train_dataset = global_(train_dataset,
                                    g_alpha=args.g_alpha,
                                    g_num_mix=args.g_num,
                                    g_prob=args.g_prob)
        if args.IM_type == 'region':
            print("using region space.................")
            train_dataset = region(train_dataset,
                                   r_beta=args.r_beta,
                                   r_prob=args.r_prob,
                                   r_num_mix=args.r_num,
                                   r_decay=args.r_pixel_decay)
        if args.IM_type == 'Cutout':
            print("using Cutout aug.................")
            train_dataset = Cutout(train_dataset,
                                   mask_size=args.mask_size,
                                   p=args.cutout_p,
                                   cutout_inside=args.cutout_inside,
                                   mask_color=args.mask_color)
        if args.IM_type == 'RandomErasing':
            print("using RandomErasing aug.................")
            train_dataset = RandomErasing(
                train_dataset,
                p=args.random_erasing_prob,
                area_ratio_range=args.area_ratio_range,
                min_aspect_ratio=args.min_aspect_ratio,
                max_attempt=args.max_attempt)

    # train loader
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.num_workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    # num of samples
    n_data = len(train_dataset)
    print('number of samples: {}'.format(n_data))

    return train_loader, n_data