Пример #1
0
 def get_simclr_pipeline_transform(size, s=1, num_aug=5):
     """Return a set of data augmentation transformations as described in the SimCLR paper."""
     color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s,
                                           0.2 * s)
     if num_aug == 5:
         data_transforms = transforms.Compose([
             transforms.RandomResizedCrop(size=size),
             transforms.RandomHorizontalFlip(),
             transforms.RandomApply([color_jitter], p=0.8),
             transforms.RandomGrayscale(p=0.2),
             GaussianBlur(kernel_size=int(0.1 * size)),
             transforms.ToTensor()
         ])
     elif num_aug == 7:
         data_transforms = transforms.Compose([
             transforms.RandomResizedCrop(size=size),
             transforms.RandomHorizontalFlip(),
             transforms.RandomApply([color_jitter], p=0.8),
             transforms.RandomGrayscale(p=0.2),
             GaussianBlur(kernel_size=int(0.1 * size)),
             transforms.RandomRotation(degrees=45),
             transforms.RandomAffine(degrees=45),
             transforms.ToTensor()
         ])
     return data_transforms
Пример #2
0
 def _get_simclr_pipeline_transform(self):
     # get a set of data augmentation transformations as described in the SimCLR paper.
     color_jitter = transforms.ColorJitter(0.8 * self.s, 0.8 * self.s, 0.8 * self.s, 0.2 * self.s)
     data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=self.input_shape[0]),
                                           transforms.RandomHorizontalFlip(),
                                           transforms.RandomApply([color_jitter], p=0.8),
                                           # transforms.RandomGrayscale(p=0.2),
                                           GaussianBlur(kernel_size=int(0.2 * self.input_shape[0])),
                                           transforms.ToTensor()])
     return data_transforms
 def get_simclr_transform(size, s=1):
     """Return a set of data augmentation transformations as described in the SimCLR paper."""
     color_jitter = transforms.ColorJitter(0.1 * s, 0.8 * s, 0.2 * s, 0 * s)
     data_transforms = transforms.Compose(
         [  #transforms.RandomResizedCrop(size=size),
             transforms.RandomHorizontalFlip(),
             transforms.RandomApply([color_jitter], p=1),
             transforms.RandomGrayscale(p=0.2),
             GaussianBlur(kernel_size=int(0.1 * size)),
             transforms.ToTensor()
         ])
     return data_transforms
Пример #4
0
 def get_simclr_pipeline_transform(size, s=1):
     """transforms can be improved to match well with SimCLR paper. this doesnt seem to be exactly what paper says. can try more transforms as well"""
     color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s,
                                           0.2 * s)  #
     data_transforms = transforms.Compose([
         transforms.RandomResizedCrop(size=size),
         transforms.RandomHorizontalFlip(),
         transforms.RandomApply([color_jitter], p=0.8),
         transforms.RandomGrayscale(p=0.2),
         GaussianBlur(kernel_size=int(0.1 * size)),
         transforms.ToTensor()
     ])
     return data_transforms
Пример #5
0
def get_simclr_transform(size, s=1):
    """Return a set of data augmentation transformations as described in the SimCLR paper."""
    color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    data_transforms = transforms.Compose([
        transforms.Resize(size=size),
        transforms.RandomHorizontalFlip(),
        #transforms.RandomApply([color_jitter], p=0.8),
        #transforms.RandomGrayscale(p=0.2),
        GaussianBlur(kernel_size=int(0.1 * size)),
        transforms.ToTensor(),
        normalize
    ])
    return data_transforms
    def get_simclr_pipeline_transform(size, s=1, channels=3):
        """Return a set of data augmentation transformations as described in the SimCLR paper."""
        #Original
        color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
        data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=size),
                                              # transforms.RandomHorizontalFlip(),
                                              transforms.RandomApply([color_jitter], p=0.8),
                                              transforms.RandomGrayscale(p=0.2),
                                              GaussianBlur(kernel_size=int(0.1 * size), channels=channels),
                                              transforms.ToTensor()])

        # data_transforms = transforms.Compose([iaa.Sequential([iaa.SomeOf((1, 5), 
        #                                       [iaa.LinearContrast((0.5, 1.0)),
        #                                       iaa.GaussianBlur((0.5, 1.5)),
        #                                       iaa.Crop(percent=((0, 0.4),(0, 0),(0, 0.4),(0, 0.0)), keep_size=True),
        #                                       iaa.Crop(percent=((0, 0.0),(0, 0.02),(0, 0),(0, 0.02)), keep_size=True),
        #                                       iaa.Sharpen(alpha=(0.0, 0.5), lightness=(0.0, 0.5)),
        #                                       iaa.PiecewiseAffine(scale=(0.02, 0.03), mode='edge'),
        #                                       iaa.PerspectiveTransform(scale=(0.01, 0.02))],
        #                                       random_order=True)]).augment_image,
        #                                       transforms.ToTensor()])

        return data_transforms
Пример #7
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    args.distributed = args.world_size > 1

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size)
    # Use CUDA
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
    use_cuda = torch.cuda.is_available()

    # Random seed
    if args.manual_seed is None:
        args.manual_seed = random.randint(1, 10000)
    random.seed(args.manual_seed)
    torch.manual_seed(args.manual_seed)
    if use_cuda:
        torch.cuda.manual_seed_all(args.manual_seed)

    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    elif args.arch.startswith('resnext'):
        model = models.__dict__[args.arch](
                    baseWidth=args.base_width,
                    cardinality=args.cardinality,
                )
    elif args.arch.startswith('shufflenet'):
        model = models.__dict__[args.arch](
                    groups=args.groups
                )
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch]()

    if not args.distributed:
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()
    else:
        model.cuda()
        model = torch.nn.parallel.DistributedDataParallel(model)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    title = 'CelebA-' + args.arch
    if not os.path.isdir(args.checkpoint):
        mkdir_p(args.checkpoint)

    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
            args.checkpoint = os.path.dirname(args.resume)
            logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title, resume=True)
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names(['Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.'])


    cudnn.benchmark = True

    # Data loading code
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    color_jitter = transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)
    train_dataset = CelebA(
        args.data,
        'train_40_att_list.txt',
        transforms.Compose([
            transforms.RandomResizedCrop(size=(218, 178)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomApply([color_jitter], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            GaussianBlur(kernel_size=int(0.1 * 178)),
            transforms.ToTensor(),
            normalize,
        ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.train_batch, shuffle=(train_sampler is None),
        num_workers=args.workers, pin_memory=True, sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(
        CelebA(args.data, 'val_40_att_list.txt', transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=args.test_batch, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    test_loader = torch.utils.data.DataLoader(
        CelebA(args.data, 'test_40_att_list.txt', transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=args.test_batch, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    if args.evaluate:
        validate(test_loader, model, criterion)
        return

    if args.private_test:
        private_loader = torch.utils.data.DataLoader(
            CelebA(args.data, 'testset', transforms.Compose([
                transforms.ToTensor(),
                normalize,
            ]),test=True),
            batch_size=args.test_batch, shuffle=False,
            num_workers=args.workers, pin_memory=True)
        test(private_loader, model, criterion)
        return

    # visualization
    writer = SummaryWriter(os.path.join(args.checkpoint, 'logs'))

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        lr = adjust_learning_rate(optimizer, epoch)

        print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, lr))

        # train for one epoch
        train_loss, train_acc = train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        val_loss, prec1 = validate(val_loader, model, criterion)

        # append logger file
        logger.append([lr, train_loss, val_loss, train_acc, prec1])

        # tensorboardX
        writer.add_scalar('learning rate', lr, epoch + 1)
        writer.add_scalars('loss', {'train loss': train_loss, 'validation loss': val_loss}, epoch + 1)
        writer.add_scalars('accuracy', {'train accuracy': train_acc, 'validation accuracy': prec1}, epoch + 1)
        #for name, param in model.named_parameters():
        #    writer.add_histogram(name, param.clone().cpu().data.numpy(), epoch + 1)


        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'best_prec1': best_prec1,
            'optimizer' : optimizer.state_dict(),
        }, is_best, checkpoint=args.checkpoint)

    logger.close()
    logger.plot()
    savefig(os.path.join(args.checkpoint, 'log.eps'))
    writer.close()

    print('Best accuracy:')
    print(best_prec1)
Пример #8
0
])

transform1 = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.49, 0.45, 0.4], std=[0.229, 0.224, 0.225]),
])

color_jitter = transforms.ColorJitter(0.8 * 1, 0.8 * 1, 0.8 * 1, 0.2 * 1)
data_transforms = transforms.Compose([
    transforms.RandomResizedCrop(size=224),
    transforms.RandomHorizontalFlip(),
    transforms.RandomApply([color_jitter], p=0.8),
    transforms.RandomGrayscale(p=0.2),
    GaussianBlur(kernel_size=int(0.1 * 3)),
    transforms.ToTensor()
])


class DataSetWrapper(object):
    def __init__(self, batch_size, num_workers, valid_size, input_shape, s):
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.valid_size = valid_size
        self.s = s
        self.input_shape = eval(input_shape)

    def get_data_loaders(self):
        data_augment = self._get_simclr_pipeline_transform()