def test_serve(): model = SummarizationTask(TEST_BACKBONE) # TODO: Currently only servable once a preprocess and postprocess have been attached model._preprocess = SummarizationPreprocess(backbone=TEST_BACKBONE) model._postprocess = Seq2SeqPostprocess() model.eval() model.serve()
def test_jit(tmpdir): sample_input = { "input_ids": torch.randint(1000, size=(1, 32)), "attention_mask": torch.randint(1, size=(1, 32)), } path = os.path.join(tmpdir, "test.pt") model = SummarizationTask(TEST_BACKBONE) model.eval() # Huggingface only supports `torch.jit.trace` model = torch.jit.trace(model, [sample_input]) torch.jit.save(model, path) model = torch.jit.load(path) out = model(sample_input) assert isinstance(out, torch.Tensor)