def load_model(model_config, device):
    model = get_semantic_segmentation_model(model_config)
    if model is None:
        repo_or_dir = model_config.get('repo_or_dir', None)
        model = get_model(model_config['name'], repo_or_dir, **model_config['params'])

    ckpt_file_path = model_config['ckpt']
    load_ckpt(ckpt_file_path, model=model, strict=True)
    return model.to(device)
Esempio n. 2
0
def load_model(model_config, device, distributed):
    model = get_image_classification_model(model_config, distributed)
    if model is None:
        repo_or_dir = model_config.get('repo_or_dir', None)
        model = get_model(model_config['name'], repo_or_dir, **model_config['params'])

    ckpt_file_path = model_config['ckpt']
    load_ckpt(ckpt_file_path, model=model, strict=True)
    return model.to(device)
Esempio n. 3
0
 def test_torch_hub(self):
     model_name = 'tf_mobilenetv3_large_100'
     repo_or_dir = 'rwightman/pytorch-image-models'
     kwargs = {'pretrained': True}
     mobilenet_v3 = get_model(model_name, repo_or_dir, **kwargs)
     assert type(mobilenet_v3).__name__ == 'MobileNetV3'