def main():

    args = parse_option()

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.dataset == 'CelebA':
        train_dataset = CelebAPrunedAligned_MAFLVal(args.data_folder,
                                                    train=True,
                                                    pair_image=True,
                                                    do_augmentations=True,
                                                    imwidth=args.image_size,
                                                    crop=args.image_crop)
    elif args.dataset == 'InatAve':
        train_dataset = InatAve(args.data_folder,
                                train=True,
                                pair_image=True,
                                do_augmentations=True,
                                imwidth=args.image_size,
                                imagelist=args.imagelist)
    else:
        raise NotImplementedError('dataset not supported {}'.format(
            args.dataset))

    print(len(train_dataset))
    train_sampler = None
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.num_workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    # create model and optimizer
    n_data = len(train_dataset)

    input_size = args.image_size - 2 * args.image_crop
    pool_size = int(input_size /
                    2**5)  # 96x96 --> 3; 160x160 --> 5; 224x224 --> 7;

    if args.model == 'resnet50':
        model = InsResNet50(pool_size=pool_size)
        model_ema = InsResNet50(pool_size=pool_size)
    elif args.model == 'resnet50x2':
        model = InsResNet50(width=2, pool_size=pool_size)
        model_ema = InsResNet50(width=2, pool_size=pool_size)
    elif args.model == 'resnet50x4':
        model = InsResNet50(width=4, pool_size=pool_size)
        model_ema = InsResNet50(width=4, pool_size=pool_size)
    elif args.model == 'resnet18':
        model = InsResNet18(width=1, pool_size=pool_size)
        model_ema = InsResNet18(width=1, pool_size=pool_size)
    elif args.model == 'resnet34':
        model = InsResNet34(width=1, pool_size=pool_size)
        model_ema = InsResNet34(width=1, pool_size=pool_size)
    elif args.model == 'resnet101':
        model = InsResNet101(width=1, pool_size=pool_size)
        model_ema = InsResNet101(width=1, pool_size=pool_size)
    elif args.model == 'resnet152':
        model = InsResNet152(width=1, pool_size=pool_size)
        model_ema = InsResNet152(width=1, pool_size=pool_size)
    else:
        raise NotImplementedError('model not supported {}'.format(args.model))

    # copy weights from `model' to `model_ema'
    moment_update(model, model_ema, 0)

    # set the contrast memory and criterion
    contrast = MemoryMoCo(128,
                          n_data,
                          args.nce_k,
                          args.nce_t,
                          use_softmax=True).cuda(args.gpu)

    criterion = NCESoftmaxLoss()
    criterion = criterion.cuda(args.gpu)

    model = model.cuda()
    model_ema = model_ema.cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    cudnn.benchmark = True

    # optionally resume from a checkpoint
    args.start_epoch = 1
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location='cpu')
            # checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch'] + 1
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            contrast.load_state_dict(checkpoint['contrast'])
            model_ema.load_state_dict(checkpoint['model_ema'])
            print("=> loaded successfully '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            del checkpoint
            torch.cuda.empty_cache()
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # tensorboard
    logger = tb_logger.Logger(logdir=args.tb_folder, flush_secs=2)

    # routine
    for epoch in range(args.start_epoch, args.epochs + 1):

        adjust_learning_rate(epoch, args, optimizer)
        print("==> training...")

        time1 = time.time()
        loss, prob = train_moco(epoch, train_loader, model, model_ema,
                                contrast, criterion, optimizer, args)
        time2 = time.time()
        print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))

        # tensorboard logger
        logger.log_value('ins_loss', loss, epoch)
        logger.log_value('ins_prob', prob, epoch)
        logger.log_value('learning_rate', optimizer.param_groups[0]['lr'],
                         epoch)

        # save model
        if epoch % args.save_freq == 0:
            print('==> Saving...')
            state = {
                'opt': args,
                'model': model.state_dict(),
                'contrast': contrast.state_dict(),
                'optimizer': optimizer.state_dict(),
                'epoch': epoch,
            }
            state['model_ema'] = model_ema.state_dict()
            save_file = os.path.join(
                args.model_folder,
                'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
            torch.save(state, save_file)
            # help release GPU memory
            del state

        # saving the model
        print('==> Saving...')
        state = {
            'opt': args,
            'model': model.state_dict(),
            'contrast': contrast.state_dict(),
            'optimizer': optimizer.state_dict(),
            'epoch': epoch,
        }
        state['model_ema'] = model_ema.state_dict()
        save_file = os.path.join(args.model_folder, 'current.pth')
        torch.save(state, save_file)
        if epoch % args.save_freq == 0:
            save_file = os.path.join(
                args.model_folder,
                'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
            torch.save(state, save_file)
        # help release GPU memory
        del state
        torch.cuda.empty_cache()
Exemple #2
0
def main():

    args = parse_option()

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    # set the data loader
    data_folder = os.path.join(args.data_folder, 'train')

    image_size = 224
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    normalize = transforms.Normalize(mean=mean, std=std)

    if args.aug == 'NULL':
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(image_size, scale=(args.crop, 1.)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
    elif args.aug == 'CJ':
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(image_size, scale=(args.crop, 1.)),
            transforms.RandomGrayscale(p=0.2),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
    else:
        raise NotImplemented('augmentation not supported: {}'.format(args.aug))

    train_dataset = ImageFolderInstance(data_folder,
                                        transform=train_transform,
                                        two_crop=args.moco)
    print(len(train_dataset))
    train_sampler = None
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.num_workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    # create model and optimizer
    n_data = len(train_dataset)

    if args.model == 'resnet50':
        model = InsResNet50()
        if args.moco:
            model_ema = InsResNet50()
    elif args.model == 'resnet50x2':
        model = InsResNet50(width=2)
        if args.moco:
            model_ema = InsResNet50(width=2)
    elif args.model == 'resnet50x4':
        model = InsResNet50(width=4)
        if args.moco:
            model_ema = InsResNet50(width=4)
    else:
        raise NotImplementedError('model not supported {}'.format(args.model))

    # copy weights from `model' to `model_ema'
    if args.moco:
        moment_update(model, model_ema, 0)

    # set the contrast memory and criterion
    if args.moco:
        contrast = MemoryMoCo(128, n_data, args.nce_k, args.nce_t,
                              args.softmax).cuda(args.gpu)
    else:
        contrast = MemoryInsDis(128, n_data, args.nce_k, args.nce_t,
                                args.nce_m, args.softmax).cuda(args.gpu)

    criterion = NCESoftmaxLoss() if args.softmax else NCECriterion(n_data)
    criterion = criterion.cuda(args.gpu)

    model = model.cuda()
    if args.moco:
        model_ema = model_ema.cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    cudnn.benchmark = True

    if args.amp:
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.opt_level)
        if args.moco:
            optimizer_ema = torch.optim.SGD(model_ema.parameters(),
                                            lr=0,
                                            momentum=0,
                                            weight_decay=0)
            model_ema, optimizer_ema = amp.initialize(model_ema,
                                                      optimizer_ema,
                                                      opt_level=args.opt_level)

    # optionally resume from a checkpoint
    args.start_epoch = 1
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location='cpu')
            # checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch'] + 1
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            contrast.load_state_dict(checkpoint['contrast'])
            if args.moco:
                model_ema.load_state_dict(checkpoint['model_ema'])

            if args.amp and checkpoint['opt'].amp:
                print('==> resuming amp state_dict')
                amp.load_state_dict(checkpoint['amp'])

            print("=> loaded successfully '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            del checkpoint
            torch.cuda.empty_cache()
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # tensorboard
    logger = tb_logger.Logger(logdir=args.tb_folder, flush_secs=2)

    # routine
    for epoch in range(args.start_epoch, args.epochs + 1):

        adjust_learning_rate(epoch, args, optimizer)
        print("==> training...")

        time1 = time.time()
        if args.moco:
            loss, prob = train_moco(epoch, train_loader, model, model_ema,
                                    contrast, criterion, optimizer, args)
        else:
            loss, prob = train_ins(epoch, train_loader, model, contrast,
                                   criterion, optimizer, args)
        time2 = time.time()
        print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))

        # tensorboard logger
        logger.log_value('ins_loss', loss, epoch)
        logger.log_value('ins_prob', prob, epoch)
        logger.log_value('learning_rate', optimizer.param_groups[0]['lr'],
                         epoch)

        # save model
        if epoch % args.save_freq == 0:
            print('==> Saving...')
            state = {
                'opt': args,
                'model': model.state_dict(),
                'contrast': contrast.state_dict(),
                'optimizer': optimizer.state_dict(),
                'epoch': epoch,
            }
            if args.moco:
                state['model_ema'] = model_ema.state_dict()
            if args.amp:
                state['amp'] = amp.state_dict()
            save_file = os.path.join(
                args.model_folder,
                'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
            torch.save(state, save_file)
            # help release GPU memory
            del state

        # saving the model
        print('==> Saving...')
        state = {
            'opt': args,
            'model': model.state_dict(),
            'contrast': contrast.state_dict(),
            'optimizer': optimizer.state_dict(),
            'epoch': epoch,
        }
        if args.moco:
            state['model_ema'] = model_ema.state_dict()
        if args.amp:
            state['amp'] = amp.state_dict()
        save_file = os.path.join(args.model_folder, 'current.pth')
        torch.save(state, save_file)
        if epoch % args.save_freq == 0:
            save_file = os.path.join(
                args.model_folder,
                'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
            torch.save(state, save_file)
        # help release GPU memory
        del state
        torch.cuda.empty_cache()
def main():

    global best_acc1
    best_acc1 = 0

    args = parse_option()

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    # set the data loader
    train_folder = os.path.join(args.data_folder, 'train')
    val_folder = os.path.join(args.data_folder, 'val')

    image_size = 224
    crop_padding = 32
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    normalize = transforms.Normalize(mean=mean, std=std)

    if args.aug == 'NULL':
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(image_size, scale=(args.crop, 1.)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
    elif args.aug == 'CJ':
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(image_size, scale=(args.crop, 1.)),
            transforms.RandomGrayscale(p=0.2),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
    else:
        raise NotImplemented('augmentation not supported: {}'.format(args.aug))

    train_dataset = datasets.ImageFolder(train_folder, train_transform)
    val_dataset = datasets.ImageFolder(
        val_folder,
        transforms.Compose([
            transforms.Resize(image_size + crop_padding),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            normalize,
        ]))

    print(len(train_dataset))
    train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.num_workers,
                                               pin_memory=True,
                                               sampler=train_sampler)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.num_workers,
                                             pin_memory=True)

    # create model and optimizer
    if args.model == 'resnet50':
        model = InsResNet50()
        classifier = LinearClassifierResNet(args.layer, args.n_label, 'avg', 1)
    elif args.model == 'resnet50x2':
        model = InsResNet50(width=2)
        classifier = LinearClassifierResNet(args.layer, args.n_label, 'avg', 2)
    elif args.model == 'resnet50x4':
        model = InsResNet50(width=4)
        classifier = LinearClassifierResNet(args.layer, args.n_label, 'avg', 4)
    else:
        raise NotImplementedError('model not supported {}'.format(args.model))

    print('==> loading pre-trained model')
    ckpt = torch.load(args.model_path)
    model.load_state_dict(ckpt['model'])
    print("==> loaded checkpoint '{}' (epoch {})".format(
        args.model_path, ckpt['epoch']))
    print('==> done')

    model = model.cuda()
    classifier = classifier.cuda()

    criterion = torch.nn.CrossEntropyLoss().cuda(args.gpu)

    if not args.adam:
        optimizer = torch.optim.SGD(classifier.parameters(),
                                    lr=args.learning_rate,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    else:
        optimizer = torch.optim.Adam(classifier.parameters(),
                                     lr=args.learning_rate,
                                     betas=(args.beta1, args.beta2),
                                     weight_decay=args.weight_decay,
                                     eps=1e-8)

    model.eval()
    cudnn.benchmark = True

    # set mixed precision training
    # if args.amp:
    #     model = amp.initialize(model, opt_level=args.opt_level)
    #     classifier, optimizer = amp.initialize(classifier, optimizer, opt_level=args.opt_level)

    # optionally resume from a checkpoint
    args.start_epoch = 1
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location='cpu')
            # checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch'] + 1
            classifier.load_state_dict(checkpoint['classifier'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            best_acc1 = checkpoint['best_acc1']
            best_acc1 = best_acc1.cuda()
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            if 'opt' in checkpoint.keys():
                # resume optimization hyper-parameters
                print('=> resume hyper parameters')
                if 'bn' in vars(checkpoint['opt']):
                    print('using bn: ', checkpoint['opt'].bn)
                if 'adam' in vars(checkpoint['opt']):
                    print('using adam: ', checkpoint['opt'].adam)
                if 'cosine' in vars(checkpoint['opt']):
                    print('using cosine: ', checkpoint['opt'].cosine)
                args.learning_rate = checkpoint['opt'].learning_rate
                # args.lr_decay_epochs = checkpoint['opt'].lr_decay_epochs
                args.lr_decay_rate = checkpoint['opt'].lr_decay_rate
                args.momentum = checkpoint['opt'].momentum
                args.weight_decay = checkpoint['opt'].weight_decay
                args.beta1 = checkpoint['opt'].beta1
                args.beta2 = checkpoint['opt'].beta2
            del checkpoint
            torch.cuda.empty_cache()
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # set cosine annealing scheduler
    if args.cosine:

        # last_epoch = args.start_epoch - 2
        # eta_min = args.learning_rate * (args.lr_decay_rate ** 3) * 0.1
        # scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, eta_min, last_epoch)

        eta_min = args.learning_rate * (args.lr_decay_rate**3) * 0.1
        scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer, args.epochs, eta_min, -1)
        # dummy loop to catch up with current epoch
        for i in range(1, args.start_epoch):
            scheduler.step()

    # tensorboard
    logger = tb_logger.Logger(logdir=args.tb_folder, flush_secs=2)

    # routine
    for epoch in range(args.start_epoch, args.epochs + 1):

        if args.cosine:
            scheduler.step()
        else:
            adjust_learning_rate(epoch, args, optimizer)
        print("==> training...")

        time1 = time.time()
        train_acc, train_acc5, train_loss = train(epoch, train_loader, model,
                                                  classifier, criterion,
                                                  optimizer, args)
        time2 = time.time()
        print('train epoch {}, total time {:.2f}'.format(epoch, time2 - time1))

        logger.log_value('train_acc', train_acc, epoch)
        logger.log_value('train_acc5', train_acc5, epoch)
        logger.log_value('train_loss', train_loss, epoch)
        logger.log_value('learning_rate', optimizer.param_groups[0]['lr'],
                         epoch)

        print("==> testing...")
        test_acc, test_acc5, test_loss = validate(val_loader, model,
                                                  classifier, criterion, args)

        logger.log_value('test_acc', test_acc, epoch)
        logger.log_value('test_acc5', test_acc5, epoch)
        logger.log_value('test_loss', test_loss, epoch)

        # save the best model
        if test_acc > best_acc1:
            best_acc1 = test_acc
            state = {
                'opt': args,
                'epoch': epoch,
                'classifier': classifier.state_dict(),
                'best_acc1': best_acc1,
                'optimizer': optimizer.state_dict(),
            }
            save_name = '{}_layer{}.pth'.format(args.model, args.layer)
            save_name = os.path.join(args.save_folder, save_name)
            print('saving best model!')
            torch.save(state, save_name)

        # save model
        if epoch % args.save_freq == 0:
            print('==> Saving...')
            state = {
                'opt': args,
                'epoch': epoch,
                'classifier': classifier.state_dict(),
                'best_acc1': test_acc,
                'optimizer': optimizer.state_dict(),
            }
            save_name = 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)
            save_name = os.path.join(args.save_folder, save_name)
            print('saving regular model!')
            torch.save(state, save_name)

        # tensorboard logger
        pass
def main():

    args = parse_option()

    test_dataset = getattr(module_data, args.dataset)(args.data_folder,
                                                      test=True,
                                                      pair_image=False,
                                                      imwidth=args.image_size,
                                                      crop=args.image_crop)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=args.num_workers,
                                              pin_memory=True)
    print('Number of test images: %d' % len(test_dataset))

    input_size = args.image_size - 2 * args.image_crop
    pool_size = int(input_size /
                    2**5)  # 96x96 --> 3; 160x160 --> 5; 224x224 --> 7;
    args.output_shape = (48, 48)
    args.boxsize = 48

    if args.model == 'resnet50':
        model = InsResNet50(pool_size=pool_size)
        desc_dim = {1: 64, 2: 256, 3: 512, 4: 1024, 5: 2048}
    elif args.model == 'resnet50x2':
        model = InsResNet50(width=2, pool_size=pool_size)
        desc_dim = {1: 128, 2: 512, 3: 1024, 4: 2048, 5: 4096}
    elif args.model == 'resnet50x4':
        model = InsResNet50(width=4, pool_size=pool_size)
        desc_dim = {1: 512, 2: 1024, 3: 2048, 4: 4096, 5: 8192}
    elif args.model == 'resnet18':
        model = InsResNet18(width=1, pool_size=pool_size)
        desc_dim = {1: 64, 2: 64, 3: 128, 4: 256, 5: 512}
    elif args.model == 'resnet34':
        model = InsResNet34(width=1, pool_size=pool_size)
        desc_dim = {1: 64, 2: 64, 3: 128, 4: 256, 5: 512}
    elif args.model == 'resnet101':
        model = InsResNet101(width=1, pool_size=pool_size)
        desc_dim = {1: 64, 2: 256, 3: 512, 4: 1024, 5: 2048}
    elif args.model == 'resnet152':
        model = InsResNet152(width=1, pool_size=pool_size)
        desc_dim = {1: 64, 2: 256, 3: 512, 4: 1024, 5: 2048}
    elif args.model == 'hourglass':
        model = HourglassNet()
    else:
        raise NotImplementedError('model not supported {}'.format(args.model))

    if args.model == 'hourglass':
        feat_dim = 64
    else:
        if args.use_hypercol:
            feat_dim = 0
            for i in range(args.layer):
                feat_dim += desc_dim[5 - i]
        else:
            feat_dim = desc_dim[args.layer]
    args.feat_dim = feat_dim

    print('==> loading pre-trained MOCO')
    ckpt = torch.load(args.trained_model_path, map_location='cpu')
    if args.model == 'hourglass':
        model.load_state_dict(clean_state_dict(ckpt["state_dict"]))
    else:
        model.load_state_dict(ckpt['model'], strict=False)
    print("==> loaded checkpoint '{}'".format(args.trained_model_path))
    print('==> done')

    model = model.cuda()
    cudnn.benchmark = True

    if args.vis_PCA:
        PCA(test_loader, model, args)
    else:
        criterion = selected_regression_loss
        regressor = IntermediateKeypointPredictor(
            feat_dim,
            num_annotated_points=args.num_points,
            num_intermediate_points=50,
            softargmax_mul=100.0)
        regressor = regressor.cuda()
        print('==> loading pre-trained landmark regressor {}'.format(
            args.ckpt_path))
        checkpoint = torch.load(args.ckpt_path, map_location='cpu')
        regressor.load_state_dict(checkpoint['regressor'])
        del checkpoint
        torch.cuda.empty_cache()
        test_PCK, test_loss = validate(test_loader, model, regressor,
                                       criterion, args)
Exemple #5
0
def main():

    # Load the arguments
    args = parse_option()

    dataset = args.dataset
    sample_size = args.sample_size
    layername = args.layer

    # Other values for places and imagenet MoCo model
    epoch = 240
    image_size = 224
    crop = 0.2
    crop_padding = 32
    batch_size = 1
    num_workers = 24
    train_sampler = None
    moco = True

    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    normalize = transforms.Normalize(mean=mean, std=std)

    # Set appropriate paths
    folder_path = "/data/vision/torralba/ganprojects/yyou/CMC_data/{}_models".format(
        dataset)
    model_name = "/{}_MoCo0.999_softmax_16384_resnet50_lr_0.03".format(dataset) \
                     + "_decay_0.0001_bsz_128_crop_0.2_aug_CJ"
    epoch_name = "/ckpt_epoch_{}.pth".format(epoch)
    my_path = folder_path + model_name + epoch_name

    data_path = "/data/vision/torralba/datasets/"
    web_path = "/data/vision/torralba/scratch/yyou/wednesday/dissection/"

    if dataset == "imagenet":
        data_path += "imagenet_pytorch"
        web_path += dataset + "/" + layername
    elif dataset == "places365":
        data_path += "places/places365_standard/places365standard_easyformat"
        web_path += dataset + "/" + layername

    # Create web path folder directory for this layer
    if not os.path.exists(web_path):
        os.makedirs(web_path)

    # Load validation data loader
    val_folder = data_path + "/val"
    val_transform = transforms.Compose([
        transforms.Resize(image_size + crop_padding),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        normalize,
    ])

    ds = QuickImageFolder(val_folder,
                          transform=val_transform,
                          shuffle=True,
                          two_crop=False)
    ds_loader = torch.utils.data.DataLoader(ds,
                                            batch_size=batch_size,
                                            shuffle=(train_sampler is None),
                                            num_workers=num_workers,
                                            pin_memory=True,
                                            sampler=train_sampler)

    # Load model from checkpoint
    checkpoint = torch.load(my_path)
    model_checkpoint = {
        key.replace(".module", ""): val
        for key, val in checkpoint['model'].items()
    }

    model = InsResNet50(parallel=False)
    model.load_state_dict(model_checkpoint)
    model = nethook.InstrumentedModel(model)
    model.cuda()

    # Renormalize RGB data from the staistical scaling in ds to [-1...1] range
    renorm = renormalize.renormalizer(source=ds, target='zc')

    # Retain desired layer with nethook
    batch = next(iter(ds_loader))[0]
    model.retain_layer(layername)
    model(batch.cuda())
    acts = model.retained_layer(layername).cpu()

    upfn = upsample.upsampler(
        target_shape=(56, 56),
        data_shape=(7, 7),
    )

    def flatten_activations(batch, *args):
        image_batch = batch
        _ = model(image_batch.cuda())
        acts = model.retained_layer(layername).cpu()
        hacts = upfn(acts)
        return hacts.permute(0, 2, 3, 1).contiguous().view(-1, acts.shape[1])

    def tally_quantile_for_layer(layername):
        rq = tally.tally_quantile(
            flatten_activations,
            dataset=ds,
            sample_size=sample_size,
            batch_size=100,
            cachefile='results/{}/{}_rq_cache.npz'.format(dataset, layername))
        return rq

    rq = tally_quantile_for_layer(layername)

    # Visualize range of activations (statistics of each filter over the sample images)
    fig, axs = plt.subplots(2, 2, figsize=(10, 8))
    axs = axs.flatten()
    quantiles = [0.5, 0.8, 0.9, 0.99]
    for i in range(4):
        axs[i].plot(rq.quantiles(quantiles[i]))
        axs[i].set_title("Rq quantiles ({})".format(quantiles[i]))
    fig.suptitle("{}  -  sample size of {}".format(dataset, sample_size))
    plt.savefig(web_path + "/rq_quantiles")

    # Set the image visualizer with the rq and percent level
    iv = imgviz.ImageVisualizer(224,
                                source=ds,
                                percent_level=0.95,
                                quantiles=rq)

    # Tally top k images that maximize the mean activation of the filter
    def max_activations(batch, *args):
        image_batch = batch.cuda()
        _ = model(image_batch)
        acts = model.retained_layer(layername)
        return acts.view(acts.shape[:2] + (-1, )).max(2)[0]

    def mean_activations(batch, *args):
        image_batch = batch.cuda()
        _ = model(image_batch)
        acts = model.retained_layer(layername)
        return acts.view(acts.shape[:2] + (-1, )).mean(2)

    topk = tally.tally_topk(
        mean_activations,
        dataset=ds,
        sample_size=sample_size,
        batch_size=100,
        cachefile='results/{}/{}_cache_mean_topk.npz'.format(
            dataset, layername))

    top_indexes = topk.result()[1]

    # Visualize top-activating images for a particular unit
    if not os.path.exists(web_path + "/top_activating_imgs"):
        os.makedirs(web_path + "/top_activating_imgs")

    def top_activating_imgs(unit):
        img_ids = [i for i in top_indexes[unit, :12]]
        images = [iv.masked_image(ds[i][0], \
                      model.retained_layer(layername)[0], unit) \
                      for i in img_ids]
        preds = [ds.classes[model(ds[i][0][None].cuda()).max(1)[1].item()]\
                    for i in img_ids]

        fig, axs = plt.subplots(3, 4, figsize=(16, 12))
        axs = axs.flatten()

        for i in range(12):
            axs[i].imshow(images[i])
            axs[i].tick_params(axis='both', which='both', bottom=False, \
                               left=False, labelbottom=False, labelleft=False)
            axs[i].set_title("img {} \n pred: {}".format(img_ids[i], preds[i]))
        fig.suptitle("unit {}".format(unit))

        plt.savefig(web_path + "/top_activating_imgs/unit_{}".format(unit))

    for unit in np.random.randint(len(top_indexes), size=10):
        top_activating_imgs(unit)

    def compute_activations(image_batch):
        image_batch = image_batch.cuda()
        _ = model(image_batch)
        acts_batch = model.retained_layer(layername)
        return acts_batch

    unit_images = iv.masked_images_for_topk(
        compute_activations,
        ds,
        topk,
        k=5,
        num_workers=10,
        pin_memory=True,
        cachefile='results/{}/{}_cache_top10images.npz'.format(
            dataset, layername))

    file = open("results/{}/unit_images.pkl".format(dataset, layername), 'wb')
    pickle.dump(unit_images, file)

    # Load a segmentation model
    segmodel, seglabels, segcatlabels = setting.load_segmenter('netpqc')

    # Intersections between every unit's 99th activation
    # and every segmentation class identified
    level_at_99 = rq.quantiles(0.99).cuda()[None, :, None, None]

    def compute_selected_segments(batch, *args):
        image_batch = batch.cuda()
        seg = segmodel.segment_batch(renorm(image_batch), downsample=4)
        _ = model(image_batch)
        acts = model.retained_layer(layername)
        hacts = upfn(acts)
        iacts = (hacts >
                 level_at_99).float()  # indicator where > 0.99 percentile.
        return tally.conditional_samples(iacts, seg)

    condi99 = tally.tally_conditional_mean(
        compute_selected_segments,
        dataset=ds,
        sample_size=sample_size,
        cachefile='results/{}/{}_cache_condi99.npz'.format(dataset, layername))

    iou99 = tally.iou_from_conditional_indicator_mean(condi99)
    file = open("results/{}/{}_iou99.pkl".format(dataset, layername), 'wb')
    pickle.dump(iou99, file)

    # Show units with best match to a segmentation class
    iou_unit_label_99 = sorted(
        [(unit, concept.item(), seglabels[concept], bestiou.item())
         for unit, (bestiou, concept) in enumerate(zip(*iou99.max(0)))],
        key=lambda x: -x[-1])

    fig, axs = plt.subplots(20, 1, figsize=(20, 80))
    axs = axs.flatten()

    for i, (unit, concept, label, score) in enumerate(iou_unit_label_99[:20]):
        axs[i].imshow(unit_images[unit])
        axs[i].set_title('unit %d; iou %g; label "%s"' % (unit, score, label))
        axs[i].set_xticks([])
        axs[i].set_yticks([])
    plt.savefig(web_path + "/best_unit_segmentation")
Exemple #6
0
def main():

    args = parse_option()

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    # set the data loader
    data_folder = os.path.join(args.data_folder, 'train')
    val_folder = os.path.join(args.data_folder, 'val')

    crop_padding = 32
    image_size = 224
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    normalize = transforms.Normalize(mean=mean, std=std)

    if args.aug == 'NULL' and args.dataset == 'imagenet':
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(image_size, scale=(args.crop, 1.)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
    elif args.aug == 'CJ':
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(image_size, scale=(args.crop, 1.)),
            transforms.RandomGrayscale(p=0.2),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
    # elif args.aug == 'NULL' and args.dataset == 'cifar':
    #     train_transform = transforms.Compose([
    #         transforms.RandomResizedCrop(size=32, scale=(0.2, 1.)),
    #         transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
    #         transforms.RandomGrayscale(p=0.2),
    #         transforms.RandomHorizontalFlip(p=0.5),
    #         transforms.ToTensor(),
    #         transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    #     ])
    #
    #     test_transform = transforms.Compose([
    #         transforms.ToTensor(),
    #         transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    #     ])
    elif args.aug == 'simple' and args.dataset == 'imagenet':
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(image_size, scale=(args.crop, 1.)),
            transforms.RandomHorizontalFlip(),
            get_color_distortion(1.0),
            transforms.ToTensor(),
            normalize,
        ])

        # TODO: Currently follow CMC
        test_transform = transforms.Compose([
            transforms.Resize(image_size + crop_padding),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            normalize,
        ])
    elif args.aug == 'simple' and args.dataset == 'cifar':
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(size=32),
            transforms.RandomHorizontalFlip(p=0.5),
            get_color_distortion(0.5),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        test_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

    else:
        raise NotImplemented('augmentation not supported: {}'.format(args.aug))

    # Get Datasets
    if args.dataset == "imagenet":
        train_dataset = ImageFolderInstance(data_folder,
                                            transform=train_transform,
                                            two_crop=args.moco)
        print(len(train_dataset))
        train_sampler = None
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=(train_sampler is None),
            num_workers=args.num_workers,
            pin_memory=True,
            sampler=train_sampler)

        test_dataset = datasets.ImageFolder(val_folder,
                                            transforms=test_transform)

        test_loader = torch.utils.data.DataLoader(test_dataset,
                                                  batch_size=256,
                                                  shuffle=False,
                                                  num_workers=args.num_workers,
                                                  pin_memory=True)

    elif args.dataset == 'cifar':
        # cifar-10 dataset
        if args.contrastive_model == 'simclr':
            train_dataset = CIFAR10Instance_double(root='./data',
                                                   train=True,
                                                   download=True,
                                                   transform=train_transform,
                                                   double=True)
        else:
            train_dataset = CIFAR10Instance(root='./data',
                                            train=True,
                                            download=True,
                                            transform=train_transform)
        train_sampler = None
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=(train_sampler is None),
            num_workers=args.num_workers,
            pin_memory=True,
            sampler=train_sampler,
            drop_last=True)

        test_dataset = CIFAR10Instance(root='./data',
                                       train=False,
                                       download=True,
                                       transform=test_transform)
        test_loader = torch.utils.data.DataLoader(test_dataset,
                                                  batch_size=100,
                                                  shuffle=False,
                                                  num_workers=args.num_workers)

    # create model and optimizer
    n_data = len(train_dataset)

    if args.model == 'resnet50':
        model = InsResNet50()
        if args.contrastive_model == 'moco':
            model_ema = InsResNet50()
    elif args.model == 'resnet50x2':
        model = InsResNet50(width=2)
        if args.contrastive_model == 'moco':
            model_ema = InsResNet50(width=2)
    elif args.model == 'resnet50x4':
        model = InsResNet50(width=4)
        if args.contrastive_model == 'moco':
            model_ema = InsResNet50(width=4)
    elif args.model == 'resnet50_cifar':
        model = InsResNet50_cifar()
        if args.contrastive_model == 'moco':
            model_ema = InsResNet50_cifar()
    else:
        raise NotImplementedError('model not supported {}'.format(args.model))

    # copy weights from `model' to `model_ema'
    if args.contrastive_model == 'moco':
        moment_update(model, model_ema, 0)

    # set the contrast memory and criterion
    if args.contrastive_model == 'moco':
        contrast = MemoryMoCo(128, n_data, args.nce_k, args.nce_t,
                              args.softmax).cuda(args.gpu)
    elif args.contrastive_model == 'simclr':
        contrast = None
    else:
        contrast = MemoryInsDis(128, n_data, args.nce_k, args.nce_t,
                                args.nce_m, args.softmax).cuda(args.gpu)

    if args.softmax:
        criterion = NCESoftmaxLoss()
    elif args.contrastive_model == 'simclr':
        criterion = BatchCriterion(1, args.nce_t, args.batch_size)
    else:
        criterion = NCECriterion(n_data)
    criterion = criterion.cuda(args.gpu)

    model = model.cuda()
    if args.contrastive_model == 'moco':
        model_ema = model_ema.cuda()

    # Exclude BN and bias if needed
    weight_decay = args.weight_decay
    if weight_decay and args.filter_weight_decay:
        parameters = add_weight_decay(model, weight_decay,
                                      args.filter_weight_decay)
        weight_decay = 0.
    else:
        parameters = model.parameters()

    optimizer = torch.optim.SGD(parameters,
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=weight_decay)
    cudnn.benchmark = True

    if args.amp:
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.opt_level)
        if args.contrastive_model == 'moco':
            optimizer_ema = torch.optim.SGD(model_ema.parameters(),
                                            lr=0,
                                            momentum=0,
                                            weight_decay=0)
            model_ema, optimizer_ema = amp.initialize(model_ema,
                                                      optimizer_ema,
                                                      opt_level=args.opt_level)

    if args.LARS:
        optimizer = LARS(optimizer=optimizer, eps=1e-8, trust_coef=0.001)

    # optionally resume from a checkpoint
    args.start_epoch = 0
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location='cpu')
            # checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch'] + 1
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            if contrast:
                contrast.load_state_dict(checkpoint['contrast'])
            if args.contrastive_model == 'moco':
                model_ema.load_state_dict(checkpoint['model_ema'])

            if args.amp and checkpoint['opt'].amp:
                print('==> resuming amp state_dict')
                amp.load_state_dict(checkpoint['amp'])

            print("=> loaded successfully '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            del checkpoint
            torch.cuda.empty_cache()
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # tensorboard
    logger = tb_logger.Logger(logdir=args.tb_folder, flush_secs=2)

    # routine
    for epoch in range(args.start_epoch, args.epochs + 1):

        print("==> training...")

        time1 = time.time()
        if args.contrastive_model == 'moco':
            loss, prob = train_moco(epoch, train_loader, model, model_ema,
                                    contrast, criterion, optimizer, args)
        elif args.contrastive_model == 'simclr':
            print("Train using simclr")
            loss, prob = train_simclr(epoch, train_loader, model, criterion,
                                      optimizer, args)
        else:
            print("Train using InsDis")
            loss, prob = train_ins(epoch, train_loader, model, contrast,
                                   criterion, optimizer, args)
        time2 = time.time()
        print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))

        # tensorboard logger
        logger.log_value('ins_loss', loss, epoch)
        logger.log_value('ins_prob', prob, epoch)
        logger.log_value('learning_rate', optimizer.param_groups[0]['lr'],
                         epoch)

        test_epoch = 2
        if epoch % test_epoch == 0:
            model.eval()

            if args.contrastive_model == 'moco':
                model_ema.eval()

            print('----------Evaluation---------')
            start = time.time()

            if args.dataset == 'cifar':
                acc = kNN(epoch,
                          model,
                          train_loader,
                          test_loader,
                          200,
                          args.nce_t,
                          n_data,
                          low_dim=128,
                          memory_bank=None)

            print("Evaluation Time: '{}'s".format(time.time() - start))
            # writer.add_scalar('nn_acc', acc, epoch)
            logger.log_value('Test accuracy', acc, epoch)

            # print('accuracy: {}% \t (best acc: {}%)'.format(acc, best_acc))
            print('[Epoch]: {}'.format(epoch))
            print('accuracy: {}%)'.format(acc))
            # test_log_file.flush()

        # save model
        if epoch % args.save_freq == 0:
            print('==> Saving...')
            state = {
                'opt': args,
                'model': model.state_dict(),
                # 'contrast': contrast.state_dict(),
                'optimizer': optimizer.state_dict(),
                'epoch': epoch,
            }
            if args.contrastive_model == 'moco':
                state['model_ema'] = model_ema.state_dict()
            if args.amp:
                state['amp'] = amp.state_dict()
            save_file = os.path.join(
                args.model_folder,
                'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
            torch.save(state, save_file)
            # help release GPU memory
            del state

        # saving the model
        print('==> Saving...')
        state = {
            'opt': args,
            'model': model.state_dict(),
            # 'contrast': contrast.state_dict(),
            'optimizer': optimizer.state_dict(),
            'epoch': epoch,
        }
        if args.contrastive_model == 'moco':
            state['model_ema'] = model_ema.state_dict()
        if args.amp:
            state['amp'] = amp.state_dict()
        save_file = os.path.join(args.model_folder, 'current.pth')
        torch.save(state, save_file)
        if epoch % args.save_freq == 0:
            save_file = os.path.join(
                args.model_folder,
                'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
            torch.save(state, save_file)
        # help release GPU memory
        del state
        torch.cuda.empty_cache()
def main():

    global best_error
    best_error = np.Inf

    args = parse_option()
    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))
    torch.manual_seed(0)

    # train on celebA unlabeled dataset
    train_dataset = CelebAPrunedAligned_MAFLVal(args.data_folder, 
                                                train=True, 
                                                pair_image=False, 
                                                do_augmentations=True,
                                                imwidth=args.image_size,
                                                crop = args.image_crop)
    print('Number of training images: %d' % len(train_dataset))
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=True,
        num_workers=args.num_workers, pin_memory=True, sampler=None)

    
    # validation set from MAFLAligned trainset for hyperparameter searching
    # we sample 2000 images as our val set
    val_dataset   = MAFLAligned(args.data_folder, 
                                train=True, # train set
                                pair_image=True, 
                                do_augmentations=False,
                                TPS_aug = True, 
                                imwidth=args.image_size, 
                                crop=args.image_crop)
    print('Initial number of validation images: %d' % len(val_dataset)) 
    val_dataset.restrict_annos(num=2000, outpath=args.save_folder, repeat_flag=False)
    print('After restricting the size of validation set: %d' % len(val_dataset))
    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=2, shuffle=False,
        num_workers=args.num_workers, pin_memory=True)


    # testing set from MAFLAligned test for evaluating image matching
    test_dataset = MAFLAligned(args.data_folder, 
                               train=False, # test set 
                               pair_image=True, 
                               do_augmentations=False,
                               TPS_aug = True, # match landmark between deformed images
                               imwidth=args.image_size, 
                               crop=args.image_crop)
    print('Number of testing images: %d' % len(test_dataset)) 
    test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=2, shuffle=False,
        num_workers=args.num_workers, pin_memory=True)

    assert len(val_dataset) == 2000
    assert len(test_dataset) == 1000


    # create model and optimizer
    input_size = args.image_size - 2 * args.image_crop
    pool_size = int(input_size / 2**5) # 96x96 --> 3; 160x160 --> 5; 224x224 --> 7;
    # we use smaller feature map when training the feature distiller for memory issue
    args.train_output_shape = (args.train_out_size, args.train_out_size)
    # we use the original size of the image (e.g. 96x96 face images) during testing
    args.val_output_shape = (args.val_out_size, args.val_out_size)

    if args.model == 'resnet50':
        model = InsResNet50(pool_size=pool_size)
        desc_dim = {1:64, 2:256, 3:512, 4:1024, 5:2048}
    elif args.model == 'resnet50_half':
        model = InsResNet50(width=0.5, pool_size=pool_size)
        desc_dim = {1:int(64/2), 2:int(256/2), 3:int(512/2), 4:int(1024/2), 5:int(2048/2)}
    elif args.model == 'resnet50x2':
        model = InsResNet50(width=2, pool_size=pool_size)
        desc_dim = {1:128, 2:512, 3:1024, 4:2048, 5:4096}
    elif args.model == 'resnet50x4':
        model = InsResNet50(width=4, pool_size=pool_size)
        desc_dim = {1:512, 2:1024, 3:2048, 4:4096, 5:8192}
    elif args.model == 'resnet18':
        model = InsResNet18(width=1, pool_size=pool_size)
        desc_dim = {1:64, 2:64, 3:128, 4:256, 5:512}
    elif args.model == 'resnet34':
        model = InsResNet34(width=1, pool_size=pool_size)
        desc_dim = {1:64, 2:64, 3:128, 4:256, 5:512}
    elif args.model == 'resnet101':
        model = InsResNet101(width=1, pool_size=pool_size)
        desc_dim = {1:64, 2:256, 3:512, 4:1024, 5:2048}
    elif args.model == 'resnet152':
        model = InsResNet152(width=1, pool_size=pool_size)
        desc_dim = {1:64, 2:256, 3:512, 4:1024, 5:2048}
    elif args.model == 'hourglass':
        model = HourglassNet()
    else:
        raise NotImplementedError('model not supported {}'.format(args.model))
    
    
    # xxx_feat_spectral records the feat dim per layer in hypercol
    # this information is useful to do layer-wise feat normalization in landmark matching
    train_feat_spectral = [] 
    if args.train_use_hypercol:
        for i in range(args.train_layer):
            train_feat_spectral.append(desc_dim[5-i])
    else:
        train_feat_spectral.append(desc_dim[args.train_layer])
    args.train_feat_spectral = train_feat_spectral

    val_feat_spectral = []
    if args.val_use_hypercol:
        for i in range(args.val_layer):
            val_feat_spectral.append(desc_dim[5-i])
    else:
        val_feat_spectral.append(desc_dim[args.val_layer])
    args.val_feat_spectral = val_feat_spectral
    

    # load pretrained moco 
    if args.trained_model_path != 'none':
        print('==> loading pre-trained model')
        ckpt = torch.load(args.trained_model_path, map_location='cpu')
        model.load_state_dict(ckpt['model'], strict=True)
        print("==> loaded checkpoint '{}' (epoch {})".format(
                            args.trained_model_path, ckpt['epoch']))
        print('==> done')
    else:
        print('==> use randomly initialized model')


    # Define feature distiller, set pretrained model to eval mode
    if args.feat_distill:
        model.eval()
        assert np.sum(train_feat_spectral) == np.sum(val_feat_spectral)
        feat_distiller = FeatDistiller(np.sum(val_feat_spectral), 
                                        kernel_size=args.kernel_size,
                                        mode=args.distill_mode,
                                        out_dim = args.out_dim,
                                        softargmax_mul=args.softargmax_mul)
        feat_distiller = nn.DataParallel(feat_distiller)
        feat_distiller.train()
        print('Feature distillation is used: kernel_size:{}, mode:{}, out_dim:{}'.format(
                args.kernel_size, args.distill_mode, args.out_dim))
        feat_distiller = feat_distiller.cuda()
    else:
        feat_distiller = None


    #  evaluate feat distiller on landmark matching, given pretrained moco and feature distiller
    model = model.cuda()
    if args.evaluation_mode:
        if args.feat_distill:
            print("==> use pretrained feature distiller ...")
            feat_ckpt = torch.load(args.trained_feat_model_path, map_location='cpu') 
            # in below, feat_distiller is misspelt, but to use pretrained model, I keep it.
            feat_distiller.load_state_dict(feat_ckpt['feat_disiller'], strict=False)
            print("==> loaded checkpoint '{}' (epoch {})".format(
                                args.trained_feat_model_path, feat_ckpt['epoch']))
            same_err, diff_err = validate(test_loader, model, args, 
                                        feat_distiller=feat_distiller, 
                                        visualization=args.visualize_matching)
        else:
            print("==> use hypercolumn ...")
            same_err, diff_err = validate(test_loader, model, args, 
                                        feat_distiller=None, 
                                        visualization=args.visualize_matching)
        exit()


    ## define optimizer for feature distiller  
    if not args.adam:
        if not args.feat_distill:
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=args.learning_rate,
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay)
        else:
            optimizer = torch.optim.SGD(feat_distiller.parameters(),
                                        lr=args.learning_rate,
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay)
    else:
        if not args.feat_distill:
            optimizer = torch.optim.Adam(model.parameters(),
                                         lr=args.learning_rate,
                                         betas=(args.beta1, args.beta2),
                                         weight_decay=args.weight_decay,
                                         eps=1e-8,
                                         amsgrad=args.amsgrad)
        else:
            optimizer = torch.optim.Adam(feat_distiller.parameters(),
                                         lr=args.learning_rate,
                                         betas=(args.beta1, args.beta2),
                                         weight_decay=args.weight_decay,
                                         eps=1e-8,
                                         amsgrad=args.amsgrad)



    # set lr scheduler
    if args.cosine: # we use cosine scheduler by default
        eta_min = args.learning_rate * (args.lr_decay_rate ** 3) * 0.1
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, eta_min, -1)
    elif args.multistep:
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 250], gamma=0.1)

    # tensorboard
    logger = tb_logger.Logger(logdir=args.tb_folder, flush_secs=2)
    cudnn.benchmark = True

    # report the performance of hypercol on landmark matching tasks
    print("==> Testing of initial model on validation set...")
    same_err, diff_err = validate(val_loader, model, args, feat_distiller=None)
    print("==> Testing of initial model on test set...")
    same_err, diff_err = validate(test_loader, model, args, feat_distiller=None)
    
    # training loss for feature projector
    criterion = dense_corr_loss

    # training feature distiller
    for epoch in range(1, args.epochs + 1):
        if args.cosine or args.multistep:
            scheduler.step()
        else:
            adjust_learning_rate(epoch, args, optimizer)

        print("==> training ...")
        time1 = time.time()
        train_loss = train_point_contrast(epoch, train_loader, model, criterion, optimizer, args, 
                                            feat_distiller=feat_distiller)
        time2 = time.time()
        print('train epoch {}, total time {:.2f}, train_loss {:.4f}'.format(epoch, 
                                            time2 - time1, train_loss))
        logger.log_value('train_loss', train_loss, epoch)
        logger.log_value('learning_rate', optimizer.param_groups[0]['lr'], epoch)


        print("==> validation ...")
        val_same_err, val_diff_err = validate(val_loader, model, args, 
                                            feat_distiller=feat_distiller)


        print("==> testing ...")
        test_same_err, test_diff_err = validate(test_loader, model, args, 
                                            feat_distiller=feat_distiller)

        # save model
        if epoch % args.save_freq == 0:
            print('==> Saving...')
            state = {
                'opt': args,
                'epoch': epoch,
                'feat_disiller': feat_distiller.state_dict(),
                'val_error': [val_same_err, val_diff_err],
                'test_error': [test_same_err, test_diff_err],
            }
            save_name = 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)
            save_name = os.path.join(args.save_folder, save_name)
            print('saving regular model!')
            torch.save(state, save_name)

            if val_diff_err < best_error:
                best_error = val_diff_err
                save_name = 'best.pth'
                save_name = os.path.join(args.save_folder, save_name)
                print('saving best model! val_same: {} val_diff: {} test_same: {} test_diff: {}'.format(val_same_err, val_diff_err, test_same_err, test_diff_err))
                torch.save(state, save_name)
def main():

    global best_error
    best_error = np.Inf

    args = parse_option()

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    val_dataset = getattr(module_data, args.dataset)(args.data_folder,
                                                     train=False,
                                                     pair_image=False,
                                                     do_augmentations=False,
                                                     imwidth=args.image_size,
                                                     crop=args.image_crop)

    print('Number of validation images: %d' % len(val_dataset))
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.num_workers,
                                             pin_memory=True)

    # create model and optimizer
    input_size = args.image_size - 2 * args.image_crop
    pool_size = int(input_size /
                    2**5)  # 96x96 --> 3; 160x160 --> 5; 224x224 --> 7;
    args.output_shape = (48, 48)

    if args.model == 'resnet50':
        model = InsResNet50(pool_size=pool_size)
        desc_dim = {1: 64, 2: 256, 3: 512, 4: 1024, 5: 2048}
    elif args.model == 'resnet50x2':
        model = InsResNet50(width=2, pool_size=pool_size)
        desc_dim = {1: 128, 2: 512, 3: 1024, 4: 2048, 5: 4096}
    elif args.model == 'resnet50x4':
        model = InsResNet50(width=4, pool_size=pool_size)
        desc_dim = {1: 512, 2: 1024, 3: 2048, 4: 4096, 5: 8192}
    elif args.model == 'resnet18':
        model = InsResNet18(width=1, pool_size=pool_size)
        desc_dim = {1: 64, 2: 64, 3: 128, 4: 256, 5: 512}
    elif args.model == 'resnet34':
        model = InsResNet34(width=1, pool_size=pool_size)
        desc_dim = {1: 64, 2: 64, 3: 128, 4: 256, 5: 512}
    elif args.model == 'resnet101':
        model = InsResNet101(width=1, pool_size=pool_size)
        desc_dim = {1: 64, 2: 256, 3: 512, 4: 1024, 5: 2048}
    elif args.model == 'resnet152':
        model = InsResNet152(width=1, pool_size=pool_size)
        desc_dim = {1: 64, 2: 256, 3: 512, 4: 1024, 5: 2048}
    elif args.model == 'hourglass':
        model = HourglassNet()
    else:
        raise NotImplementedError('model not supported {}'.format(args.model))

    if args.model == 'hourglass':
        feat_dim = 64
    else:
        if args.use_hypercol:
            feat_dim = 0
            for i in range(args.layer):
                feat_dim += desc_dim[5 - i]
        else:
            feat_dim = desc_dim[args.layer]

    args.feat_dim = feat_dim

    if not args.random_network:
        print('==> loading pre-trained model')
        ckpt = torch.load(args.trained_model_path, map_location='cpu')
        model.load_state_dict(ckpt['model'], strict=False)
        print("==> loaded checkpoint '{}' (epoch {})".format(
            args.trained_model_path, ckpt['epoch']))
    else:
        print("==> loaded randomly initialized model")
    print('==> done')

    model = model.cuda()
    model.eval()
    cudnn.benchmark = True

    # visualize the PCA projections of representation
    if args.vis_PCA:
        PCA(val_loader, model, args)
    else:
        criterion = regression_loss
        # visualize the landmark detection results
        regressor = IntermediateKeypointPredictor(
            feat_dim,
            num_annotated_points=args.num_points,
            num_intermediate_points=50,
            softargmax_mul=100.0)
        regressor = regressor.cuda()
        checkpoint = torch.load(args.ckpt_path, map_location='cpu')
        regressor.load_state_dict(checkpoint['regressor'])
        del checkpoint
        torch.cuda.empty_cache()
        # routine
        # print("==> testing epoch %d" % ckpt_epoch)
        print("===> testing check points: %s" % args.ckpt_path)
        test_InterOcularError, test_loss = validate(val_loader, model,
                                                    regressor, criterion, args)
def main():

    global best_error
    best_error = np.Inf

    args = parse_option()

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    train_dataset = getattr(module_data, args.dataset)(args.data_folder, 
                                         train=True, 
                                         pair_image=False, 
                                         do_augmentations=False,
                                         imwidth=args.image_size, 
                                         crop=args.image_crop,
                                         TPS_aug=args.TPS_aug) # using TPS for data augmentation
    val_dataset   = getattr(module_data, args.dataset)(args.data_folder, 
                                         train=False, 
                                         pair_image=False, 
                                         do_augmentations=False,
                                         imwidth=args.image_size, 
                                         crop=args.image_crop)

    print('Number of training images: %d' % len(train_dataset))
    print('Number of validation images: %d' % len(val_dataset))


    # for the few-shot experiments: using limited annotations to train the landmark regression
    if args.restrict_annos > -1:
        if args.resume:
            train_dataset.restrict_annos(num=args.restrict_annos, datapath=args.save_folder, 
                                                    repeat_flag=args.repeat, num_per_epoch = args.num_per_epoch)
        else:
            train_dataset.restrict_annos(num=args.restrict_annos, outpath=args.save_folder,  
                                                    repeat_flag=args.repeat, num_per_epoch = args.num_per_epoch)
        print('Now restricting number of images to %d, sanity check: %d; number per epoch %d' % (args.restrict_annos, 
                                                    len(train_dataset), args.num_per_epoch))
    

    train_sampler = None
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
        num_workers=args.num_workers, pin_memory=True, sampler=train_sampler)
    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=args.batch_size, shuffle=False,
        num_workers=args.num_workers, pin_memory=True)

    # create model and optimizer
    input_size = args.image_size - 2 * args.image_crop
    pool_size = int(input_size / 2**5) # 96x96 --> 3; 160x160 --> 5; 224x224 --> 7;
    args.output_shape = (48,48)

    if args.model == 'resnet50':
        model = InsResNet50(pool_size=pool_size)
        desc_dim = {1:64, 2:256, 3:512, 4:1024, 5:2048}
    elif args.model == 'resnet50x2':
        model = InsResNet50(width=2, pool_size=pool_size)
        desc_dim = {1:128, 2:512, 3:1024, 4:2048, 5:4096}
    elif args.model == 'resnet50x4':
        model = InsResNet50(width=4, pool_size=pool_size)
        desc_dim = {1:512, 2:1024, 3:2048, 4:4096, 5:8192}
    elif args.model == 'resnet18':
        model = InsResNet18(width=1, pool_size=pool_size)
        desc_dim = {1:64, 2:64, 3:128, 4:256, 5:512}
    elif args.model == 'resnet34':
        model = InsResNet34(width=1, pool_size=pool_size)
        desc_dim = {1:64, 2:64, 3:128, 4:256, 5:512}
    elif args.model == 'resnet101':
        model = InsResNet101(width=1, pool_size=pool_size)
        desc_dim = {1:64, 2:256, 3:512, 4:1024, 5:2048}
    elif args.model == 'resnet152':
        model = InsResNet152(width=1, pool_size=pool_size)
        desc_dim = {1:64, 2:256, 3:512, 4:1024, 5:2048}
    elif args.model == 'hourglass':
        model = HourglassNet()
    else:
        raise NotImplementedError('model not supported {}'.format(args.model))


    if args.model == 'hourglass':
        feat_dim = 64
    else:
        if args.use_hypercol:
            feat_dim = 0
            for i in range(args.layer):
                feat_dim += desc_dim[5-i]
        else:
            feat_dim = desc_dim[args.layer]

   
    regressor =  IntermediateKeypointPredictor(feat_dim, num_annotated_points=args.num_points, 
                                                num_intermediate_points=50, 
                                                softargmax_mul = 100.0)

    print('==> loading pre-trained model')
    ckpt = torch.load(args.trained_model_path, map_location='cpu')
    model.load_state_dict(ckpt['model'], strict=False)
    print("==> loaded checkpoint '{}' (epoch {})".format(args.trained_model_path, ckpt['epoch']))
    print('==> done')

    model = model.cuda()
    regressor = regressor.cuda()

    criterion = regression_loss

    if not args.adam:
        optimizer = torch.optim.SGD(regressor.parameters(),
                                    lr=args.learning_rate,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    else:
        optimizer = torch.optim.Adam(regressor.parameters(),
                                     lr=args.learning_rate,
                                     betas=(args.beta1, args.beta2),
                                     weight_decay=args.weight_decay,
                                     eps=1e-8,
                                     amsgrad=args.amsgrad)
    model.eval()
    cudnn.benchmark = True

    # optionally resume from a checkpoint
    args.start_epoch = 1
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location='cpu')
            # checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch'] + 1
            regressor.load_state_dict(checkpoint['regressor'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            best_error = checkpoint['best_error']
            best_error = best_error.cuda()
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
            if 'opt' in checkpoint.keys():
                # resume optimization hyper-parameters
                print('=> resume hyper parameters')
                if 'bn' in vars(checkpoint['opt']):
                    print('using bn: ', checkpoint['opt'].bn)
                if 'adam' in vars(checkpoint['opt']):
                    print('using adam: ', checkpoint['opt'].adam)
                if 'cosine' in vars(checkpoint['opt']):
                    print('using cosine: ', checkpoint['opt'].cosine)
                args.learning_rate = checkpoint['opt'].learning_rate
                # args.lr_decay_epochs = checkpoint['opt'].lr_decay_epochs
                args.lr_decay_rate = checkpoint['opt'].lr_decay_rate
                args.momentum = checkpoint['opt'].momentum
                args.weight_decay = checkpoint['opt'].weight_decay
                args.beta1 = checkpoint['opt'].beta1
                args.beta2 = checkpoint['opt'].beta2
            del checkpoint
            torch.cuda.empty_cache()
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # set cosine annealing scheduler
    if args.cosine:

        # last_epoch = args.start_epoch - 2
        # eta_min = args.learning_rate * (args.lr_decay_rate ** 3) * 0.1
        # scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, eta_min, last_epoch)

        eta_min = args.learning_rate * (args.lr_decay_rate ** 3) * 0.1
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, eta_min, -1)
        # dummy loop to catch up with current epoch
        for i in range(1, args.start_epoch):
            scheduler.step()
    elif args.multistep:
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 250], gamma=0.1)
        # dummy loop to catch up with current epoch
        for i in range(1, args.start_epoch):
            scheduler.step()

    # tensorboard
    logger = tb_logger.Logger(logdir=args.tb_folder, flush_secs=2)

    # routine
    for epoch in range(args.start_epoch, args.epochs + 1):

        if args.cosine or args.multistep:
            scheduler.step()
        else:
            adjust_learning_rate(epoch, args, optimizer)
        print("==> training...")

        time1 = time.time()
        InterOcularError, train_loss = train(epoch, train_loader, model, regressor, criterion, optimizer, args)
        time2 = time.time()
        print('train epoch {}, total time {:.2f}'.format(epoch, time2 - time1))

        logger.log_value('InterOcularError', InterOcularError, epoch)
        logger.log_value('train_loss', train_loss, epoch)
        logger.log_value('learning_rate', optimizer.param_groups[0]['lr'], epoch)

        print("==> testing...")
        test_InterOcularError, test_loss = validate(val_loader, model, regressor, criterion, args)

        logger.log_value('Test_InterOcularError', test_InterOcularError, epoch)
        logger.log_value('test_loss', test_loss, epoch) 

        # save the best model
        if test_InterOcularError < best_error:
            best_error = test_InterOcularError
            state = {
                'opt': args,
                'epoch': epoch,
                'regressor': regressor.state_dict(),
                'best_error': best_error,
                'optimizer': optimizer.state_dict(),
            }
            save_name = '{}.pth'.format(args.model)
            save_name = os.path.join(args.save_folder, save_name)
            print('saving best model!')
            torch.save(state, save_name)

        # save model
        if epoch % args.save_freq == 0:
            print('==> Saving...')
            state = {
                'opt': args,
                'epoch': epoch,
                'regressor': regressor.state_dict(),
                'best_error': test_InterOcularError,
                'optimizer': optimizer.state_dict(),
            }
            save_name = 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)
            save_name = os.path.join(args.save_folder, save_name)
            print('saving regular model!')
            torch.save(state, save_name)

        # tensorboard logger
        pass
Exemple #10
0
def main():
    
    args = parse_option()
    args.moco = True
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    normalize = transforms.Normalize(mean=mean, std=std)

    my_path = "/data/vision/torralba/ganprojects/yyou/CMC"
    
    model_name = "/{}_MoCo0.999_softmax_16384_resnet50_lr_0.03_decay_0.0001_bsz_128_crop_0.2_aug_CJ".format(args.dataset)
    
    if args.dataset == "imagenet":
        dataset_path = "/data/vision/torralba/datasets/imagenet_pytorch/imagenet_pytorch"
        model_ckpt_path = "/imagenet_models"
    elif args.dataset == "places365":
        dataset_path = "/data/vision/torralba/datasets/places/places365_standard/places365standard_easyformat"
        model_ckpt_path = "/places365_models"
    else:
        raise ValueError("Unsupported dataset type of {}".format(args.datset))

    train_folder = dataset_path + "/train"
    val_folder = dataset_path + "/val"
        
    if args.folder == "train":
        print("=> Loading train set")
        crop = 0.8
        train_sampler = None

        transform = transforms.Compose([
                        transforms.RandomResizedCrop(args.imsize, scale=(crop, 1.)),
                        transforms.RandomGrayscale(p=0.2),
                        transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
                        transforms.RandomHorizontalFlip(),
                        transforms.ToTensor(),
                        normalize,
                    ])
        dataset = QuickImageFolder(train_folder, transform=transform, two_crop=args.moco)
        loader = torch.utils.data.DataLoader(
                    dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
                    num_workers=args.num_workers, pin_memory=True, sampler=train_sampler)
    else:
        print("=> Loading val set")
        transform = transforms.Compose([
                            transforms.Resize(args.imsize + args.crop_pad),
                            transforms.CenterCrop(args.imsize),
                            transforms.ToTensor(),
                            normalize,
                        ])

        dataset = QuickImageFolder(val_folder, transform=transform, two_crop=args.moco)
        loader = torch.utils.data.DataLoader(
                    dataset, batch_size=args.batch_size, 
                    shuffle=False, num_workers=args.num_workers, pin_memory=True)
            
    if not isinstance(args.epochs, list):
        args.epochs = [args.epochs]

    print("Generating for epochs: ", args.epochs)
    
    for epoch in args.epochs:
        print("====> Working on epoch: {}".format(epoch))
        epoch_name = "/ckpt_epoch_{}.pth".format(epoch)

        model_path = my_path + model_ckpt_path + model_name + epoch_name

        model = InsResNet50()
        checkpoint = torch.load(model_path)
        model.load_state_dict(checkpoint['model'])
        model.cuda()

        #idx_to_feat, idx_to_label= img_to_feat(model, loader)
        idx_to_feat = img_to_feat(model, loader)
        save_obj(idx_to_feat, my_path+"/pkl"+"/{}_{}_to_feat_epoch{}".format(args.dataset, args.folder, epoch))