Exemple #1
0
    def __init__(self, encoder, checkpoint_path, cuda, num_classes=400):
        model_type = "{}_vtn".format(encoder)

        args, _ = generate_args(model=model_type, n_classes=num_classes, layer_norm=False, cuda=cuda)
        #print(checkpoint_path)
        args.pretrain_path = checkpoint_path
        self.model, _ = create_model(args, model_type)

        if cuda:
        	self.model = self.model.module
        	self.model.eval()
        	self.model.cuda()
        	# we train on GPU thus it automatically loads to GPU
        	checkpoint = torch.load(str(checkpoint_path))
        else:
        	self.model.eval()
        	checkpoint = torch.load(str(checkpoint_path), map_location=lambda storage, loc: storage)
		
        self.device = torch.device('cuda' if cuda else 'cpu')
        self.model.load_checkpoint(checkpoint['state_dict'])
        #print(self.model)
        #load_state(self.model, checkpoint['state_dict'])
        self.preprocessing = make_preprocessing(args)
        #elf.tp_preprocessing = Compose([args.temporal_stride, LoopPadding(args.sample_duration / args.temporal_stride)])

        self.embeds = deque(maxlen=(args.sample_duration * args.temporal_stride))
Exemple #2
0
    def test_create_resnet34_vtn(self):
        args = _make_args()
        model, _ = create_model(args, 'resnet34_vtn')
        num_convs = len([l for l in model.resnet.modules() if isinstance(l, nn.Conv2d)])

        assert isinstance(model, VideoTransformer)
        assert 36 == num_convs
Exemple #3
0
    def test_select_encoder_from_args(self):
        args = _make_args(encoder='mobilenetv2')
        model, _ = create_model(args, 'vtn')
        num_convs = len([l for l in model.resnet.modules() if isinstance(l, nn.Conv2d)])
        num_mobnet_blocks = len([l for l in model.resnet.modules() if isinstance(l, InvertedResidual)])

        assert isinstance(model, VideoTransformer)
        assert 52 == num_convs
        assert 17 == num_mobnet_blocks
    def __init__(self, encoder, checkpoint_path, num_classes=400):
        model_type = "{}_vtn".format(encoder)
        args, _ = generate_args(model=model_type,
                                n_classes=num_classes,
                                layer_norm=False)
        self.model, _ = create_model(args, model_type)

        self.model = self.model.module
        self.model.eval()
        self.model.cuda()

        checkpoint = torch.load(str(checkpoint_path))
        load_state(self.model, checkpoint['state_dict'])

        self.preprocessing = make_preprocessing(args)
        self.embeds = deque(maxlen=(args.sample_duration *
                                    args.temporal_stride))
Exemple #5
0
def main():
    args = parse_arguments()
    net, _ = create_model(args, args.model)
    net = net.module
    net.cpu()

    if net is None:
        return
    net.eval()

    h, w = args.sample_size, args.sample_size
    var = torch.randn(1, args.sample_duration, 3, h, w).to('cpu')

    net.apply(lambda m: m.register_forward_hook(compute_layer_statistics_hook))

    out = net(var)

    restore_module_names(net)
    print_statisctics()
def main():
    args = configure()
    log_configuration(args)

    model, parameters = create_model(args,
                                     args.model,
                                     pretrain_path=args.pretrain_path)
    optimizer, scheduler = setup_solver(args, parameters)

    if args.resume_path:
        print('loading checkpoint {}'.format(args.resume_path))
        checkpoint = torch.load(str(args.resume_path))
        load_state(model, checkpoint['state_dict'])

        if args.resume_train:
            args.begin_epoch = checkpoint['epoch']
            if not args.train and checkpoint.get('optimizer') is not None:
                optimizer.load_state_dict(checkpoint['optimizer'])

    if args.onnx:
        export_onnx(args, model, args.onnx)
        return

    criterion = create_criterion(args)
    train_data, val_data, test_data = setup_dataset(args, args.train, args.val,
                                                    args.test)

    if args.train or args.val:
        val_loader = torch.utils.data.DataLoader(val_data,
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=args.n_threads,
                                                 pin_memory=True,
                                                 drop_last=False)

    if args.train:
        if args.weighted_sampling:
            class_weights = getattr(args, 'class_weights', None)
            sampler = WeightedRandomSampler(
                train_data.get_sample_weights(class_weights), len(train_data))
        else:
            if len(train_data) < args.batch_size:
                sampler = RandomSampler(train_data,
                                        replacement=True,
                                        num_samples=args.batch_size)
            else:
                sampler = RandomSampler(train_data, replacement=False)

        train_loader = torch.utils.data.DataLoader(train_data,
                                                   batch_size=args.batch_size,
                                                   sampler=sampler,
                                                   num_workers=args.n_threads,
                                                   pin_memory=True,
                                                   drop_last=args.sync_bn)

        log_training_setup(model, train_data, val_data, optimizer, scheduler)

        train(args, model, train_loader, val_loader, criterion, optimizer,
              scheduler, args.logger)

    if not args.train and args.val:
        with torch.no_grad():
            validate(args, args.begin_epoch, val_loader, model, criterion,
                     args.logger)

    if args.test:
        test_loader = torch.utils.data.DataLoader(test_data,
                                                  batch_size=args.batch_size,
                                                  shuffle=False,
                                                  num_workers=args.n_threads,
                                                  pin_memory=True)

        with torch.no_grad():
            with args.logger.scope():
                test(args, test_loader, model, args.logger)