コード例 #1
0
def get_model(paths, feature_type):
    if feature_type == 'vgg':
        feature_network = feature_model.Vgg16(
            num_classes=len(metadata.action_classes))
    elif feature_type == 'resnet':
        feature_network = feature_model.Resnet152(
            num_classes=len(metadata.action_classes))
    elif feature_type == 'densenet':
        feature_network = feature_model.Densenet(
            num_classes=len(metadata.action_classes))
    else:
        raise ValueError('feature type not recognized')

    if feature_type.startswith('alexnet') or feature_type.startswith('vgg'):
        feature_network.features = torch.nn.DataParallel(
            feature_network.features)
        feature_network.cuda()
    else:
        feature_network = torch.nn.DataParallel(feature_network).cuda()

    checkpoint_dir = os.path.join(paths.tmp_root, 'checkpoints', 'vcoco',
                                  'finetune_{}'.format(feature_type))
    best_model_file = os.path.join(checkpoint_dir, 'model_best.pth')
    checkpoint = torch.load(best_model_file)
    feature_network.load_state_dict(checkpoint['state_dict'])
    return feature_network
コード例 #2
0
def get_model(paths):
    # vgg16 = Vgg16(last_layer=1).cuda()
    
    if feature_mode == 'None':
        return 

    feature_network = feature_model.Resnet152(num_classes=len(metadata.action_classes))
    feature_network.cuda()
    checkpoint_dir = os.path.join(os.path.dirname(__file__), '../../../../data/model_resnet_noisy/finetune_{}_noisy'.format(feature_mode))
    best_model_file = os.path.join(checkpoint_dir, 'model_best.pth')
    checkpoint = torch.load(best_model_file)
    for k in list(checkpoint['state_dict'].keys()):
        if k[:7] == 'module.':
            checkpoint['state_dict'][k[7:]] = checkpoint['state_dict'][k]
            del checkpoint['state_dict'][k]
    feature_network.load_state_dict(checkpoint['state_dict'])
    
    transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225]),
    ])

    return feature_network, transform
コード例 #3
0
ファイル: finetune.py プロジェクト: tengyu-liu/Part-GPNN
def main(args):
    best_prec1 = 0.0
    args.distributed = args.world_size > 1
    if args.distributed:
        torch.distributed.init_process_group(backend=args.dist_backend,
                                             init_method=args.dist_url,
                                             world_size=args.world_size)
    # create model
    if args.feature_type == 'vgg':
        model = feature_model.Vgg16(num_classes=len(metadata.action_classes))
    elif args.feature_type == 'resnet':
        model = feature_model.Resnet152(
            num_classes=len(metadata.action_classes))
    elif args.feature_type == 'densenet':
        model = feature_model.Densenet(
            num_classes=len(metadata.action_classes))
    input_imsize = (244, 244)

    if not args.distributed:
        if args.feature_type.startswith(
                'alexnet') or args.feature_type.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()
    else:
        model.cuda()
        model = torch.nn.parallel.DistributedDataParallel(model)

    # define loss function (criterion) and optimizer
    criterion = torch.nn.CrossEntropyLoss(
        weight=torch.FloatTensor(metadata.action_class_weight)).cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(os.path.join(args.resume, 'model_best.pth')):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(os.path.join(args.resume,
                                                 'model_best.pth'))
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(
                os.path.join(args.resume, 'model_best.pth')))

    torch.backends.cudnn.benchmark = True

    # Data loading code
    normalize = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                 std=[0.229, 0.224, 0.225])
    train_transform = torchvision.transforms.Compose([
        torchvision.transforms.RandomHorizontalFlip(),
        torchvision.transforms.RandomVerticalFlip(),
        torchvision.transforms.ColorJitter(brightness=0.3,
                                           contrast=0.3,
                                           saturation=0.3,
                                           hue=0.15),
        torchvision.transforms.RandomAffine(degrees=15,
                                            shear=15,
                                            resample=PIL.Image.BILINEAR),
        torchvision.transforms.ToTensor(), normalize
    ])

    test_transform = torchvision.transforms.Compose(
        [torchvision.transforms.ToTensor(), normalize])

    train_dataset = feature_model.VCOCO(args.data, input_imsize,
                                        train_transform, 'train')
    val_dataset = feature_model.VCOCO(args.data, input_imsize, train_transform,
                                      'val')
    test_dataset = feature_model.VCOCO(args.data, input_imsize, test_transform,
                                       'test')

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=True,
                                             num_workers=args.workers,
                                             pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=args.workers,
                                              pin_memory=False)

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train(val_loader, model, criterion, optimizer, epoch)
        train(train_loader, model, criterion, optimizer, epoch)

        if epoch == 0 or epoch >= 3:
            # evaluate on validation set
            prec1 = validate(test_loader, model, criterion)

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            print('Best precision: {:.03f}'.format(best_prec1))
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.feature_type,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                    'optimizer': optimizer.state_dict(),
                }, is_best)

    test_prec = validate(test_loader, model, criterion, test=True)
    print('Testing precision: {:.04f}'.format(test_prec))
コード例 #4
0
    elif len(img.shape) == 5:
        img = img.transpose([0, 1, 4, 2, 3])
    img = torch.autograd.Variable(torch.Tensor(img)).cuda()
    return img


meta_dir = os.path.join(os.path.dirname(__file__),
                        '../../../data/vcoco_features')
img_dir = '/home/tengyu/dataset/mscoco/images'
checkpoint_dir = '/home/tengyu/github/Part-GPNN/data/model_resnet_noisy/finetune_resnet'
vcoco_root = '/home/tengyu/dataset/v-coco/data'
save_data_path = '/home/tengyu/github/Part-GPNN/data/feature_resnet_tengyu2'

os.makedirs(save_data_path, exist_ok=True)

feature_network = feature_model.Resnet152(
    num_classes=len(metadata.action_classes))
feature_network.cuda()
best_model_file = os.path.join(checkpoint_dir, 'model_best.pth')
checkpoint = torch.load(best_model_file)
for k in list(checkpoint['state_dict'].keys()):
    if k[:7] == 'module.':
        checkpoint['state_dict'][k[7:]] = checkpoint['state_dict'][k]
        del checkpoint['state_dict'][k]

feature_network.load_state_dict(checkpoint['state_dict'])

transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225]),
])