class TestReCoSaInference(unittest.TestCase):
    def setUp(self):
        self.device = torch.device("cpu")
        self.config = Config()
        self.config.add_dataset("./conf/dataset/ubuntu_test.yml")
        self.config.add_model("./conf/model/ReCoSa_test.yml")
        self.config.add_api("./conf/api/ReCoSa.yml")
        self.recosa = RecoSAPL.load_from_checkpoint(
            checkpoint_path=self.config.api.model_path, config=self.config)
        data = UbuntuDataSet(
            self.config.dataset.root + self.config.dataset.target,
            self.config.dataset.raw.train,
            self.config.model.max_seq,
            _max_turns=self.config.model.max_turns,
        )
        dataloader = UbuntuDataLoader(
            data,
            batch_size=1,
            shuffle=False,
            num_workers=8,
            collate_fn=collate,
        )
        sample = iter(dataloader)
        sample_data = next(sample)
        self.ctx = sample_data[0].to(self.device)
        self.response = sample_data[1].to(self.device)
        self.target = sample_data[2].to(self.device)
        pytorch_lightning.seed_everything(SEED_NUM)

    def test_inputs(self):
        batch_idx = 0
        self.assertEqual(
            "<|start|> thanks!  How the heck did you figure that out?. <|end|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|>",
            self.recosa.model.tokenizer.decode(self.ctx[batch_idx][-2]),
        )
        self.assertEqual(
            "<|start|> https://bugs.launchpad.net/lightdm/+bug/864109/comments/3. <|end|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|>",
            self.recosa.model.tokenizer.decode(self.ctx[batch_idx][-1]),
        )
        self.assertEqual(
            "<|start|> nice thanks!. <|end|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|>",
            self.recosa.model.tokenizer.decode(self.response[batch_idx]),
        )
        self.assertEqual(
            "nice thanks!. <|end|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|end|>",
            self.recosa.model.tokenizer.decode(self.target[batch_idx]),
        )

    def test_forward_recosa(self):
        dec_res = self.recosa.model(self.ctx, self.target)
        dec_res = torch.argmax(dec_res[0], dim=0)
        res_decoded = self.recosa.model.tokenizer.decode(dec_res)
        logger.debug(res_decoded)
        self.assertEqual(res_decoded.split()[0], "thanks!.")

    def test_generate_recosa(self):
        _, res = self.recosa.generate(self.ctx)
        res_decoded = self.recosa.model.tokenizer.decode(res[0])
        logger.debug(res_decoded)
        self.assertEqual(res_decoded.split()[0], "thanks.")
Пример #2
0
def main(
    config_data_file: str,
    config_model_file: str,
    config_trainer_file: str,
    config_api_file: str,
    version: str,
) -> None:

    # TODO: to be removed
    _ = build({"data_config": config_data_file, "version": version})

    cfg = Config()
    cfg.add_dataset(config_data_file)
    cfg.add_model(config_model_file)
    cfg.add_api(config_api_file)
    cfg.add_trainer(config_trainer_file)

    val_data = UbuntuDataSet(
        cfg.dataset.root + cfg.dataset.target,
        cfg.dataset.raw.val,
        cfg.model.max_seq,
        cfg.dataset.target,
        cfg.model.max_turns,
    )

    val_dataloader = UbuntuDataLoader(
        val_data,
        batch_size=cfg.model.batch_size,
        shuffle=False,
        num_workers=8,
        collate_fn=collate,
    )

    model = RecoSAPL.load_from_checkpoint(checkpoint_path=cfg.api.model_path,
                                          config=cfg)
    cfg.trainer.pl.max_epochs = 1

    trainer = pl.Trainer(**cfg.trainer.pl,
                         logger=False,
                         checkpoint_callback=False)
    test_result = trainer.test(model, test_dataloaders=val_dataloader)
    logger.info(test_result)
    bleu_score_4 = bleuS_4(model.pred, model.target)
    bleu_score_2 = bleuS_2(model.pred, model.target)
    logger.info(bleu_score_4)
    logger.info(bleu_score_2)
Пример #3
0
from pathlib import Path

from pydantic import BaseModel

from infer import Predictor
from serving.app_factory import create_app
from src.core.build_data import Config
from train import RecoSAPL

config = Config()
config.add_model("./conf/model/ReCoSa.yml")
config.add_api("./conf/api/ReCoSa.yml")

predictor = Predictor.from_checkpoint(RecoSAPL, config)


class Request(BaseModel):
    input_text: str


class Response(BaseModel):
    prediction: str


def handler(request: Request) -> Response:
    prediction = predictor.generate(request.input_text)
    return Response(prediction=prediction)


app = create_app(handler, Request, Response)