Пример #1
0
def main():

    args = parse_option()

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.dataset == 'CelebA':
        train_dataset = CelebAPrunedAligned_MAFLVal(args.data_folder,
                                                    train=True,
                                                    pair_image=True,
                                                    do_augmentations=True,
                                                    imwidth=args.image_size,
                                                    crop=args.image_crop)
    elif args.dataset == 'InatAve':
        train_dataset = InatAve(args.data_folder,
                                train=True,
                                pair_image=True,
                                do_augmentations=True,
                                imwidth=args.image_size,
                                imagelist=args.imagelist)
    else:
        raise NotImplementedError('dataset not supported {}'.format(
            args.dataset))

    print(len(train_dataset))
    train_sampler = None
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.num_workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    # create model and optimizer
    n_data = len(train_dataset)

    input_size = args.image_size - 2 * args.image_crop
    pool_size = int(input_size /
                    2**5)  # 96x96 --> 3; 160x160 --> 5; 224x224 --> 7;

    if args.model == 'resnet50':
        model = InsResNet50(pool_size=pool_size)
        model_ema = InsResNet50(pool_size=pool_size)
    elif args.model == 'resnet50x2':
        model = InsResNet50(width=2, pool_size=pool_size)
        model_ema = InsResNet50(width=2, pool_size=pool_size)
    elif args.model == 'resnet50x4':
        model = InsResNet50(width=4, pool_size=pool_size)
        model_ema = InsResNet50(width=4, pool_size=pool_size)
    elif args.model == 'resnet18':
        model = InsResNet18(width=1, pool_size=pool_size)
        model_ema = InsResNet18(width=1, pool_size=pool_size)
    elif args.model == 'resnet34':
        model = InsResNet34(width=1, pool_size=pool_size)
        model_ema = InsResNet34(width=1, pool_size=pool_size)
    elif args.model == 'resnet101':
        model = InsResNet101(width=1, pool_size=pool_size)
        model_ema = InsResNet101(width=1, pool_size=pool_size)
    elif args.model == 'resnet152':
        model = InsResNet152(width=1, pool_size=pool_size)
        model_ema = InsResNet152(width=1, pool_size=pool_size)
    else:
        raise NotImplementedError('model not supported {}'.format(args.model))

    # copy weights from `model' to `model_ema'
    moment_update(model, model_ema, 0)

    # set the contrast memory and criterion
    contrast = MemoryMoCo(128,
                          n_data,
                          args.nce_k,
                          args.nce_t,
                          use_softmax=True).cuda(args.gpu)

    criterion = NCESoftmaxLoss()
    criterion = criterion.cuda(args.gpu)

    model = model.cuda()
    model_ema = model_ema.cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    cudnn.benchmark = True

    # optionally resume from a checkpoint
    args.start_epoch = 1
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location='cpu')
            # checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch'] + 1
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            contrast.load_state_dict(checkpoint['contrast'])
            model_ema.load_state_dict(checkpoint['model_ema'])
            print("=> loaded successfully '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            del checkpoint
            torch.cuda.empty_cache()
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # tensorboard
    logger = tb_logger.Logger(logdir=args.tb_folder, flush_secs=2)

    # routine
    for epoch in range(args.start_epoch, args.epochs + 1):

        adjust_learning_rate(epoch, args, optimizer)
        print("==> training...")

        time1 = time.time()
        loss, prob = train_moco(epoch, train_loader, model, model_ema,
                                contrast, criterion, optimizer, args)
        time2 = time.time()
        print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))

        # tensorboard logger
        logger.log_value('ins_loss', loss, epoch)
        logger.log_value('ins_prob', prob, epoch)
        logger.log_value('learning_rate', optimizer.param_groups[0]['lr'],
                         epoch)

        # save model
        if epoch % args.save_freq == 0:
            print('==> Saving...')
            state = {
                'opt': args,
                'model': model.state_dict(),
                'contrast': contrast.state_dict(),
                'optimizer': optimizer.state_dict(),
                'epoch': epoch,
            }
            state['model_ema'] = model_ema.state_dict()
            save_file = os.path.join(
                args.model_folder,
                'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
            torch.save(state, save_file)
            # help release GPU memory
            del state

        # saving the model
        print('==> Saving...')
        state = {
            'opt': args,
            'model': model.state_dict(),
            'contrast': contrast.state_dict(),
            'optimizer': optimizer.state_dict(),
            'epoch': epoch,
        }
        state['model_ema'] = model_ema.state_dict()
        save_file = os.path.join(args.model_folder, 'current.pth')
        torch.save(state, save_file)
        if epoch % args.save_freq == 0:
            save_file = os.path.join(
                args.model_folder,
                'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
            torch.save(state, save_file)
        # help release GPU memory
        del state
        torch.cuda.empty_cache()
Пример #2
0
def main():

    args = parse_option()

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    # set the data loader
    data_folder = os.path.join(args.data_folder, 'train')

    image_size = 224
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    normalize = transforms.Normalize(mean=mean, std=std)

    if args.aug == 'NULL':
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(image_size, scale=(args.crop, 1.)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
    elif args.aug == 'CJ':
        train_transform = transforms.Compose([
            transforms.RandomResizedCrop(image_size, scale=(args.crop, 1.)),
            transforms.RandomGrayscale(p=0.2),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
    else:
        raise NotImplemented('augmentation not supported: {}'.format(args.aug))

    train_dataset = ImageFolderInstance(data_folder,
                                        transform=train_transform,
                                        two_crop=args.moco)
    print(len(train_dataset))
    train_sampler = None
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.num_workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    # create model and optimizer
    n_data = len(train_dataset)

    if args.model == 'resnet50':
        model = InsResNet50()
        if args.moco:
            model_ema = InsResNet50()
    elif args.model == 'resnet50x2':
        model = InsResNet50(width=2)
        if args.moco:
            model_ema = InsResNet50(width=2)
    elif args.model == 'resnet50x4':
        model = InsResNet50(width=4)
        if args.moco:
            model_ema = InsResNet50(width=4)
    else:
        raise NotImplementedError('model not supported {}'.format(args.model))

    # copy weights from `model' to `model_ema'
    if args.moco:
        moment_update(model, model_ema, 0)

    # set the contrast memory and criterion
    if args.moco:
        contrast = MemoryMoCo(128, n_data, args.nce_k, args.nce_t,
                              args.softmax).cuda(args.gpu)
    else:
        contrast = MemoryInsDis(128, n_data, args.nce_k, args.nce_t,
                                args.nce_m, args.softmax).cuda(args.gpu)

    criterion = NCESoftmaxLoss() if args.softmax else NCECriterion(n_data)
    criterion = criterion.cuda(args.gpu)

    model = model.cuda()
    if args.moco:
        model_ema = model_ema.cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    cudnn.benchmark = True

    if args.amp:
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.opt_level)
        if args.moco:
            optimizer_ema = torch.optim.SGD(model_ema.parameters(),
                                            lr=0,
                                            momentum=0,
                                            weight_decay=0)
            model_ema, optimizer_ema = amp.initialize(model_ema,
                                                      optimizer_ema,
                                                      opt_level=args.opt_level)

    # optionally resume from a checkpoint
    args.start_epoch = 1
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location='cpu')
            # checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch'] + 1
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            contrast.load_state_dict(checkpoint['contrast'])
            if args.moco:
                model_ema.load_state_dict(checkpoint['model_ema'])

            if args.amp and checkpoint['opt'].amp:
                print('==> resuming amp state_dict')
                amp.load_state_dict(checkpoint['amp'])

            print("=> loaded successfully '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            del checkpoint
            torch.cuda.empty_cache()
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # tensorboard
    logger = tb_logger.Logger(logdir=args.tb_folder, flush_secs=2)

    # routine
    for epoch in range(args.start_epoch, args.epochs + 1):

        adjust_learning_rate(epoch, args, optimizer)
        print("==> training...")

        time1 = time.time()
        if args.moco:
            loss, prob = train_moco(epoch, train_loader, model, model_ema,
                                    contrast, criterion, optimizer, args)
        else:
            loss, prob = train_ins(epoch, train_loader, model, contrast,
                                   criterion, optimizer, args)
        time2 = time.time()
        print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))

        # tensorboard logger
        logger.log_value('ins_loss', loss, epoch)
        logger.log_value('ins_prob', prob, epoch)
        logger.log_value('learning_rate', optimizer.param_groups[0]['lr'],
                         epoch)

        # save model
        if epoch % args.save_freq == 0:
            print('==> Saving...')
            state = {
                'opt': args,
                'model': model.state_dict(),
                'contrast': contrast.state_dict(),
                'optimizer': optimizer.state_dict(),
                'epoch': epoch,
            }
            if args.moco:
                state['model_ema'] = model_ema.state_dict()
            if args.amp:
                state['amp'] = amp.state_dict()
            save_file = os.path.join(
                args.model_folder,
                'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
            torch.save(state, save_file)
            # help release GPU memory
            del state

        # saving the model
        print('==> Saving...')
        state = {
            'opt': args,
            'model': model.state_dict(),
            'contrast': contrast.state_dict(),
            'optimizer': optimizer.state_dict(),
            'epoch': epoch,
        }
        if args.moco:
            state['model_ema'] = model_ema.state_dict()
        if args.amp:
            state['amp'] = amp.state_dict()
        save_file = os.path.join(args.model_folder, 'current.pth')
        torch.save(state, save_file)
        if epoch % args.save_freq == 0:
            save_file = os.path.join(
                args.model_folder,
                'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
            torch.save(state, save_file)
        # help release GPU memory
        del state
        torch.cuda.empty_cache()