コード例 #1
0
def test_serve():
    model = TranslationTask(TEST_BACKBONE)
    # TODO: Currently only servable once a preprocess and postprocess have been attached
    model._preprocess = TranslationPreprocess(backbone=TEST_BACKBONE)
    model._postprocess = Seq2SeqPostprocess()
    model.eval()
    model.serve()
コード例 #2
0
def test_jit(tmpdir):
    sample_input = {
        "input_ids": torch.randint(128, size=(1, 4)),
        "attention_mask": torch.randint(1, size=(1, 4)),
    }
    path = os.path.join(tmpdir, "test.pt")

    model = TranslationTask(TEST_BACKBONE, val_target_max_length=None)
    model.eval()

    # Huggingface only supports `torch.jit.trace`
    model = torch.jit.trace(model, [sample_input])

    torch.jit.save(model, path)
    model = torch.jit.load(path)

    out = model(sample_input)
    assert isinstance(out, torch.Tensor)
コード例 #3
0
def test_serve():
    model = TranslationTask(TEST_BACKBONE)
    model.eval()
    model.serve()