コード例 #1
0
ファイル: test_api.py プロジェクト: petoor/mobile-vision
    def test_model_info(self):
        with make_temp_directory("test_model_info") as tmp_dir:
            _save_test_model(tmp_dir)
            model_info = ModelInfo(path=tmp_dir, type="torchscript")
            # NOTE: decide if load_model is a public API or class method of ModelInfo
            from mobile_cv.predictor.model_wrappers import load_model

            model = load_model(model_info, model_root="")
            self.assertEqual(torch.tensor(2), model(torch.tensor(1)))
コード例 #2
0
def _create_predictor(info_json, model_root):
    logger.info("Loading predictor info from {}".format(info_json))
    with PathManager.open(info_json) as f:
        info_dict = json.load(f)
        predictor_info = PredictorInfo.from_dict(info_dict)

    assert (predictor_info.model is None) ^ (predictor_info.models is None)
    if predictor_info.model is not None:
        model_or_models = load_model(predictor_info.model, model_root)
    else:
        model_or_models = {
            k: load_model(info, model_root)
            for k, info in predictor_info.models.items()
        }

    return PredictorWrapper(
        model_or_models=model_or_models,
        run_func=predictor_info.run_func_info.instantiate(),
        preprocess=predictor_info.preprocess_info.instantiate(),
        postprocess=predictor_info.postprocess_info.instantiate(),
    )
コード例 #3
0
    def load(cls, save_path, inputs_schema, outputs_schema, **load_kwargs):
        inputs_schema = instantiate(inputs_schema)
        outputs_schema = instantiate(outputs_schema)
        traced_model = load_model(save_path, "torchscript")

        class TracingAdapterWrapper(nn.Module):
            def __init__(self, traced_model, inputs_schema, outputs_schema):
                super().__init__()
                self.traced_model = traced_model
                self.inputs_schema = inputs_schema
                self.outputs_schema = outputs_schema

            def forward(self, *input_args):
                flattened_inputs, _ = flatten_to_tuple(input_args)
                flattened_outputs = self.traced_model(*flattened_inputs)
                return self.outputs_schema(flattened_outputs)

        return TracingAdapterWrapper(traced_model, inputs_schema, outputs_schema)
コード例 #4
0
ファイル: caffe2.py プロジェクト: yeonh2/d2go
    def load(cls, save_path, **load_kwargs):
        from mobile_cv.predictor.model_wrappers import load_model

        return load_model(save_path, "caffe2")
コード例 #5
0
 def load(cls, save_path, **load_kwargs):
     return load_model(save_path, "torchscript")
コード例 #6
0
 def load(cls, save_path, **load_kwargs):
     return load_model(save_path, "caffe2")