示例#1
0
    def __init__(self, encoder, checkpoint_path, cuda, num_classes=400):
        model_type = "{}_vtn".format(encoder)

        args, _ = generate_args(model=model_type, n_classes=num_classes, layer_norm=False, cuda=cuda)
        #print(checkpoint_path)
        args.pretrain_path = checkpoint_path
        self.model, _ = create_model(args, model_type)

        if cuda:
        	self.model = self.model.module
        	self.model.eval()
        	self.model.cuda()
        	# we train on GPU thus it automatically loads to GPU
        	checkpoint = torch.load(str(checkpoint_path))
        else:
        	self.model.eval()
        	checkpoint = torch.load(str(checkpoint_path), map_location=lambda storage, loc: storage)
		
        self.device = torch.device('cuda' if cuda else 'cpu')
        self.model.load_checkpoint(checkpoint['state_dict'])
        #print(self.model)
        #load_state(self.model, checkpoint['state_dict'])
        self.preprocessing = make_preprocessing(args)
        #elf.tp_preprocessing = Compose([args.temporal_stride, LoopPadding(args.sample_duration / args.temporal_stride)])

        self.embeds = deque(maxlen=(args.sample_duration * args.temporal_stride))
    def __init__(self, encoder, checkpoint_path, num_classes=400):
        model_type = "{}_vtn".format(encoder)
        args, _ = generate_args(model=model_type,
                                n_classes=num_classes,
                                layer_norm=False)
        self.model, _ = create_model(args, model_type)

        self.model = self.model.module
        self.model.eval()
        self.model.cuda()

        checkpoint = torch.load(str(checkpoint_path))
        load_state(self.model, checkpoint['state_dict'])

        self.preprocessing = make_preprocessing(args)
        self.embeds = deque(maxlen=(args.sample_duration *
                                    args.temporal_stride))
示例#3
0
    def test_bool_flags(self):
        args, _ = generate_args('--no-val')

        assert args.val is False
示例#4
0
    def test_kwargs_is_setting_args(self):
        args, _ = generate_args(encoder='test')

        assert args.encoder == 'test'
示例#5
0
    def test_returns_namespace(self):
        args, _ = generate_args()

        assert isinstance(args, Namespace)