Example #1
0
def set_model(args):
	if args.model.startswith('alexnet'):
		model = MyAlexNetCMC()
		classifier = LinearClassifierAlexNet(layer=args.layer, n_label=args.n_label, pool_type='max')
	elif args.model.startswith('resnet'):
		model = MyResNetsCMC(name=args.model, view=args.view)
		if args.model.endswith('v1'):
			classifier = LinearClassifierResNet(args.layer, args.n_label, 'avg', 1)
		elif args.model.endswith('v2'):
			classifier = LinearClassifierResNet(args.layer, args.n_label, 'avg', 2)
		elif args.model.endswith('v3'):
			classifier = LinearClassifierResNet(args.layer, args.n_label, 'avg', 4)
		elif 'ttt' in args.model:
			classifier = LinearClassifierResNet(10, args.n_label, 'avg', 1)
		else:
			raise NotImplementedError('model not supported {}'.format(args.model))
	else:
		raise NotImplementedError('model not supported {}'.format(args.model))

	# load pre-trained model
	print('==> loading pre-trained model')
	ckpt = torch.load(args.model_path)
	model.load_state_dict(ckpt['model'])
	print("==> loaded checkpoint '{}' (epoch {})".format(args.model_path, ckpt['epoch']))
	print('==> done')

	model = model.cuda()
	classifier = classifier.cuda()

	model.eval()

	return model, classifier
def set_model(args):
    if args.model == 'alexnet':
        model = alexnet()
        classifier = LinearClassifierAlexNet(layer=args.layer, n_label=1000, pool_type='max')
    else:
        raise NotImplementedError(args.model)

    print('==> loading pre-trained model')
    ckpt = torch.load(args.model_path)
    model.load_state_dict(ckpt['model'])
    print('==> done')
    model.eval()

    criterion = nn.CrossEntropyLoss()
    if torch.cuda.is_available():
        model = model.cuda()
        classifier = classifier.cuda()
        criterion = criterion.cuda()
        cudnn.benchmark = True

    return model, classifier, criterion
def main():

    global best_acc1
    best_acc1 = 0

    args = parse_option()

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    # set the data loader
    train_folder = os.path.join(args.data_folder, 'train')
    val_folder = os.path.join(args.data_folder, 'val')

    logger = getLogger(args.save_folder)
    if args.dataset.startswith('imagenet') or args.dataset.startswith(
            'places'):
        image_size = 224
        crop_padding = 32
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
        normalize = transforms.Normalize(mean=mean, std=std)
        if args.aug == 'NULL':
            train_transform = transforms.Compose([
                transforms.RandomResizedCrop(image_size,
                                             scale=(args.crop, 1.)),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ])
        elif args.aug == 'CJ':
            train_transform = transforms.Compose([
                transforms.RandomResizedCrop(image_size,
                                             scale=(args.crop, 1.)),
                transforms.RandomGrayscale(p=0.2),
                transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ])
        else:
            raise NotImplemented('augmentation not supported: {}'.format(
                args.aug))

        val_transform = transforms.Compose([
            transforms.Resize(image_size + crop_padding),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            normalize,
        ])
        if args.dataset.startswith('imagenet'):
            train_dataset = datasets.ImageFolder(train_folder, train_transform)
            val_dataset = datasets.ImageFolder(
                val_folder,
                val_transform,
            )

        if args.dataset.startswith('places'):
            train_dataset = ImageList(
                '/data/trainvalsplit_places205/train_places205.csv',
                '/data/data/vision/torralba/deeplearning/images256',
                transform=train_transform,
                symbol_split=' ')
            val_dataset = ImageList(
                '/data/trainvalsplit_places205/val_places205.csv',
                '/data/data/vision/torralba/deeplearning/images256',
                transform=val_transform,
                symbol_split=' ')

        print(len(train_dataset))
        train_sampler = None

        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=(train_sampler is None),
            num_workers=args.n_workers,
            pin_memory=False,
            sampler=train_sampler)

        val_loader = torch.utils.data.DataLoader(val_dataset,
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=args.n_workers,
                                                 pin_memory=False)
    elif args.dataset.startswith('cifar'):
        train_loader, val_loader = cifar.get_linear_dataloader(args)
    elif args.dataset.startswith('svhn'):
        train_loader, val_loader = svhn.get_linear_dataloader(args)

    # create model and optimizer
    if args.model == 'alexnet':
        if args.layer == 6:
            args.layer = 5
        model = AlexNet(128)
        model = nn.DataParallel(model)
        classifier = LinearClassifierAlexNet(args.layer, args.n_label, 'avg')
    elif args.model == 'alexnet_cifar':
        if args.layer == 6:
            args.layer = 5
        model = AlexNet_cifar(128)
        model = nn.DataParallel(model)
        classifier = LinearClassifierAlexNet(args.layer,
                                             args.n_label,
                                             'avg',
                                             cifar=True)
    elif args.model == 'resnet50':
        model = resnet50(non_linear_head=False)
        model = nn.DataParallel(model)
        classifier = LinearClassifierResNet(args.layer, args.n_label, 'avg', 1)
    elif args.model == 'resnet18':
        model = resnet18()
        model = nn.DataParallel(model)
        classifier = LinearClassifierResNet(args.layer,
                                            args.n_label,
                                            'avg',
                                            1,
                                            bottleneck=False)
    elif args.model == 'resnet18_cifar':
        model = resnet18_cifar()
        model = nn.DataParallel(model)
        classifier = LinearClassifierResNet(args.layer,
                                            args.n_label,
                                            'avg',
                                            1,
                                            bottleneck=False)
    elif args.model == 'resnet50_cifar':
        model = resnet50_cifar()
        model = nn.DataParallel(model)
        classifier = LinearClassifierResNet(args.layer, args.n_label, 'avg', 1)
    elif args.model == 'resnet50x2':
        model = InsResNet50(width=2)
        classifier = LinearClassifierResNet(args.layer, args.n_label, 'avg', 2)
    elif args.model == 'resnet50x4':
        model = InsResNet50(width=4)
        classifier = LinearClassifierResNet(args.layer, args.n_label, 'avg', 4)
    elif args.model == 'shufflenet':
        model = shufflenet_v2_x1_0(num_classes=128, non_linear_head=False)
        model = nn.DataParallel(model)
        classifier = LinearClassifierResNet(args.layer, args.n_label, 'avg',
                                            0.5)
    else:
        raise NotImplementedError('model not supported {}'.format(args.model))

    print('==> loading pre-trained model')
    ckpt = torch.load(args.model_path)
    if not args.moco:
        model.load_state_dict(ckpt['state_dict'])
    else:
        try:
            state_dict = ckpt['state_dict']
            for k in list(state_dict.keys()):
                # retain only encoder_q up to before the embedding layer
                if k.startswith('module.encoder_q'
                                ) and not k.startswith('module.encoder_q.fc'):
                    # remove prefix
                    state_dict['module.' +
                               k[len("module.encoder_q."):]] = state_dict[k]
                # delete renamed or unused k
                del state_dict[k]
            model.load_state_dict(state_dict)
        except:
            pass
    print("==> loaded checkpoint '{}' (epoch {})".format(
        args.model_path, ckpt['epoch']))
    print('==> done')

    model = model.cuda()
    classifier = classifier.cuda()

    criterion = torch.nn.CrossEntropyLoss().cuda(args.gpu)

    if not args.adam:
        optimizer = torch.optim.SGD(classifier.parameters(),
                                    lr=args.learning_rate,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    else:
        optimizer = torch.optim.Adam(classifier.parameters(),
                                     lr=args.learning_rate,
                                     betas=(args.beta1, args.beta2),
                                     weight_decay=args.weight_decay,
                                     eps=1e-8)

    model.eval()
    cudnn.benchmark = True

    # optionally resume from a checkpoint
    args.start_epoch = 1
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location='cpu')
            # checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch'] + 1
            classifier.load_state_dict(checkpoint['classifier'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            best_acc1 = checkpoint['best_acc1']
            print(best_acc1.item())
            best_acc1 = best_acc1.cuda()
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            if 'opt' in checkpoint.keys():
                # resume optimization hyper-parameters
                print('=> resume hyper parameters')
                if 'bn' in vars(checkpoint['opt']):
                    print('using bn: ', checkpoint['opt'].bn)
                if 'adam' in vars(checkpoint['opt']):
                    print('using adam: ', checkpoint['opt'].adam)
                #args.learning_rate = checkpoint['opt'].learning_rate
                # args.lr_decay_epochs = checkpoint['opt'].lr_decay_epochs
                args.lr_decay_rate = checkpoint['opt'].lr_decay_rate
                args.momentum = checkpoint['opt'].momentum
                args.weight_decay = checkpoint['opt'].weight_decay
                args.beta1 = checkpoint['opt'].beta1
                args.beta2 = checkpoint['opt'].beta2
            del checkpoint
            torch.cuda.empty_cache()
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # tensorboard
    tblogger = tb_logger.Logger(logdir=args.tb_folder, flush_secs=2)

    # routine
    best_acc = 0.0
    for epoch in range(args.start_epoch, args.epochs + 1):

        adjust_learning_rate(epoch, args, optimizer)
        print("==> training...")

        time1 = time.time()
        train_acc, train_acc5, train_loss = train(epoch, train_loader, model,
                                                  classifier, criterion,
                                                  optimizer, args)
        time2 = time.time()
        logging.info('train epoch {}, total time {:.2f}'.format(
            epoch, time2 - time1))

        logging.info(
            'Epoch: {}, lr:{} , train_loss: {:.4f}, train_acc: {:.4f}/{:.4f}'.
            format(epoch, optimizer.param_groups[0]['lr'], train_loss,
                   train_acc, train_acc5))

        tblogger.log_value('train_acc', train_acc, epoch)
        tblogger.log_value('train_acc5', train_acc5, epoch)
        tblogger.log_value('train_loss', train_loss, epoch)
        tblogger.log_value('learning_rate', optimizer.param_groups[0]['lr'],
                           epoch)

        test_acc, test_acc5, test_loss = validate(val_loader, model,
                                                  classifier, criterion, args)

        if test_acc >= best_acc:
            best_acc = test_acc

        logging.info(
            colorful(
                'Epoch: {}, val_loss: {:.4f}, val_acc: {:.4f}/{:.4f}, best_acc: {:.4f}'
                .format(epoch, test_loss, test_acc, test_acc5, best_acc)))
        tblogger.log_value('test_acc', test_acc, epoch)
        tblogger.log_value('test_acc5', test_acc5, epoch)
        tblogger.log_value('test_loss', test_loss, epoch)

        # save the best model
        if test_acc > best_acc1:
            best_acc1 = test_acc
            state = {
                'opt': args,
                'epoch': epoch,
                'classifier': classifier.state_dict(),
                'best_acc1': best_acc1,
                'optimizer': optimizer.state_dict(),
            }
            save_name = '{}_layer{}.pth'.format(args.model, args.layer)
            save_name = os.path.join(args.save_folder, save_name)
            print('saving best model!')
            torch.save(state, save_name)

        # save model
        if epoch % args.save_freq == 0:
            print('==> Saving...')
            state = {
                'opt': args,
                'epoch': epoch,
                'classifier': classifier.state_dict(),
                'best_acc1': test_acc,
                'optimizer': optimizer.state_dict(),
            }
            save_name = 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)
            save_name = os.path.join(args.save_folder, save_name)
            print('saving regular model!')
            torch.save(state, save_name)

        # tensorboard logger
        pass
Example #4
0
def set_model(args, ngpus_per_node):
    if args.model == 'alexnet':
        model = alexnet()
        classifier = LinearClassifierAlexNet(layer=args.layer,
                                             n_label=1000,
                                             pool_type='max')
    elif args.model.startswith('resnet'):
        model = ResNetV2(args.model)
        classifier = LinearClassifierResNetV2(layer=args.layer,
                                              n_label=1000,
                                              pool_type='avg')
    else:
        raise NotImplementedError(args.model)

    # load pre-trained model
    print('==> loading pre-trained model')
    ckpt = torch.load(args.model_path)
    state_dict = ckpt['model']

    has_module = False
    for k, v in state_dict.items():
        if k.startswith('module'):
            has_module = True

    if has_module:
        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k[7:]  # remove `module.`
            new_state_dict[name] = v
        model.load_state_dict(new_state_dict)
    else:
        model.load_state_dict(state_dict)

    print('==> done')
    model.eval()

    if args.distributed:
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            classifier.cuda(args.gpu)
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.num_workers = int(args.num_workers / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.gpu])
            classifier = torch.nn.parallel.DistributedDataParallel(
                classifier, device_ids=[args.gpu])
        else:
            model.cuda()
            model = torch.nn.parallel.DistributedDataParallel(model)
            classifier.cuda()
            classifier = torch.nn.parallel.DistributedDataParallel(classifier)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
        classifier = classifier.cuda(args.gpu)
    else:
        model = torch.nn.DataParallel(model).cuda()
        classifier = torch.nn.DataParallel(classifier).cuda()

    criterion = nn.CrossEntropyLoss().cuda(args.gpu)

    return model, classifier, criterion
Example #5
0
def set_model(args, ngpus_per_node):
    if args.model == 'alexnet':
        if args.view == 'Lab':
            model = alexnet(in_channel=(1, 2))
        elif args.view == 'Rot':
            model = alexnet(in_channel=(3, 3))
        classifier = LinearClassifierAlexNet(layer=args.layer,
                                             n_label=10,
                                             pool_type='max')
    elif args.model.startswith('resnet'):
        model = ResNetV2(args.model)
        classifier = LinearClassifierResNetV2(layer=args.layer,
                                              n_label=10,
                                              pool_type='avg')
    else:
        raise NotImplementedError(args.model)

    # load pre-trained model
    print('==> loading pre-trained model')
    if torch.cuda.is_available():
        ckpt = torch.load(args.model_path)
    else:
        ckpt = torch.load(args.model_path, map_location='cpu')
    state_dict = ckpt['model']

    has_module = False
    for k, v in state_dict.items():
        if k.startswith('module'):
            has_module = True

    if has_module:
        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k[7:]  # remove `module.`
            new_state_dict[name] = v
        model.load_state_dict(new_state_dict)
    else:
        model.load_state_dict(state_dict)

    print('==> done')
    model.eval()

    # load pre-trained classifier
    print('==> loading pre-trained classifier')
    if torch.cuda.is_available():
        ckpt = torch.load(args.classifier_path)
    else:
        ckpt = torch.load(args.classifier_path, map_location='cpu')
    state_dict = ckpt['classifier']

    has_module = False
    for k, v in state_dict.items():
        if k.startswith('module'):
            has_module = True

    if has_module:
        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k[7:]  # remove `module.`
            new_state_dict[name] = v
        classifier.load_state_dict(new_state_dict)
    else:
        classifier.load_state_dict(state_dict)

    print('==> done')
    classifier.eval()

    criterion = nn.CrossEntropyLoss().cuda(args.gpu)

    return model, classifier, criterion