Exemplo n.º 1
0
def model_fn(model_dir):

    logger.info('model_fn')
    neopytorch.config(model_dir=model_dir, neo_runtime=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # The compiled model is saved as "compiled.pt"
    model = torch.jit.load(os.path.join(model_dir, 'model.pth'), map_location=device)

    return model
Exemplo n.º 2
0
def model_fn(model_dir):

    logger.info("model_fn")
    neopytorch.config(model_dir=model_dir, neo_runtime=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # The compiled model is saved as "compiled.pt"
    model = torch.jit.load(os.path.join(model_dir, "compiled.pt"), map_location=device)

    # It is recommended to run warm-up inference during model load
    sample_input_path = os.path.join(model_dir, "sample_input.pkl")
    with open(sample_input_path, "rb") as input_file:
        model_input = pickle.load(input_file)
    if torch.is_tensor(model_input):
        model_input = model_input.to(device)
        model(model_input)
    elif isinstance(model_input, tuple):
        model_input = (inp.to(device) for inp in model_input if torch.is_tensor(inp))
        model(*model_input)
    else:
        print("Only supports a torch tensor or a tuple of torch tensors")

    return model