class TestReCoSaInference(unittest.TestCase):
    def setUp(self):
        self.device = torch.device("cpu")
        self.config = Config()
        self.config.add_dataset("./conf/dataset/ubuntu_test.yml")
        self.config.add_model("./conf/model/ReCoSa_test.yml")
        self.config.add_api("./conf/api/ReCoSa.yml")
        self.recosa = RecoSAPL.load_from_checkpoint(
            checkpoint_path=self.config.api.model_path, config=self.config)
        data = UbuntuDataSet(
            self.config.dataset.root + self.config.dataset.target,
            self.config.dataset.raw.train,
            self.config.model.max_seq,
            _max_turns=self.config.model.max_turns,
        )
        dataloader = UbuntuDataLoader(
            data,
            batch_size=1,
            shuffle=False,
            num_workers=8,
            collate_fn=collate,
        )
        sample = iter(dataloader)
        sample_data = next(sample)
        self.ctx = sample_data[0].to(self.device)
        self.response = sample_data[1].to(self.device)
        self.target = sample_data[2].to(self.device)
        pytorch_lightning.seed_everything(SEED_NUM)

    def test_inputs(self):
        batch_idx = 0
        self.assertEqual(
            "<|start|> thanks!  How the heck did you figure that out?. <|end|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|>",
            self.recosa.model.tokenizer.decode(self.ctx[batch_idx][-2]),
        )
        self.assertEqual(
            "<|start|> https://bugs.launchpad.net/lightdm/+bug/864109/comments/3. <|end|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|>",
            self.recosa.model.tokenizer.decode(self.ctx[batch_idx][-1]),
        )
        self.assertEqual(
            "<|start|> nice thanks!. <|end|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|>",
            self.recosa.model.tokenizer.decode(self.response[batch_idx]),
        )
        self.assertEqual(
            "nice thanks!. <|end|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|pad|> <|end|>",
            self.recosa.model.tokenizer.decode(self.target[batch_idx]),
        )

    def test_forward_recosa(self):
        dec_res = self.recosa.model(self.ctx, self.target)
        dec_res = torch.argmax(dec_res[0], dim=0)
        res_decoded = self.recosa.model.tokenizer.decode(dec_res)
        logger.debug(res_decoded)
        self.assertEqual(res_decoded.split()[0], "thanks!.")

    def test_generate_recosa(self):
        _, res = self.recosa.generate(self.ctx)
        res_decoded = self.recosa.model.tokenizer.decode(res[0])
        logger.debug(res_decoded)
        self.assertEqual(res_decoded.split()[0], "thanks.")
class TestConifg(unittest.TestCase):
    def setUp(self):
        self.cfg = Config()
        self.cfg.add_dataset("./conf/dataset/ubuntu.yml")

    def test_keys(self):
        self.assertEqual(
            list(self.cfg.dataset.keys()),
            ["root", "target", "curl", "raw", "refinement"],
        )

    def test_download_keys(self):
        down_file = DownloadableFile(**self.cfg.dataset.curl)
        self.assertEqual(list(down_file.__dict__.keys()),
                         ["url", "file_name", "hashcode"])

    def test_root(self):
        self.assertEqual(self.cfg.dataset.root, "./data/")

    def test_target(self):
        self.assertEqual(self.cfg.dataset.target, "Ubuntu")

    @pytest.mark.skip(reason="The test-data file to be changed")
    def test_download(self):
        build({"config": "./conf/dataset/ubuntu.yml", "version": "test"})
Exemplo n.º 3
0
def build(opt: dict):
    logger.debug("Read Config")
    cfg = Config()
    cfg.add_dataset(opt["data_config"])

    dpath = os.path.join(cfg.dataset.root, cfg.dataset.target)
    if cfg.dataset.curl is None:
        logger.warning(
            "DownloadableFile does not exist!; Off-line Data will be used.")
        return cfg
    else:
        resources = [DownloadableFile(**cfg.dataset.curl)]

    logger.debug("Check built")
    if not build_data.built(dpath, version_string=opt["version"]):
        logger.info("[building data: " + dpath + "]")
        if build_data.built(dpath):
            # An older version exists, so remove these outdated files.
            build_data.remove_dir(dpath)
        build_data.make_dir(dpath)

        # Download the data.
        for downloadable_file in resources:
            downloadable_file.download_file(dpath)

        build_data.untar(dpath, cfg.dataset.curl.file_name)

        # Mark the data as built.
        build_data.mark_done(dpath, version_string=opt["version"])
    logger.debug("Done")
    return cfg
Exemplo n.º 4
0
def main(
    config_data_file: str,
    config_model_file: str,
    config_trainer_file: str,
    config_api_file: str,
    version: str,
) -> None:

    # TODO: to be removed
    _ = build({"data_config": config_data_file, "version": version})

    cfg = Config()
    cfg.add_dataset(config_data_file)
    cfg.add_model(config_model_file)
    cfg.add_api(config_api_file)
    cfg.add_trainer(config_trainer_file)

    val_data = UbuntuDataSet(
        cfg.dataset.root + cfg.dataset.target,
        cfg.dataset.raw.val,
        cfg.model.max_seq,
        cfg.dataset.target,
        cfg.model.max_turns,
    )

    val_dataloader = UbuntuDataLoader(
        val_data,
        batch_size=cfg.model.batch_size,
        shuffle=False,
        num_workers=8,
        collate_fn=collate,
    )

    model = RecoSAPL.load_from_checkpoint(checkpoint_path=cfg.api.model_path,
                                          config=cfg)
    cfg.trainer.pl.max_epochs = 1

    trainer = pl.Trainer(**cfg.trainer.pl,
                         logger=False,
                         checkpoint_callback=False)
    test_result = trainer.test(model, test_dataloaders=val_dataloader)
    logger.info(test_result)
    bleu_score_4 = bleuS_4(model.pred, model.target)
    bleu_score_2 = bleuS_2(model.pred, model.target)
    logger.info(bleu_score_4)
    logger.info(bleu_score_2)
Exemplo n.º 5
0
class TestDataSet(unittest.TestCase):
    def setUp(self):
        self.cfg = Config()
        self.cfg.add_dataset("./conf/dataset/ubuntu_test.yml")
        self.cfg.add_model("./conf/model/ReCoSa_test.yml")
        self.data = UbuntuDataSet(
            folderpath=self.cfg.dataset.root + self.cfg.dataset.target,
            filepath=self.cfg.dataset.raw.train,
            _max_turns=self.cfg.model.max_turns,
        )
        self.test_batch_size = 2

    def test_config(self):
        self.assertEqual(list(self.cfg.dataset.keys()),
                         ["root", "target", "raw"])

    def test_model(self):
        self.assertEqual(
            list(self.cfg.model.keys()),
            [
                "output_size",
                "vocab_size",
                "embed_size",
                "utter_hidden_size",
                "utter_n_layer",
                "batch_size",
                "dropout",
                "out_size",
                "max_seq",
                "max_turns",
            ],
        )

    def test_dataset(self):
        instance_len = 3  # ctxs, response, target
        instance = self.data[0]
        ctxs = instance[0]
        response = instance[1]
        target = instance[2]

        # instance
        self.assertEqual(instance_len, len(instance))

        # ctx
        self.assertEqual(self.cfg.model.max_turns, len(ctxs))

        for ctx in ctxs:
            self.assertEqual(self.cfg.model.max_seq, len(ctx))

        # response
        self.assertEqual(self.cfg.model.max_seq, len(response))

        # target
        self.assertEqual(self.cfg.model.max_seq, len(target))

    def test_dataloader(self):
        dataloader = UbuntuDataLoader(
            self.data,
            batch_size=self.test_batch_size,
            shuffle=False,
            num_workers=1,
            collate_fn=collate,
        )
        for (ctx, response, cands) in dataloader:
            self.assertEqual(
                ctx.shape,
                torch.Size([
                    self.test_batch_size,
                    self.cfg.model.max_turns,
                    self.cfg.model.max_seq,
                ]),
            )
            self.assertEqual(
                response.shape,
                torch.Size([self.test_batch_size, self.cfg.model.max_seq]),
            )
            self.assertEqual(
                cands.shape,
                torch.Size([self.test_batch_size, self.cfg.model.max_seq]),
            )
            break