Exemplo n.º 1
0
class TrainArguments(utils.Arguments):
    model: models.AbstractTDA
    train_data: Sequence[Dialog]
    valid_data: Sequence[Dialog]
    processor: datasets.DialogProcessor
    device: torch.device = torch.device("cpu")
    save_dir: pathlib.Path = pathlib.Path("out")
    report_every: Optional[int] = None
    batch_size: int = 32
    valid_batch_size: int = 64
    optimizer: str = "adam"
    gradient_clip: Optional[float] = None
    l2norm_weight: Optional[float] = None
    learning_rate: float = 0.001
    num_epochs: int = 10
    kld_schedule: utils.Scheduler = utils.ConstantScheduler(1.0)
    dropout_schedule: utils.Scheduler = utils.ConstantScheduler(1.0)
    validate_every: int = 1
    beam_size: int = 4
    max_gen_len: int = 30
    early_stop: bool = False
    early_stop_criterion: str = "~val-loss"
    early_stop_patience: Optional[int] = None
    save_every: Optional[int] = None
    disable_kl: bool = False
    kl_mode: str = "kl-mi"
Exemplo n.º 2
0
class VHREDLoss(Loss):
    vocabs: VocabSet
    enable_kl: bool = True
    kld_weight: utils.Scheduler = utils.ConstantScheduler(1.0)
    _ce: ClassVar[nn.CrossEntropyLoss] = \
        nn.CrossEntropyLoss(reduction="none", ignore_index=-1)

    def compute(self,
                batch: BatchData,
                outputs,
                step: int = None) -> Tuple[torch.Tensor, utils.TensorMap]:
        step = step or 0
        logit, post, prior = outputs
        batch_size = batch.batch_size
        max_conv_len = batch.max_conv_len
        max_sent_len = batch.max_sent_len
        w_logit, zsent_post, zsent_prior = \
            logit["sent"], post["sent"], prior["sent"]
        conv_lens, sent_lens = batch.conv_lens, batch.sent.lens1
        conv_mask = utils.mask(conv_lens, max_conv_len)
        sent_lens = sent_lens.masked_fill(~conv_mask, 0)
        sent_mask = utils.mask(sent_lens, max_sent_len)
        kld_sent = zsent_post.kl_div(zsent_prior).masked_fill(~conv_mask, 0)
        w_target = (batch.sent.value.masked_fill(~sent_mask,
                                                 -1).view(-1,
                                                          max_sent_len))[...,
                                                                         1:]
        sent_loss = self._ce(
            w_logit[:, :, :-1].contiguous().view(-1, len(self.vocabs.word)),
            w_target.contiguous().view(-1)).view(batch_size, max_conv_len,
                                                 -1).sum(-1)
        kld_weight = self.kld_weight.get(step)
        loss_kld = kld_sent.sum(-1)
        loss_recon = sent_loss.sum(-1)
        nll = loss_recon + loss_kld
        if self.enable_kl:
            loss = loss_recon + kld_weight * loss_kld
        else:
            loss = loss_recon
        stats = {
            "nll": nll.mean(),
            "loss": loss.mean(),
            "loss-recon": loss_recon.mean(),
            "loss-sent": sent_loss.sum(-1).mean(),
            "loss-sent-turn": sent_loss.sum() / conv_lens.sum(),
            "loss-sent-word": sent_loss.sum() / sent_lens.sum(),
            "ppl-turn": (sent_loss.sum() / conv_lens.sum()).exp(),
            "ppl-word": (sent_loss.sum() / sent_lens.sum()).exp(),
            "kld-weight": torch.tensor(kld_weight),
            "kld-sent": kld_sent.sum(-1).mean(),
            "kld-sent-turn": kld_sent.sum() / conv_lens.sum(),
            "kld": loss_kld.mean()
        }
        return loss.mean(), stats
Exemplo n.º 3
0
def create_loss(model, vocabs: datasets.VocabSet,
                kld_weight: utils.Scheduler = utils.ConstantScheduler(1.0),
                enable_kl=True, kl_mode="kl-mi") -> losses.Loss:
    assert kl_mode in {"kl", "kl-mi", "kl-mi+"}
    if isinstance(model, models.VHDA):
        return losses.VHDALoss(
            vocabs=vocabs,
            kld_weight=kld_weight,
            enable_kl=enable_kl,
            kl_mode=kl_mode
        )
    elif isinstance(model, models.VHCR):
        return losses.VHCRLoss(
            vocabs=vocabs,
            enable_kl=enable_kl,
            kld_weight=kld_weight,
            kl_mode=kl_mode
        )
    elif isinstance(model, models.HDA):
        return losses.HDALoss(
            vocabs=vocabs
        )
    elif isinstance(model, models.VHDAWithoutGoal):
        return losses.VHDAWithoutGoalLoss(
            vocabs=vocabs,
            kld_weight=kld_weight,
            enable_kl=enable_kl,
            kl_mode=kl_mode
        )
    elif isinstance(model, models.VHDAWithoutGoalAct):
        return losses.VHDAWithoutGoalActLoss(
            vocabs=vocabs,
            kld_weight=kld_weight,
            enable_kl=enable_kl,
            calibrate_mi="mi" in kl_mode
        )
    elif isinstance(model, models.VHRED):
        return losses.VHREDLoss(
            vocabs=vocabs,
            enable_kl=enable_kl,
            kld_weight=kld_weight
        )
    elif isinstance(model, models.VHUS):
        return losses.VHUSLoss(
            vocabs=vocabs,
            enable_kl=enable_kl,
            kld_weight=kld_weight
        )
    else:
        raise RuntimeError(f"unsupported model: {type(model)}")
Exemplo n.º 4
0
class VHDAInferencer(Inferencer):
    sample_scale: float = 1.0
    dropout_scale: utils.Scheduler = utils.ConstantScheduler(1.0)

    def model_kwargs(self) -> dict:
        kwargs = super().model_kwargs()
        kwargs["sample_scale"] = self.sample_scale
        kwargs["dropout_scale"] = self.dropout_scale.get(self.global_step)
        return kwargs

    def on_batch_ended(self, batch: BatchData, pred: BatchData,
                       outputs) -> utils.TensorMap:
        stats = dict(super().on_batch_ended(batch, pred, outputs))
        dropout_scale = self.dropout_scale.get(self.global_step)
        stats["dropout-scale"] = torch.tensor(dropout_scale).to(self.device)
        return stats
Exemplo n.º 5
0
def main():
    args = utils.parse_args(create_parser())
    if args.logging_config is not None:
        logging.config.dictConfig(utils.load_yaml(args.logging_config))
    save_dir = pathlib.Path(args.save_dir)
    if (not args.overwrite and save_dir.exists()
            and utils.has_element(save_dir.glob("*.json"))):
        raise FileExistsError(f"save directory ({save_dir}) is not empty")
    shell = utils.ShellUtils()
    engine = inflect.engine()
    shell.mkdir(save_dir, silent=True)
    logger = logging.getLogger("train")
    utils.seed(args.seed)
    logger.info("loading data...")
    load_fn = utils.chain_func(lambda x: list(map(Dialog.from_json, x)),
                               utils.load_json)
    data_dir = pathlib.Path(args.data_dir)
    train_data = load_fn(str(data_dir.joinpath("train.json")))
    valid_data = load_fn(str(data_dir.joinpath("dev.json")))
    test_data = load_fn(str(data_dir.joinpath("test.json")))
    processor = datasets.DialogProcessor(sent_processor=datasets.SentProcessor(
        bos=True, eos=True, lowercase=True, tokenizer="space", max_len=30),
                                         boc=True,
                                         eoc=True,
                                         state_order="randomized",
                                         max_len=30)
    processor.prepare_vocabs(
        list(itertools.chain(train_data, valid_data, test_data)))
    utils.save_pickle(processor, save_dir.joinpath("processor.pkl"))
    logger.info("preparing model...")
    utils.save_json(utils.load_yaml(args.gen_model_path),
                    save_dir.joinpath("model.json"))
    torchmodels.register_packages(models)
    model_cls = torchmodels.create_model_cls(models, args.gen_model_path)
    model: models.AbstractTDA = model_cls(processor.vocabs)
    model.reset_parameters()
    utils.report_model(logger, model)
    device = torch.device("cpu")
    if args.gpu is not None:
        device = torch.device(f"cuda:{args.gpu}")
    model = model.to(device)

    def create_scheduler(s):
        return utils.PiecewiseScheduler(
            [utils.Coordinate(*t) for t in eval(s)])

    save_dir = pathlib.Path(args.save_dir)
    train_args = train.TrainArguments(
        model=model,
        train_data=tuple(train_data),
        valid_data=tuple(valid_data),
        processor=processor,
        device=device,
        save_dir=save_dir,
        report_every=args.report_every,
        batch_size=args.batch_size,
        valid_batch_size=args.valid_batch_size,
        optimizer=args.optimizer,
        gradient_clip=args.gradient_clip,
        l2norm_weight=args.l2norm_weight,
        learning_rate=args.learning_rate,
        num_epochs=args.epochs,
        kld_schedule=(utils.ConstantScheduler(1.0) if args.kld_schedule is None
                      else create_scheduler(args.kld_schedule)),
        dropout_schedule=(utils.ConstantScheduler(1.0)
                          if args.dropout_schedule is None else
                          create_scheduler(args.dropout_schedule)),
        validate_every=args.validate_every,
        early_stop=args.early_stop,
        early_stop_criterion=args.early_stop_criterion,
        early_stop_patience=args.early_stop_patience,
        disable_kl=args.disable_kl,
        kl_mode=args.kl_mode)
    utils.save_json(train_args.to_json(), save_dir.joinpath("train-args.json"))
    record = train.train(train_args)
    utils.save_json(record.to_json(), save_dir.joinpath("final-summary.json"))
    eval_dir = save_dir.joinpath("eval")
    shell.mkdir(eval_dir, silent=True)
    eval_data = dict(
        list(
            filter(None, [
                ("train", train_data) if "train" in args.eval_splits else None,
                ("dev", valid_data) if "dev" in args.eval_splits else None,
                ("test", test_data) if "test" in args.eval_splits else None
            ])))
    for split, data in eval_data.items():
        eval_args = evaluate.EvaluateArugments(
            model=model,
            train_data=tuple(train_data),
            test_data=tuple(data),
            processor=processor,
            embed_type=args.embed_type,
            embed_path=args.embed_path,
            device=device,
            batch_size=args.valid_batch_size,
            beam_size=args.beam_size,
            max_conv_len=args.max_conv_len,
            max_sent_len=args.max_sent_len)
        utils.save_json(eval_args.to_json(),
                        eval_dir.joinpath(f"eval-{split}-args.json"))
        eval_results = evaluate.evaluate(eval_args)
        save_path = eval_dir.joinpath(f"eval-{split}.json")
        utils.save_json(eval_results, save_path)
        logger.info(f"'{split}' results saved to {save_path}")
    logger.info(f"will run {args.gen_runs} generation trials...")
    gen_summary = []
    dst_summary = []
    for gen_idx in range(1, args.gen_runs + 1):
        logger.info(f"running {engine.ordinal(gen_idx)} generation trial...")
        gen_dir = save_dir.joinpath(f"gen-{gen_idx:03d}")
        shell.mkdir(gen_dir, silent=True)
        gen_args = generate.GenerateArguments(
            model=model,
            processor=processor,
            data=train_data,
            instances=int(round(len(train_data) * args.multiplier)),
            batch_size=args.valid_batch_size,
            conv_scale=args.conv_scale,
            spkr_scale=args.spkr_scale,
            goal_scale=args.goal_scale,
            state_scale=args.state_scale,
            sent_scale=args.sent_scale,
            validate_dst=True,
            validate_unique=args.validate_unique,
            device=device)
        utils.save_json(gen_args.to_json(), gen_dir.joinpath("gen-args.json"))
        with torch.no_grad():
            samples = generate.generate(gen_args)
        utils.save_json([sample.output.to_json() for sample in samples],
                        gen_dir.joinpath("gen-out.json"))
        utils.save_json([sample.input.to_json() for sample in samples],
                        gen_dir.joinpath("gen-in.json"))
        utils.save_lines([str(sample.log_prob) for sample in samples],
                         gen_dir.joinpath("logprob.txt"))
        da_data = [sample.output for sample in samples]
        data = {"train": train_data, "dev": valid_data, "test": test_data}
        data["train"] += da_data
        # convert dialogs to dst dialogs
        data = {
            split: list(map(datasets.DSTDialog.from_dialog, dialogs))
            for split, dialogs in data.items()
        }
        for split, dialogs in data.items():
            logger.info(f"verifying '{split}' dataset...")
            for dialog in dialogs:
                dialog.compute_user_goals()
                dialog.validate()
        logger.info("preparing dst environment...")
        dst_processor = dst_datasets.DSTDialogProcessor(
            sent_processor=datasets.SentProcessor(
                bos=True, eos=True, lowercase=True, max_len=30))
        dst_processor.prepare_vocabs(list(itertools.chain(*data.values())))
        train_dataset = dst_datasets.DSTDialogDataset(dialogs=data["train"],
                                                      processor=dst_processor)
        train_dataloader = dst_datasets.create_dataloader(
            train_dataset,
            batch_size=args.dst_batch_size,
            shuffle=True,
            pin_memory=True)
        dev_dataloader = dst_run.TestDataloader(
            dialogs=data["dev"],
            processor=dst_processor,
            max_batch_size=args.dst_batch_size)
        test_dataloader = dst_run.TestDataloader(
            dialogs=data["test"],
            processor=dst_processor,
            max_batch_size=args.dst_batch_size)
        logger.info("saving dst processor object...")
        utils.save_pickle(dst_processor, gen_dir.joinpath("processor.pkl"))
        torchmodels.register_packages(dst_models)
        dst_model_cls = torchmodels.create_model_cls(dst_pkg,
                                                     args.dst_model_path)
        dst_model = dst_model_cls(dst_processor.vocabs)
        dst_model = dst_model.to(device)
        logger.info(str(model))
        logger.info(f"number of parameters DST: "
                    f"{utils.count_parameters(dst_model):,d}")
        logger.info(f"running {args.dst_runs} trials...")
        all_results = []
        for idx in range(1, args.dst_runs + 1):
            logger.info(f"running {engine.ordinal(idx)} dst trial...")
            trial_dir = gen_dir.joinpath(f"dst-{idx:03d}")
            logger.info("resetting parameters...")
            dst_model.reset_parameters()
            logger.info("preparing trainer...")
            runner = dst_run.Runner(
                model=dst_model,
                processor=dst_processor,
                device=device,
                save_dir=trial_dir,
                epochs=int(round(args.dst_epochs / (1 + args.multiplier))),
                loss="sum",
                l2norm=args.dst_l2norm,
                gradient_clip=args.dst_gradient_clip,
                train_validate=False,
                early_stop=True,
                early_stop_criterion="joint-goal",
                early_stop_patience=None,
                asr_method="scaled",
                asr_sigmoid_sum_order="sigmoid-sum",
                asr_topk=5)

            logger.info("commencing training...")
            record = runner.train(train_dataloader=train_dataloader,
                                  dev_dataloader=dev_dataloader,
                                  test_fn=None)
            logger.info("final summary: ")
            logger.info(pprint.pformat(record.to_json()))
            utils.save_json(record.to_json(),
                            trial_dir.joinpath("summary.json"))
            if not args.dst_test_asr:
                logger.info("commencing testing...")
                with torch.no_grad():
                    eval_results = runner.test(test_dataloader)
                logger.info("test results: ")
                logger.info(pprint.pformat(eval_results))
            else:
                logger.info("commencing testing (asr)...")
                with torch.no_grad():
                    eval_results = runner.test_asr(test_dataloader)
                logger.info("test(asr) results: ")
                logger.info(pprint.pformat(eval_results))
            eval_results["epoch"] = int(record.epoch)
            logger.info("test evaluation: ")
            logger.info(pprint.pformat(eval_results))
            utils.save_json(eval_results, trial_dir.joinpath("eval.json"))
            all_results.append(eval_results)
            dst_summary.append(eval_results)
        logger.info("aggregating results...")
        summary = reduce_json(all_results)
        logger.info("aggregated results: ")
        agg_summary = pprint.pformat(
            {k: v["stats"]["mean"]
             for k, v in summary.items()})
        logger.info(pprint.pformat(agg_summary))
        gen_summary.append(agg_summary)
        utils.save_json(summary, gen_dir.joinpath("summary.json"))
    gen_summary = reduce_json(gen_summary)
    dst_summary = reduce_json(dst_summary)
    logger.info(f"aggregating generation trials ({args.gen_runs})...")
    logger.info(
        pprint.pformat({k: v["stats"]["mean"]
                        for k, v in gen_summary.items()}))
    logger.info(f"aggregating dst trials ({args.gen_runs * args.dst_runs})...")
    logger.info(
        pprint.pformat({k: v["stats"]["mean"]
                        for k, v in dst_summary.items()}))
    utils.save_json(gen_summary, save_dir.joinpath("gen-summary.json"))
    utils.save_json(dst_summary, save_dir.joinpath("dst-summary.json"))
    logger.info("done!")
Exemplo n.º 6
0
class VHUSLoss(Loss):
    vocabs: VocabSet
    enable_kl: bool = True
    kld_weight: utils.Scheduler = utils.ConstantScheduler(1.0)
    _ce: ClassVar[nn.CrossEntropyLoss] = \
        nn.CrossEntropyLoss(reduction="none", ignore_index=-1)
    _bce: ClassVar[nn.BCEWithLogitsLoss] = \
        nn.BCEWithLogitsLoss(reduction="none")

    @property
    def num_asv(self):
        return len(self.vocabs.goal_state.asv)

    def compute(self,
                batch: BatchData,
                outputs,
                step: int = None) -> Tuple[torch.Tensor, utils.TensorMap]:
        step = step or 0
        logit, post, prior = outputs
        batch_size = batch.batch_size
        max_conv_len = batch.max_conv_len
        s_logit, zstate_post, zstate_prior = \
            logit["state"], post["state"], prior["state"]
        conv_lens, sent_lens = batch.conv_lens, batch.sent.lens1
        conv_mask = utils.mask(conv_lens, max_conv_len)
        state_logit_mask = \
            (((s_logit != float("-inf")) & (s_logit != float("inf")))
             .masked_fill(~conv_mask.unsqueeze(-1), 0))
        kld_state = zstate_post.kl_div(zstate_prior).masked_fill(~conv_mask, 0)
        s_target = utils.to_dense(idx=batch.state.value,
                                  lens=batch.state.lens1,
                                  max_size=self.num_asv)
        p_target = batch.speaker.value.masked_fill(~conv_mask, -1)
        state_loss = (self._bce(s_logit, s_target.float()).masked_fill(
            ~state_logit_mask, 0)).sum(-1)
        kld_weight = self.kld_weight.get(step)
        nll = state_loss + kld_state
        loss = state_loss + kld_weight * kld_state
        state_mi = \
            (estimate_mi(zstate_post.view(batch_size * max_conv_len, -1))
             .view(batch_size, max_conv_len).masked_fill(~conv_mask, 0).sum(-1))
        stats = {
            "nll": nll.mean(),
            "state-mi": state_mi.mean(),
            "loss-state": state_loss.sum(-1).mean(),
            "loss-state-turn": state_loss.sum() / conv_lens.sum(),
            "loss-state-asv": state_loss.sum() / state_logit_mask.sum(),
            "kld-weight": torch.tensor(kld_weight),
            "kld-state": kld_state.sum(-1).mean(),
            "kld-state-turn": kld_state.sum() / conv_lens.sum(),
            "kld": kld_state.sum(-1).mean()
        }
        for spkr_idx, spkr in self.vocabs.speaker.i2f.items():
            if spkr == "<unk>":
                continue
            spkr_mask = p_target == spkr_idx
            spkr_state_mask = \
                state_logit_mask.masked_fill(~spkr_mask.unsqueeze(-1), 0)
            spkr_state_loss = state_loss.masked_fill(~spkr_mask, 0).sum()
            spkr_kld_state = kld_state.masked_fill(~spkr_mask, 0).sum()
            spkr_stats = {
                "loss-state": spkr_state_loss / batch_size,
                "loss-state-turn": spkr_state_loss / spkr_mask.sum(),
                "loss-state-asv": spkr_state_loss / spkr_state_mask.sum(),
                "kld-state": spkr_kld_state / batch_size,
                "kld-state-turn": spkr_kld_state / spkr_mask.sum(),
            }
            stats.update({f"{k}-{spkr}": v for k, v in spkr_stats.items()})
        return loss.mean(), stats
Exemplo n.º 7
0
def main():
    args = utils.parse_args(create_parser())
    if args.logging_config is not None:
        logging.config.dictConfig(utils.load_yaml(args.logging_config))
    save_dir = pathlib.Path(args.save_dir)
    if (not args.overwrite and
            save_dir.exists() and utils.has_element(save_dir.glob("*.json"))):
        raise FileExistsError(f"save directory ({save_dir}) is not empty")
    shell = utils.ShellUtils()
    shell.mkdir(save_dir, silent=True)
    logger = logging.getLogger("train")
    utils.seed(args.seed)
    logger.info("loading data...")
    load_fn = utils.chain_func(lambda x: list(map(Dialog.from_json, x)),
                               utils.load_json)
    data_dir = pathlib.Path(args.data_dir)
    train_data = load_fn(str(data_dir.joinpath("train.json")))
    valid_data = load_fn(str(data_dir.joinpath("dev.json")))
    test_data = load_fn(str(data_dir.joinpath("test.json")))
    processor = datasets.DialogProcessor(
        sent_processor=datasets.SentProcessor(
            bos=True,
            eos=True,
            lowercase=True,
            tokenizer="space",
            max_len=30
        ),
        boc=True,
        eoc=True,
        state_order="randomized",
        max_len=30
    )
    processor.prepare_vocabs(
        list(itertools.chain(train_data, valid_data, test_data)))
    utils.save_pickle(processor, save_dir.joinpath("processor.pkl"))
    logger.info("preparing model...")
    utils.save_json(utils.load_yaml(args.model_path),
                    save_dir.joinpath("model.json"))
    torchmodels.register_packages(models)
    model_cls = torchmodels.create_model_cls(models, args.model_path)
    model: models.AbstractTDA = model_cls(processor.vocabs)
    model.reset_parameters()
    utils.report_model(logger, model)
    device = torch.device("cpu")
    if args.gpu is not None:
        device = torch.device(f"cuda:{args.gpu}")
    model = model.to(device)

    def create_scheduler(s):
        return utils.PiecewiseScheduler([utils.Coordinate(*t) for t in eval(s)])

    save_dir = pathlib.Path(args.save_dir)
    train_args = train.TrainArguments(
        model=model,
        train_data=tuple(train_data),
        valid_data=tuple(valid_data),
        processor=processor,
        device=device,
        save_dir=save_dir,
        report_every=args.report_every,
        batch_size=args.batch_size,
        valid_batch_size=args.valid_batch_size,
        optimizer=args.optimizer,
        gradient_clip=args.gradient_clip,
        l2norm_weight=args.l2norm_weight,
        learning_rate=args.learning_rate,
        num_epochs=args.epochs,
        kld_schedule=(utils.ConstantScheduler(1.0)
                      if args.kld_schedule is None else
                      create_scheduler(args.kld_schedule)),
        dropout_schedule=(utils.ConstantScheduler(1.0)
                          if args.dropout_schedule is None else
                          create_scheduler(args.dropout_schedule)),
        validate_every=args.validate_every,
        early_stop=args.early_stop,
        early_stop_criterion=args.early_stop_criterion,
        early_stop_patience=args.early_stop_patience,
        disable_kl=args.disable_kl,
        kl_mode=args.kl_mode
    )
    utils.save_json(train_args.to_json(), save_dir.joinpath("train-args.json"))
    record = train.train(train_args)
    utils.save_json(record.to_json(), save_dir.joinpath("final-summary.json"))
    eval_dir = save_dir.joinpath("eval")
    shell.mkdir(eval_dir, silent=True)
    eval_data = dict(list(filter(None, [
        ("train", train_data) if "train" in args.eval_splits else None,
        ("dev", valid_data) if "dev" in args.eval_splits else None,
        ("test", test_data) if "test" in args.eval_splits else None
    ])))
    for split, data in eval_data.items():
        eval_args = evaluate.EvaluateArugments(
            model=model,
            train_data=tuple(train_data),
            test_data=tuple(data),
            processor=processor,
            embed_type=args.embed_type,
            embed_path=args.embed_path,
            device=device,
            batch_size=args.valid_batch_size,
            beam_size=args.beam_size,
            max_conv_len=args.max_conv_len,
            max_sent_len=args.max_sent_len
        )
        utils.save_json(eval_args.to_json(),
                        eval_dir.joinpath(f"eval-{split}-args.json"))
        with torch.no_grad():
            eval_results = evaluate.evaluate(eval_args)
        save_path = eval_dir.joinpath(f"eval-{split}.json")
        utils.save_json(eval_results, save_path)
        logger.info(f"'{split}' results saved to {save_path}")
    logger.info("done!")
Exemplo n.º 8
0
class VHDAWithoutGoalLoss(Loss):
    vocabs: VocabSet
    enable_kl: bool = True
    kl_mode: str = "kl"
    kld_weight: utils.Scheduler = utils.ConstantScheduler(1.0)
    _ce: ClassVar[nn.CrossEntropyLoss] = \
        nn.CrossEntropyLoss(reduction="none", ignore_index=-1)
    _bce: ClassVar[nn.BCEWithLogitsLoss] = \
        nn.BCEWithLogitsLoss(reduction="none")

    def __post_init__(self):
        assert self.kl_mode in {"kl", "kl-mi", "kl-mi+"}

    @property
    def num_asv(self):
        return len(self.vocabs.goal_state.asv)

    def compute(self,
                batch: BatchData,
                outputs,
                step: int = None) -> Tuple[torch.Tensor, utils.TensorMap]:
        step = step or 0
        logit, post, prior = outputs
        batch_size = batch.batch_size
        max_conv_len = batch.max_conv_len
        max_sent_len = batch.max_sent_len
        max_goal_len = batch.max_goal_len
        max_state_len = batch.max_state_len
        w_logit, p_logit, s_logit = \
            (logit[k] for k in ("sent", "speaker", "state"))
        zconv_post, zstate_post, zsent_post, zspkr_post = \
            (post[k] for k in ("conv", "state", "sent", "speaker"))
        zconv_prior, zstate_prior, zsent_prior, zspkr_prior = \
            (prior[k] for k in ("conv", "state", "sent", "speaker"))
        conv_lens, sent_lens = batch.conv_lens, batch.sent.lens1
        conv_mask = utils.mask(conv_lens, max_conv_len)
        sent_lens = sent_lens.masked_fill(~conv_mask, 0)
        sent_mask = utils.mask(sent_lens, max_sent_len)
        state_logit_mask = \
            (((s_logit != float("-inf")) & (s_logit != float("inf")))
             .masked_fill(~conv_mask.unsqueeze(-1), 0))
        kld_conv = zconv_post.kl_div()
        kld_state = zstate_post.kl_div(zstate_prior).masked_fill(~conv_mask, 0)
        kld_sent = zsent_post.kl_div(zsent_prior).masked_fill(~conv_mask, 0)
        kld_spkr = zspkr_post.kl_div(zspkr_prior).masked_fill(~conv_mask, 0)
        w_target = (batch.sent.value.masked_fill(~sent_mask,
                                                 -1).view(-1,
                                                          max_sent_len))[...,
                                                                         1:]
        s_target = utils.to_dense(idx=batch.state.value,
                                  lens=batch.state.lens1,
                                  max_size=self.num_asv)
        p_target = batch.speaker.value.masked_fill(~conv_mask, -1)
        state_loss = (self._bce(s_logit, s_target.float()).masked_fill(
            ~state_logit_mask, 0)).sum(-1)
        spkr_loss = self._ce(p_logit.view(-1, self.vocabs.num_speakers),
                             p_target.view(-1)).view(batch_size, max_conv_len)
        sent_loss = self._ce(
            w_logit[:, :, :-1].contiguous().view(-1, len(self.vocabs.word)),
            w_target.contiguous().view(-1)).view(batch_size, max_conv_len,
                                                 -1).sum(-1)
        kld_weight = self.kld_weight.get(step)
        loss_kld = (kld_conv + kld_sent.sum(-1) + kld_state.sum(-1) +
                    kld_spkr.sum(-1))
        loss_recon = (sent_loss.sum(-1) + state_loss.sum(-1) +
                      spkr_loss.sum(-1))
        nll = loss_recon + loss_kld
        conv_mi = estimate_mi(zconv_post)
        sent_mi = \
            (estimate_mi(zsent_post.view(batch_size * max_conv_len, -1))
             .view(batch_size, max_conv_len).masked_fill(~conv_mask, 0).sum(-1))
        spkr_mi = \
            (estimate_mi(zspkr_post.view(batch_size * max_conv_len, -1))
             .view(batch_size, max_conv_len).masked_fill(~conv_mask, 0).sum(-1))
        state_mi = \
            (estimate_mi(zstate_post.view(batch_size * max_conv_len, -1))
             .view(batch_size, max_conv_len).masked_fill(~conv_mask, 0).sum(-1))
        if self.enable_kl:
            if self.kl_mode == "kl-mi":
                loss = loss_recon + kld_weight * (loss_kld - conv_mi)
            elif self.kl_mode == "kl-mi+":
                loss = loss_recon + kld_weight * (loss_kld - conv_mi -
                                                  sent_mi - spkr_mi - state_mi)
            else:
                loss = loss_recon + kld_weight * loss_kld
        else:
            loss = loss_recon
        stats = {
            "nll": nll.mean(),
            "conv-mi": conv_mi.mean(),
            "sent-mi": sent_mi.mean(),
            "state-mi": state_mi.mean(),
            "spkr-mi": spkr_mi.mean(),
            "loss": loss.mean(),
            "loss-recon": loss_recon.mean(),
            "loss-sent": sent_loss.sum(-1).mean(),
            "loss-sent-turn": sent_loss.sum() / conv_lens.sum(),
            "loss-sent-word": sent_loss.sum() / sent_lens.sum(),
            "ppl-turn": (sent_loss.sum() / conv_lens.sum()).exp(),
            "ppl-word": (sent_loss.sum() / sent_lens.sum()).exp(),
            "loss-state": state_loss.sum(-1).mean(),
            "loss-state-turn": state_loss.sum() / conv_lens.sum(),
            "loss-state-asv": state_loss.sum() / state_logit_mask.sum(),
            "loss-spkr": spkr_loss.sum(-1).mean(),
            "loss-spkr-turn": spkr_loss.sum() / conv_lens.sum(),
            "kld-weight": torch.tensor(kld_weight),
            "kld-sent": kld_sent.sum(-1).mean(),
            "kld-sent-turn": kld_sent.sum() / conv_lens.sum(),
            "kld-conv": kld_conv.sum(-1).mean(),
            "kld-state": kld_state.sum(-1).mean(),
            "kld-state-turn": kld_state.sum() / conv_lens.sum(),
            "kld-spkr": kld_spkr.sum(-1).mean(),
            "kld-spkr-turn": kld_spkr.sum() / conv_lens.sum(),
            "kld": loss_kld.mean()
        }
        for spkr_idx, spkr in self.vocabs.speaker.i2f.items():
            if spkr == "<unk>":
                continue
            spkr_mask = p_target == spkr_idx
            spkr_sent_lens = sent_lens.masked_fill(~spkr_mask, 0)
            spkr_state_mask = \
                state_logit_mask.masked_fill(~spkr_mask.unsqueeze(-1), 0)
            spkr_sent_loss = sent_loss.masked_fill(~spkr_mask, 0).sum()
            spkr_state_loss = state_loss.masked_fill(~spkr_mask, 0).sum()
            spkr_spkr_loss = spkr_loss.masked_fill(~spkr_mask, 0).sum()
            spkr_kld_sent = kld_sent.masked_fill(~spkr_mask, 0).sum()
            spkr_kld_state = kld_state.masked_fill(~spkr_mask, 0).sum()
            spkr_kld_spkr = kld_spkr.masked_fill(~spkr_mask, 0).sum()
            spkr_stats = {
                "loss-sent": spkr_sent_loss / batch_size,
                "loss-sent-turn": spkr_sent_loss / spkr_mask.sum(),
                "loss-sent-word": spkr_sent_loss / spkr_sent_lens.sum(),
                "ppl-turn": (spkr_sent_loss / spkr_mask.sum()).exp(),
                "ppl-word": (spkr_sent_loss / spkr_sent_lens.sum()).exp(),
                "loss-state": spkr_state_loss / batch_size,
                "loss-state-turn": spkr_state_loss / spkr_mask.sum(),
                "loss-state-asv": spkr_state_loss / spkr_state_mask.sum(),
                "loss-spkr": spkr_spkr_loss / batch_size,
                "loss-spkr-turn": spkr_spkr_loss / spkr_mask.sum(),
                "kld-sent": spkr_kld_sent / batch_size,
                "kld-sent-turn": spkr_kld_sent / spkr_mask.sum(),
                "kld-state": spkr_kld_state / batch_size,
                "kld-state-turn": spkr_kld_state / spkr_mask.sum(),
                "kld-spkr": spkr_kld_spkr / batch_size,
                "kld-spkr-turn": spkr_kld_spkr / spkr_mask.sum(),
            }
            stats.update({f"{k}-{spkr}": v for k, v in spkr_stats.items()})
        return loss.mean(), stats
Exemplo n.º 9
0
Arquivo: vhcr.py Projeto: kaniblu/vhda
class VHCRLoss(Loss):
    vocabs: VocabSet
    enable_kl: bool = True
    kl_mode: str = "kl"
    kld_weight: utils.Scheduler = utils.ConstantScheduler(1.0)
    _ce: ClassVar[nn.CrossEntropyLoss] = \
        nn.CrossEntropyLoss(reduction="none", ignore_index=-1)
    _bce: ClassVar[nn.BCEWithLogitsLoss] = \
        nn.BCEWithLogitsLoss(reduction="none")

    def __post_init__(self):
        assert self.kl_mode in {"kl", "kl-mi", "kl-mi+"}

    @property
    def num_asv(self):
        return len(self.vocabs.goal_state.asv)

    def compute(self, batch: BatchData, outputs, step: int = None
                ) -> Tuple[torch.Tensor, utils.TensorMap]:
        step = step or 0
        logit, post, prior = outputs
        batch_size = batch.batch_size
        max_conv_len = batch.max_conv_len
        max_sent_len = batch.max_sent_len
        w_logit, p_logit, g_logit, s_logit = \
            (logit[k] for k in ("sent", "speaker", "goal", "state"))
        zconv_post, zsent_post = (post[k] for k in ("conv", "sent"))
        zconv_prior, zsent_prior = (prior[k] for k in ("conv", "sent"))
        conv_lens, sent_lens = batch.conv_lens, batch.sent.lens1
        conv_mask = utils.mask(conv_lens, max_conv_len)
        sent_lens = sent_lens.masked_fill(~conv_mask, 0)
        sent_mask = utils.mask(sent_lens, max_sent_len)
        kld_conv = zconv_post.kl_div()
        kld_sent = zsent_post.kl_div(zsent_prior).masked_fill(~conv_mask, 0)
        w_target = (batch.sent.value
                    .masked_fill(~sent_mask, -1)
                    .view(-1, max_sent_len))[..., 1:]
        sent_loss = self._ce(
            w_logit[:, :, :-1].contiguous().view(-1, len(self.vocabs.word)),
            w_target.contiguous().view(-1)
        ).view(batch_size, max_conv_len, -1).sum(-1)
        kld_weight = self.kld_weight.get(step)
        loss_kld = kld_sent.sum(-1) + kld_conv
        loss_recon = sent_loss.sum(-1)
        nll = loss_recon + loss_kld
        conv_mi = estimate_mi(zconv_post)
        sent_mi = \
            (estimate_mi(zsent_post.view(batch_size * max_conv_len, -1))
             .view(batch_size, max_conv_len).masked_fill(~conv_mask, 0).sum(-1))
        if self.enable_kl:
            if self.kl_mode == "kl":
                loss = loss_recon + kld_weight * loss_kld
            elif self.kl_mode == "kl-mi":
                loss = loss_recon + kld_weight * (loss_kld - conv_mi)
            elif self.kl_mode == "kl-mi+":
                loss = loss_recon + kld_weight * (loss_kld - conv_mi - sent_mi)
            else:
                raise ValueError(f"unexpected kl mode: {self.kl_mode}")
        else:
            loss = loss_recon
        stats = {
            "nll": nll.mean(),
            "conv-mi": conv_mi.mean(),
            "sent-mi": sent_mi.mean(),
            "loss": loss.mean(),
            "loss-recon": loss_recon.mean(),
            "loss-sent": sent_loss.sum(-1).mean(),
            "loss-sent-turn": sent_loss.sum() / conv_lens.sum(),
            "loss-sent-word": sent_loss.sum() / sent_lens.sum(),
            "ppl-turn": (sent_loss.sum() / conv_lens.sum()).exp(),
            "ppl-word": (sent_loss.sum() / sent_lens.sum()).exp(),
            "kld-weight": torch.tensor(kld_weight),
            "kld-sent": kld_sent.sum(-1).mean(),
            "kld-sent-turn": kld_sent.sum() / conv_lens.sum(),
            "kld-conv": kld_conv.sum(-1).mean(),
            "kld": loss_kld.mean()
        }
        return loss.mean(), stats
Exemplo n.º 10
0
    def main(self):
        args = self.args
        device = "cuda" if not args['no_cuda'] and torch.cuda.is_available(
        ) else "cpu"

        ################## Setup Dataset ##################################
        if args['dataset'].lower().startswith("syn"):
            # Manual toggle of different datasets:
            #d_k_a = {'feature_mode': 'default', 'assign_feat': 'id'}
            #enzyme_gen = DiffpoolDataset('ENZYMES', use_node_attr=True,
            #                             use_node_label=False,
            #                             mode='train',
            #                             train_ratio=0.8,
            #                             test_ratio=0.1,
            #                             **d_k_a)
            #clique_gen = DummyClique(600, [10, 20, 30, 40, 50], [20, 30, 40, 50, 60],
            #                         5)
            #dataset = SyncPoolDataset(600, graph_dataset=clique_gen,
            #                          num_sub_graphs=10, mode='train')

            # Fix a random dataset
            pickle_in = open('sync_fix_H.pickle', 'rb')
            #pickle_in = open('sync_fix.pickle', 'rb')
            dataset = pickle.load(pickle_in)
            pickle_in = open('sync_fix_H_val.pickle', 'rb')
            dataset_val = pickle.load(pickle_in)
            pickle_in = open('sync_fix_H_test.pickle', 'rb')
            dataset_test = pickle.load(pickle_in)

            # Hijack node feature
            adversial_feature = []
            for i in range(len(dataset.features)):
                feat = np.ones(dataset.features[i].shape)
                adversial_feature.append(feat)
            max_num_nodes_candidate = []
            for ds in (dataset, dataset_val, dataset_test):
                max_num_nodes = max(
                    np.array([item[0][0].shape[0] for item in ds]))
                max_num_nodes_candidate.append(max_num_nodes)
            max_num_nodes = max(max_num_nodes_candidate)
            #dataset.features = adversial_feature
            #print('manually set node features to constant')
        else:
            dataset = TUDataset(args['dataset'])
            max_num_nodes = max(
                np.array([item[0][0].shape[0] for item in dataset]))

        # Turn this off if dataset not pre_separated
        # \TODO move this to argparser
        pre_separated = True

        dataset_size = len(dataset)
        train_size = int(dataset_size * args['train_ratio'])
        val_size = dataset_size - train_size
        mean_num_nodes = int(
            np.array([item[0][0].shape[0] for item in dataset]).mean())
        n_classes = int(max([item[1] for item in dataset])) + 1

        skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=42)
        labels = np.array(list(zip(*dataset))[1])
        for train_index, test_index in skf.split(np.zeros_like(labels),
                                                 labels):
            # THIS K-FOLD IS BROKEN
            #\TODO Fix
            utils.writer = utils.CustomLogger(
                comment='|'.join([args['dataset'], args['logname']]))
            utils.writer.add_text("args", str(args))
            utils.writer.log_args(args)
            utils.writer.log_py_file()

            train_val_data = torch.utils.data.Subset(dataset, train_index)
            train_val_labels = np.array(list(zip(*train_val_data))[1])
            train_data, val_data = train_test_split(train_val_data,
                                                    train_size=0.8,
                                                    stratify=train_val_labels)
            test_data = torch.utils.data.Subset(dataset, test_index)
            viz_data = torch.utils.data.Subset(test_data, [0])

            input_shape = int(dataset[0][0][1].shape[-1])
            if pre_separated:
                #\TODO not k-fold ready yet!
                train_loader = DataLoader(dataset,
                                          batch_size=args['batch_size'],
                                          shuffle=True,
                                          collate_fn=CollateFn(
                                              max_num_nodes, device))
                val_loader = DataLoader(dataset_val,
                                        batch_size=args['batch_size'],
                                        shuffle=True,
                                        collate_fn=CollateFn(
                                            max_num_nodes, device))
                test_loader = DataLoader(dataset_test,
                                         batch_size=args['batch_size'],
                                         shuffle=True,
                                         collate_fn=CollateFn(
                                             max_num_nodes, device))

            else:
                train_loader = DataLoader(train_data,
                                          batch_size=args['batch_size'],
                                          shuffle=True,
                                          collate_fn=CollateFn(
                                              max_num_nodes, device))
                val_loader = DataLoader(val_data,
                                        batch_size=args['batch_size'],
                                        shuffle=False,
                                        collate_fn=CollateFn(
                                            max_num_nodes, device))
                test_loader = DataLoader(test_data,
                                         batch_size=args['batch_size'],
                                         shuffle=False,
                                         collate_fn=CollateFn(
                                             max_num_nodes, device))

            viz_loader = DataLoader(viz_data,
                                    batch_size=1,
                                    collate_fn=CollateFn(
                                        max_num_nodes, device))
            ############### Record Config and setup model ###############################
            config = {}
            config.update(args)
            if args.get('pool_size', None) is not None:
                pool_size = args['pool_size']
            else:
                pool_size = int(mean_num_nodes * args['pool_ratio'])
            for k, v in locals().copy().items():
                if k in [
                        'device', 'args', 'pool_size', 'input_shape',
                        'n_classes'
                ]:
                    config[k] = v
            config['rtn'] = args['rtn']
            config['link_pred'] = args['link_pred']
            config['min_cut'] = True
            print("############################")
            print(config)
            if args['rtn'] > 0:
                tqdm.write("Using Routing Model")
            else:
                tqdm.write("Using DiffPool")
            model = BatchedModel(**config).to(device)
            print(model)
            ############### Optimizer and Scheduler #################################
            self.optimizer = optim.Adam(model.parameters())
            if config.get("scheduler", False):
                self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                    self.optimizer)
            else:
                self.scheduler = utils.ConstantScheduler()
            self.vg = True

            for e in tqdm(range(args['epochs'])):
                utils.e = e
                if args['viz']:
                    self.visualize(e, model, viz_loader)
                self.train(args, e, model, train_loader)
                self.val(args, e, model, val_loader)
                self.test(args, e, model, test_loader)
                utils.writer.log_epochs(e)

            utils.writer.log_tfboard()
            break