예제 #1
0
def get_model(args):
    assert args.type in ['2d', '3d', 's3d']
    if args.type == '2d':
        print('Loading 2D-ResNet-152 ...')
        model = models.resnet152(pretrained=True)
        model = nn.Sequential(*list(model.children())[:-2], GlobalAvgPool())
        model = model.cuda()
    elif args.type == '3d':
        print('Loading 3D-ResneXt-101 ...')
        model = resnext.resnet101(
            num_classes=400,
            shortcut_type='B',
            cardinality=32,
            sample_size=112,
            sample_duration=16,
            last_fc=False)
        model = model.cuda()
        model_data = th.load(args.resnext101_model_path)
        model.load_state_dict(model_data)
    else:
        print('Loading S3D ...')
        model = S3D(
            'model/s3d_dict.npy',
            num_classes=512
         )
        model = model.cuda()
        model_data = th.load(args.s3d_model_path)
        model.load_state_dict(model_data)
        # device = th.device('cuda:0')
        # model.to(device)
    model.eval()
    print('loaded')
    return model
def build_model(args):
    print(f'Loading S3D with checkpoint {args.s3d_ckpt}...')
    model = S3D()
    model = model.cuda()
    model_data = th.load(args.s3d_ckpt)
    model.load_state_dict(model_data, strict=False)

    model.eval()
    return model
예제 #3
0
def main():
    args = get_args()
    assert args.eval_video_root != ''
    checkpoint_path = './checkpoint/epoch0089.pth.tar'
    print("=> loading checkpoint '{}'".format(checkpoint_path))
    checkpoint = torch.load(checkpoint_path)
    if "state_dict" in checkpoint:
        model = S3D(args.num_class,
                    space_to_depth=False,
                    word2vec_path=args.word2vec_path)
        model = torch.nn.DataParallel(model)
        model.load_state_dict(checkpoint["state_dict"])
    else:  # load pre-trained model from https://github.com/antoine77340/S3D_HowTo100M
        model = S3D(args.num_class,
                    space_to_depth=True,
                    word2vec_path=args.word2vec_path)
        model = torch.nn.DataParallel(model)
        checkpoint_module = {'module.' + k: v for k, v in checkpoint.items()}
        model.load_state_dict(checkpoint_module)
    model.eval()
    model.cuda()

    # Data loading code
    dataset = HMDB_DataLoader(
        data=os.path.join(os.path.dirname(__file__), 'csv/hmdb51.csv'),
        num_clip=args.num_windows_test,
        video_root=args.eval_video_root,
        num_frames=args.num_frames,
        size=args.video_size,
        crop_only=False,
        center_crop=True,
        with_flip=True,
    )
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        shuffle=False,
        drop_last=False,
        num_workers=args.num_thread_reader,
    )

    # train for one epoch
    evaluate(dataloader, model, args)
예제 #4
0
def main():
    # see model input data
    # data = np.load('s3d_dict.npy')

    # Instantiate the model
    net = S3D('s3d_dict.npy', 512)

    # Load the model weights
    net.load_state_dict(th.load('s3d_howto100m.pth'))

    # Video input should be of size Batch x 3 x T x H x W and normalized to [0, 1]
    # video1 = th.rand(2, 3, 32, 224, 224)
    # print(video1.shape)
    # print(type(video1))
    video = th.from_numpy(
        np.load("../video_feature_extractor/output/_0flfBHjVKU_features.npy"))
    print(video.shape)
    print(type(video))

    # Evaluation mode
    net = net.eval()

    # Video inference
    '''
    video_output is a dictionary containing two keys:

        video_embedding: This is the video embedding (size 512) from the joint text-video space. 
                        It should be used to compute similarity scores with text inputs using the text embedding.
        
        mixed_5c: This is the global averaged pooled feature from S3D of dimension 1024. 
                This should be use for classification on downstream tasks.
    '''
    video_output = net(video)
    print(video_output['mixed_5c'])
    print(video_output['mixed_5c'].shape)
    print(type(video_output['mixed_5c']))
예제 #5
0
def main_worker(gpu, ngpus_per_node, args):
    args.gpu = gpu

    if args.distributed:
        if args.multiprocessing_distributed:
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(
            backend=args.dist_backend,
            init_method=args.dist_url,
            world_size=args.world_size,
            rank=args.rank,
        )
    # create model
    model = S3D(
        args.num_class,
        space_to_depth=False,
        word2vec_path=args.word2vec_path,
        init=args.weight_init,
    )

    if args.pretrain_cnn_path:
        net_data = torch.load(args.pretrain_cnn_path)
        model.load_state_dict(net_data)
    if args.distributed:
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.batch_size_val = int(args.batch_size_val / ngpus_per_node)
            args.num_thread_reader = int(args.num_thread_reader /
                                         ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.gpu])
        else:
            model.cuda()
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        model = torch.nn.DataParallel(model).cuda()

    # Data loading code
    train_dataset = HT100M_DataLoader(
        csv=args.train_csv,
        video_root=args.video_path,
        caption_root=args.caption_root,
        min_time=args.min_time,
        fps=args.fps,
        num_frames=args.num_frames,
        size=args.video_size,
        crop_only=args.crop_only,
        center_crop=args.centercrop,
        random_left_right_flip=args.random_flip,
        num_candidates=args.num_candidates,
    )

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=(train_sampler is None),
        drop_last=True,
        num_workers=args.num_thread_reader,
        pin_memory=args.pin_memory,
        sampler=train_sampler,
    )

    criterion = SOFTMAXMILNCELoss()

    if args.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), args.lr)
    elif args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(),
                                    args.lr,
                                    momentum=args.momemtum)

    # scheduler = get_cosine_schedule_with_warmup(optimizer, args.warmup_steps, len(train_loader) * args.epochs)
    scheduler = None
    checkpoint_dir = os.path.join(os.path.dirname(__file__), 'checkpoint',
                                  args.checkpoint_dir)
    if args.checkpoint_dir != '' and not (
            os.path.isdir(checkpoint_dir)) and args.rank == 0:
        os.mkdir(checkpoint_dir)
    # optionally resume from a checkpoint
    if args.resume:
        checkpoint_path = './checkpoint/checkpoint.pth.tar'
        if checkpoint_path:
            log("=> loading checkpoint '{}'".format(checkpoint_path), args)
            checkpoint = torch.load(checkpoint_path)
            args.start_epoch = checkpoint["epoch"]
            model.load_state_dict(checkpoint["state_dict"])
            optimizer.load_state_dict(checkpoint["optimizer"])
            scheduler.load_state_dict(checkpoint["scheduler"])
            log(
                "=> loaded checkpoint '{}' (epoch {})".format(
                    checkpoint_path, checkpoint["epoch"]), args)
        else:
            log("=> no checkpoint found at '{}'".format(args.resume), args)
        # checkpoint_path = './checkpoint/s3d_howto100m.pth'
        # checkpoint = torch.load(checkpoint_path)
        # checkpoint_module = {'module.' + k:v for k,v in checkpoint.items()}
        # model.load_state_dict(checkpoint_module)
        # print('loaded')

    if args.cudnn_benchmark:
        cudnn.benchmark = True
    total_batch_size = args.world_size * args.batch_size
    log(
        "Starting training loop for rank: {}, total batch size: {}".format(
            args.rank, total_batch_size), args)
    for epoch in tqdm(range(args.start_epoch, args.epochs)):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        if epoch % max(1, total_batch_size // 512) == 0 and args.evaluate:
            evaluate(test_loader, model, epoch, args, 'HMDB')
        # train for one epoch
        train(train_loader, model, criterion, optimizer, scheduler, epoch,
              train_dataset, args)
        if args.rank == 0:
            save_checkpoint(
                {
                    "epoch": epoch + 1,
                    "state_dict": model.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    # "scheduler": scheduler.state_dict(),
                },
                checkpoint_dir,
                epoch + 1)
예제 #6
0
                    action='store_true',
                    default=False,
                    help='using ITP')
args, leftovers = parser.parse_known_args()
# python main.py --batch_size 32 --id2path test_id2path.csv --ann_file ../annotation/caption.json --data_root ./

# For ITP
if args.ITP:
    os.environ['CUDA_VISIBLE_DEVICES'] = os.environ[
        'OMPI_COMM_WORLD_LOCAL_RANK']
    args.id2path = args.id2path + os.environ['CUDA_VISIBLE_DEVICES'] + ".csv"
    print('{} {} {} {}'.format(os.environ['CUDA_VISIBLE_DEVICES'],
                               torch.cuda.current_device(),
                               torch.cuda.device_count(), args.id2path))
# Instantiate the model
net = S3D(f'{args.data_root}/s3d_dict.npy', 512)  # text module
net = net.cuda()
net.load_state_dict(torch.load(f'{args.data_root}/s3d_howto100m.pth'))  # S3D

# Video input should be of size Batch x 3 x T x H x W and normalized to [0, 1]
dataset = VideoClipDataset(
    args.id2path,
    args.ann_file,
    args.data_root,
    framerate=16,
    size=224,
    centercrop=True,  # TODO: use ?*224 or ?*224 + centercrop or 224*224
)

n_dataset = len(dataset)
sampler = RandomSequenceSampler(n_dataset, 10)