class TrainArguments(utils.Arguments): model: models.AbstractTDA train_data: Sequence[Dialog] valid_data: Sequence[Dialog] processor: datasets.DialogProcessor device: torch.device = torch.device("cpu") save_dir: pathlib.Path = pathlib.Path("out") report_every: Optional[int] = None batch_size: int = 32 valid_batch_size: int = 64 optimizer: str = "adam" gradient_clip: Optional[float] = None l2norm_weight: Optional[float] = None learning_rate: float = 0.001 num_epochs: int = 10 kld_schedule: utils.Scheduler = utils.ConstantScheduler(1.0) dropout_schedule: utils.Scheduler = utils.ConstantScheduler(1.0) validate_every: int = 1 beam_size: int = 4 max_gen_len: int = 30 early_stop: bool = False early_stop_criterion: str = "~val-loss" early_stop_patience: Optional[int] = None save_every: Optional[int] = None disable_kl: bool = False kl_mode: str = "kl-mi"
class VHREDLoss(Loss): vocabs: VocabSet enable_kl: bool = True kld_weight: utils.Scheduler = utils.ConstantScheduler(1.0) _ce: ClassVar[nn.CrossEntropyLoss] = \ nn.CrossEntropyLoss(reduction="none", ignore_index=-1) def compute(self, batch: BatchData, outputs, step: int = None) -> Tuple[torch.Tensor, utils.TensorMap]: step = step or 0 logit, post, prior = outputs batch_size = batch.batch_size max_conv_len = batch.max_conv_len max_sent_len = batch.max_sent_len w_logit, zsent_post, zsent_prior = \ logit["sent"], post["sent"], prior["sent"] conv_lens, sent_lens = batch.conv_lens, batch.sent.lens1 conv_mask = utils.mask(conv_lens, max_conv_len) sent_lens = sent_lens.masked_fill(~conv_mask, 0) sent_mask = utils.mask(sent_lens, max_sent_len) kld_sent = zsent_post.kl_div(zsent_prior).masked_fill(~conv_mask, 0) w_target = (batch.sent.value.masked_fill(~sent_mask, -1).view(-1, max_sent_len))[..., 1:] sent_loss = self._ce( w_logit[:, :, :-1].contiguous().view(-1, len(self.vocabs.word)), w_target.contiguous().view(-1)).view(batch_size, max_conv_len, -1).sum(-1) kld_weight = self.kld_weight.get(step) loss_kld = kld_sent.sum(-1) loss_recon = sent_loss.sum(-1) nll = loss_recon + loss_kld if self.enable_kl: loss = loss_recon + kld_weight * loss_kld else: loss = loss_recon stats = { "nll": nll.mean(), "loss": loss.mean(), "loss-recon": loss_recon.mean(), "loss-sent": sent_loss.sum(-1).mean(), "loss-sent-turn": sent_loss.sum() / conv_lens.sum(), "loss-sent-word": sent_loss.sum() / sent_lens.sum(), "ppl-turn": (sent_loss.sum() / conv_lens.sum()).exp(), "ppl-word": (sent_loss.sum() / sent_lens.sum()).exp(), "kld-weight": torch.tensor(kld_weight), "kld-sent": kld_sent.sum(-1).mean(), "kld-sent-turn": kld_sent.sum() / conv_lens.sum(), "kld": loss_kld.mean() } return loss.mean(), stats
def create_loss(model, vocabs: datasets.VocabSet, kld_weight: utils.Scheduler = utils.ConstantScheduler(1.0), enable_kl=True, kl_mode="kl-mi") -> losses.Loss: assert kl_mode in {"kl", "kl-mi", "kl-mi+"} if isinstance(model, models.VHDA): return losses.VHDALoss( vocabs=vocabs, kld_weight=kld_weight, enable_kl=enable_kl, kl_mode=kl_mode ) elif isinstance(model, models.VHCR): return losses.VHCRLoss( vocabs=vocabs, enable_kl=enable_kl, kld_weight=kld_weight, kl_mode=kl_mode ) elif isinstance(model, models.HDA): return losses.HDALoss( vocabs=vocabs ) elif isinstance(model, models.VHDAWithoutGoal): return losses.VHDAWithoutGoalLoss( vocabs=vocabs, kld_weight=kld_weight, enable_kl=enable_kl, kl_mode=kl_mode ) elif isinstance(model, models.VHDAWithoutGoalAct): return losses.VHDAWithoutGoalActLoss( vocabs=vocabs, kld_weight=kld_weight, enable_kl=enable_kl, calibrate_mi="mi" in kl_mode ) elif isinstance(model, models.VHRED): return losses.VHREDLoss( vocabs=vocabs, enable_kl=enable_kl, kld_weight=kld_weight ) elif isinstance(model, models.VHUS): return losses.VHUSLoss( vocabs=vocabs, enable_kl=enable_kl, kld_weight=kld_weight ) else: raise RuntimeError(f"unsupported model: {type(model)}")
class VHDAInferencer(Inferencer): sample_scale: float = 1.0 dropout_scale: utils.Scheduler = utils.ConstantScheduler(1.0) def model_kwargs(self) -> dict: kwargs = super().model_kwargs() kwargs["sample_scale"] = self.sample_scale kwargs["dropout_scale"] = self.dropout_scale.get(self.global_step) return kwargs def on_batch_ended(self, batch: BatchData, pred: BatchData, outputs) -> utils.TensorMap: stats = dict(super().on_batch_ended(batch, pred, outputs)) dropout_scale = self.dropout_scale.get(self.global_step) stats["dropout-scale"] = torch.tensor(dropout_scale).to(self.device) return stats
def main(): args = utils.parse_args(create_parser()) if args.logging_config is not None: logging.config.dictConfig(utils.load_yaml(args.logging_config)) save_dir = pathlib.Path(args.save_dir) if (not args.overwrite and save_dir.exists() and utils.has_element(save_dir.glob("*.json"))): raise FileExistsError(f"save directory ({save_dir}) is not empty") shell = utils.ShellUtils() engine = inflect.engine() shell.mkdir(save_dir, silent=True) logger = logging.getLogger("train") utils.seed(args.seed) logger.info("loading data...") load_fn = utils.chain_func(lambda x: list(map(Dialog.from_json, x)), utils.load_json) data_dir = pathlib.Path(args.data_dir) train_data = load_fn(str(data_dir.joinpath("train.json"))) valid_data = load_fn(str(data_dir.joinpath("dev.json"))) test_data = load_fn(str(data_dir.joinpath("test.json"))) processor = datasets.DialogProcessor(sent_processor=datasets.SentProcessor( bos=True, eos=True, lowercase=True, tokenizer="space", max_len=30), boc=True, eoc=True, state_order="randomized", max_len=30) processor.prepare_vocabs( list(itertools.chain(train_data, valid_data, test_data))) utils.save_pickle(processor, save_dir.joinpath("processor.pkl")) logger.info("preparing model...") utils.save_json(utils.load_yaml(args.gen_model_path), save_dir.joinpath("model.json")) torchmodels.register_packages(models) model_cls = torchmodels.create_model_cls(models, args.gen_model_path) model: models.AbstractTDA = model_cls(processor.vocabs) model.reset_parameters() utils.report_model(logger, model) device = torch.device("cpu") if args.gpu is not None: device = torch.device(f"cuda:{args.gpu}") model = model.to(device) def create_scheduler(s): return utils.PiecewiseScheduler( [utils.Coordinate(*t) for t in eval(s)]) save_dir = pathlib.Path(args.save_dir) train_args = train.TrainArguments( model=model, train_data=tuple(train_data), valid_data=tuple(valid_data), processor=processor, device=device, save_dir=save_dir, report_every=args.report_every, batch_size=args.batch_size, valid_batch_size=args.valid_batch_size, optimizer=args.optimizer, gradient_clip=args.gradient_clip, l2norm_weight=args.l2norm_weight, learning_rate=args.learning_rate, num_epochs=args.epochs, kld_schedule=(utils.ConstantScheduler(1.0) if args.kld_schedule is None else create_scheduler(args.kld_schedule)), dropout_schedule=(utils.ConstantScheduler(1.0) if args.dropout_schedule is None else create_scheduler(args.dropout_schedule)), validate_every=args.validate_every, early_stop=args.early_stop, early_stop_criterion=args.early_stop_criterion, early_stop_patience=args.early_stop_patience, disable_kl=args.disable_kl, kl_mode=args.kl_mode) utils.save_json(train_args.to_json(), save_dir.joinpath("train-args.json")) record = train.train(train_args) utils.save_json(record.to_json(), save_dir.joinpath("final-summary.json")) eval_dir = save_dir.joinpath("eval") shell.mkdir(eval_dir, silent=True) eval_data = dict( list( filter(None, [ ("train", train_data) if "train" in args.eval_splits else None, ("dev", valid_data) if "dev" in args.eval_splits else None, ("test", test_data) if "test" in args.eval_splits else None ]))) for split, data in eval_data.items(): eval_args = evaluate.EvaluateArugments( model=model, train_data=tuple(train_data), test_data=tuple(data), processor=processor, embed_type=args.embed_type, embed_path=args.embed_path, device=device, batch_size=args.valid_batch_size, beam_size=args.beam_size, max_conv_len=args.max_conv_len, max_sent_len=args.max_sent_len) utils.save_json(eval_args.to_json(), eval_dir.joinpath(f"eval-{split}-args.json")) eval_results = evaluate.evaluate(eval_args) save_path = eval_dir.joinpath(f"eval-{split}.json") utils.save_json(eval_results, save_path) logger.info(f"'{split}' results saved to {save_path}") logger.info(f"will run {args.gen_runs} generation trials...") gen_summary = [] dst_summary = [] for gen_idx in range(1, args.gen_runs + 1): logger.info(f"running {engine.ordinal(gen_idx)} generation trial...") gen_dir = save_dir.joinpath(f"gen-{gen_idx:03d}") shell.mkdir(gen_dir, silent=True) gen_args = generate.GenerateArguments( model=model, processor=processor, data=train_data, instances=int(round(len(train_data) * args.multiplier)), batch_size=args.valid_batch_size, conv_scale=args.conv_scale, spkr_scale=args.spkr_scale, goal_scale=args.goal_scale, state_scale=args.state_scale, sent_scale=args.sent_scale, validate_dst=True, validate_unique=args.validate_unique, device=device) utils.save_json(gen_args.to_json(), gen_dir.joinpath("gen-args.json")) with torch.no_grad(): samples = generate.generate(gen_args) utils.save_json([sample.output.to_json() for sample in samples], gen_dir.joinpath("gen-out.json")) utils.save_json([sample.input.to_json() for sample in samples], gen_dir.joinpath("gen-in.json")) utils.save_lines([str(sample.log_prob) for sample in samples], gen_dir.joinpath("logprob.txt")) da_data = [sample.output for sample in samples] data = {"train": train_data, "dev": valid_data, "test": test_data} data["train"] += da_data # convert dialogs to dst dialogs data = { split: list(map(datasets.DSTDialog.from_dialog, dialogs)) for split, dialogs in data.items() } for split, dialogs in data.items(): logger.info(f"verifying '{split}' dataset...") for dialog in dialogs: dialog.compute_user_goals() dialog.validate() logger.info("preparing dst environment...") dst_processor = dst_datasets.DSTDialogProcessor( sent_processor=datasets.SentProcessor( bos=True, eos=True, lowercase=True, max_len=30)) dst_processor.prepare_vocabs(list(itertools.chain(*data.values()))) train_dataset = dst_datasets.DSTDialogDataset(dialogs=data["train"], processor=dst_processor) train_dataloader = dst_datasets.create_dataloader( train_dataset, batch_size=args.dst_batch_size, shuffle=True, pin_memory=True) dev_dataloader = dst_run.TestDataloader( dialogs=data["dev"], processor=dst_processor, max_batch_size=args.dst_batch_size) test_dataloader = dst_run.TestDataloader( dialogs=data["test"], processor=dst_processor, max_batch_size=args.dst_batch_size) logger.info("saving dst processor object...") utils.save_pickle(dst_processor, gen_dir.joinpath("processor.pkl")) torchmodels.register_packages(dst_models) dst_model_cls = torchmodels.create_model_cls(dst_pkg, args.dst_model_path) dst_model = dst_model_cls(dst_processor.vocabs) dst_model = dst_model.to(device) logger.info(str(model)) logger.info(f"number of parameters DST: " f"{utils.count_parameters(dst_model):,d}") logger.info(f"running {args.dst_runs} trials...") all_results = [] for idx in range(1, args.dst_runs + 1): logger.info(f"running {engine.ordinal(idx)} dst trial...") trial_dir = gen_dir.joinpath(f"dst-{idx:03d}") logger.info("resetting parameters...") dst_model.reset_parameters() logger.info("preparing trainer...") runner = dst_run.Runner( model=dst_model, processor=dst_processor, device=device, save_dir=trial_dir, epochs=int(round(args.dst_epochs / (1 + args.multiplier))), loss="sum", l2norm=args.dst_l2norm, gradient_clip=args.dst_gradient_clip, train_validate=False, early_stop=True, early_stop_criterion="joint-goal", early_stop_patience=None, asr_method="scaled", asr_sigmoid_sum_order="sigmoid-sum", asr_topk=5) logger.info("commencing training...") record = runner.train(train_dataloader=train_dataloader, dev_dataloader=dev_dataloader, test_fn=None) logger.info("final summary: ") logger.info(pprint.pformat(record.to_json())) utils.save_json(record.to_json(), trial_dir.joinpath("summary.json")) if not args.dst_test_asr: logger.info("commencing testing...") with torch.no_grad(): eval_results = runner.test(test_dataloader) logger.info("test results: ") logger.info(pprint.pformat(eval_results)) else: logger.info("commencing testing (asr)...") with torch.no_grad(): eval_results = runner.test_asr(test_dataloader) logger.info("test(asr) results: ") logger.info(pprint.pformat(eval_results)) eval_results["epoch"] = int(record.epoch) logger.info("test evaluation: ") logger.info(pprint.pformat(eval_results)) utils.save_json(eval_results, trial_dir.joinpath("eval.json")) all_results.append(eval_results) dst_summary.append(eval_results) logger.info("aggregating results...") summary = reduce_json(all_results) logger.info("aggregated results: ") agg_summary = pprint.pformat( {k: v["stats"]["mean"] for k, v in summary.items()}) logger.info(pprint.pformat(agg_summary)) gen_summary.append(agg_summary) utils.save_json(summary, gen_dir.joinpath("summary.json")) gen_summary = reduce_json(gen_summary) dst_summary = reduce_json(dst_summary) logger.info(f"aggregating generation trials ({args.gen_runs})...") logger.info( pprint.pformat({k: v["stats"]["mean"] for k, v in gen_summary.items()})) logger.info(f"aggregating dst trials ({args.gen_runs * args.dst_runs})...") logger.info( pprint.pformat({k: v["stats"]["mean"] for k, v in dst_summary.items()})) utils.save_json(gen_summary, save_dir.joinpath("gen-summary.json")) utils.save_json(dst_summary, save_dir.joinpath("dst-summary.json")) logger.info("done!")
class VHUSLoss(Loss): vocabs: VocabSet enable_kl: bool = True kld_weight: utils.Scheduler = utils.ConstantScheduler(1.0) _ce: ClassVar[nn.CrossEntropyLoss] = \ nn.CrossEntropyLoss(reduction="none", ignore_index=-1) _bce: ClassVar[nn.BCEWithLogitsLoss] = \ nn.BCEWithLogitsLoss(reduction="none") @property def num_asv(self): return len(self.vocabs.goal_state.asv) def compute(self, batch: BatchData, outputs, step: int = None) -> Tuple[torch.Tensor, utils.TensorMap]: step = step or 0 logit, post, prior = outputs batch_size = batch.batch_size max_conv_len = batch.max_conv_len s_logit, zstate_post, zstate_prior = \ logit["state"], post["state"], prior["state"] conv_lens, sent_lens = batch.conv_lens, batch.sent.lens1 conv_mask = utils.mask(conv_lens, max_conv_len) state_logit_mask = \ (((s_logit != float("-inf")) & (s_logit != float("inf"))) .masked_fill(~conv_mask.unsqueeze(-1), 0)) kld_state = zstate_post.kl_div(zstate_prior).masked_fill(~conv_mask, 0) s_target = utils.to_dense(idx=batch.state.value, lens=batch.state.lens1, max_size=self.num_asv) p_target = batch.speaker.value.masked_fill(~conv_mask, -1) state_loss = (self._bce(s_logit, s_target.float()).masked_fill( ~state_logit_mask, 0)).sum(-1) kld_weight = self.kld_weight.get(step) nll = state_loss + kld_state loss = state_loss + kld_weight * kld_state state_mi = \ (estimate_mi(zstate_post.view(batch_size * max_conv_len, -1)) .view(batch_size, max_conv_len).masked_fill(~conv_mask, 0).sum(-1)) stats = { "nll": nll.mean(), "state-mi": state_mi.mean(), "loss-state": state_loss.sum(-1).mean(), "loss-state-turn": state_loss.sum() / conv_lens.sum(), "loss-state-asv": state_loss.sum() / state_logit_mask.sum(), "kld-weight": torch.tensor(kld_weight), "kld-state": kld_state.sum(-1).mean(), "kld-state-turn": kld_state.sum() / conv_lens.sum(), "kld": kld_state.sum(-1).mean() } for spkr_idx, spkr in self.vocabs.speaker.i2f.items(): if spkr == "<unk>": continue spkr_mask = p_target == spkr_idx spkr_state_mask = \ state_logit_mask.masked_fill(~spkr_mask.unsqueeze(-1), 0) spkr_state_loss = state_loss.masked_fill(~spkr_mask, 0).sum() spkr_kld_state = kld_state.masked_fill(~spkr_mask, 0).sum() spkr_stats = { "loss-state": spkr_state_loss / batch_size, "loss-state-turn": spkr_state_loss / spkr_mask.sum(), "loss-state-asv": spkr_state_loss / spkr_state_mask.sum(), "kld-state": spkr_kld_state / batch_size, "kld-state-turn": spkr_kld_state / spkr_mask.sum(), } stats.update({f"{k}-{spkr}": v for k, v in spkr_stats.items()}) return loss.mean(), stats
def main(): args = utils.parse_args(create_parser()) if args.logging_config is not None: logging.config.dictConfig(utils.load_yaml(args.logging_config)) save_dir = pathlib.Path(args.save_dir) if (not args.overwrite and save_dir.exists() and utils.has_element(save_dir.glob("*.json"))): raise FileExistsError(f"save directory ({save_dir}) is not empty") shell = utils.ShellUtils() shell.mkdir(save_dir, silent=True) logger = logging.getLogger("train") utils.seed(args.seed) logger.info("loading data...") load_fn = utils.chain_func(lambda x: list(map(Dialog.from_json, x)), utils.load_json) data_dir = pathlib.Path(args.data_dir) train_data = load_fn(str(data_dir.joinpath("train.json"))) valid_data = load_fn(str(data_dir.joinpath("dev.json"))) test_data = load_fn(str(data_dir.joinpath("test.json"))) processor = datasets.DialogProcessor( sent_processor=datasets.SentProcessor( bos=True, eos=True, lowercase=True, tokenizer="space", max_len=30 ), boc=True, eoc=True, state_order="randomized", max_len=30 ) processor.prepare_vocabs( list(itertools.chain(train_data, valid_data, test_data))) utils.save_pickle(processor, save_dir.joinpath("processor.pkl")) logger.info("preparing model...") utils.save_json(utils.load_yaml(args.model_path), save_dir.joinpath("model.json")) torchmodels.register_packages(models) model_cls = torchmodels.create_model_cls(models, args.model_path) model: models.AbstractTDA = model_cls(processor.vocabs) model.reset_parameters() utils.report_model(logger, model) device = torch.device("cpu") if args.gpu is not None: device = torch.device(f"cuda:{args.gpu}") model = model.to(device) def create_scheduler(s): return utils.PiecewiseScheduler([utils.Coordinate(*t) for t in eval(s)]) save_dir = pathlib.Path(args.save_dir) train_args = train.TrainArguments( model=model, train_data=tuple(train_data), valid_data=tuple(valid_data), processor=processor, device=device, save_dir=save_dir, report_every=args.report_every, batch_size=args.batch_size, valid_batch_size=args.valid_batch_size, optimizer=args.optimizer, gradient_clip=args.gradient_clip, l2norm_weight=args.l2norm_weight, learning_rate=args.learning_rate, num_epochs=args.epochs, kld_schedule=(utils.ConstantScheduler(1.0) if args.kld_schedule is None else create_scheduler(args.kld_schedule)), dropout_schedule=(utils.ConstantScheduler(1.0) if args.dropout_schedule is None else create_scheduler(args.dropout_schedule)), validate_every=args.validate_every, early_stop=args.early_stop, early_stop_criterion=args.early_stop_criterion, early_stop_patience=args.early_stop_patience, disable_kl=args.disable_kl, kl_mode=args.kl_mode ) utils.save_json(train_args.to_json(), save_dir.joinpath("train-args.json")) record = train.train(train_args) utils.save_json(record.to_json(), save_dir.joinpath("final-summary.json")) eval_dir = save_dir.joinpath("eval") shell.mkdir(eval_dir, silent=True) eval_data = dict(list(filter(None, [ ("train", train_data) if "train" in args.eval_splits else None, ("dev", valid_data) if "dev" in args.eval_splits else None, ("test", test_data) if "test" in args.eval_splits else None ]))) for split, data in eval_data.items(): eval_args = evaluate.EvaluateArugments( model=model, train_data=tuple(train_data), test_data=tuple(data), processor=processor, embed_type=args.embed_type, embed_path=args.embed_path, device=device, batch_size=args.valid_batch_size, beam_size=args.beam_size, max_conv_len=args.max_conv_len, max_sent_len=args.max_sent_len ) utils.save_json(eval_args.to_json(), eval_dir.joinpath(f"eval-{split}-args.json")) with torch.no_grad(): eval_results = evaluate.evaluate(eval_args) save_path = eval_dir.joinpath(f"eval-{split}.json") utils.save_json(eval_results, save_path) logger.info(f"'{split}' results saved to {save_path}") logger.info("done!")
class VHDAWithoutGoalLoss(Loss): vocabs: VocabSet enable_kl: bool = True kl_mode: str = "kl" kld_weight: utils.Scheduler = utils.ConstantScheduler(1.0) _ce: ClassVar[nn.CrossEntropyLoss] = \ nn.CrossEntropyLoss(reduction="none", ignore_index=-1) _bce: ClassVar[nn.BCEWithLogitsLoss] = \ nn.BCEWithLogitsLoss(reduction="none") def __post_init__(self): assert self.kl_mode in {"kl", "kl-mi", "kl-mi+"} @property def num_asv(self): return len(self.vocabs.goal_state.asv) def compute(self, batch: BatchData, outputs, step: int = None) -> Tuple[torch.Tensor, utils.TensorMap]: step = step or 0 logit, post, prior = outputs batch_size = batch.batch_size max_conv_len = batch.max_conv_len max_sent_len = batch.max_sent_len max_goal_len = batch.max_goal_len max_state_len = batch.max_state_len w_logit, p_logit, s_logit = \ (logit[k] for k in ("sent", "speaker", "state")) zconv_post, zstate_post, zsent_post, zspkr_post = \ (post[k] for k in ("conv", "state", "sent", "speaker")) zconv_prior, zstate_prior, zsent_prior, zspkr_prior = \ (prior[k] for k in ("conv", "state", "sent", "speaker")) conv_lens, sent_lens = batch.conv_lens, batch.sent.lens1 conv_mask = utils.mask(conv_lens, max_conv_len) sent_lens = sent_lens.masked_fill(~conv_mask, 0) sent_mask = utils.mask(sent_lens, max_sent_len) state_logit_mask = \ (((s_logit != float("-inf")) & (s_logit != float("inf"))) .masked_fill(~conv_mask.unsqueeze(-1), 0)) kld_conv = zconv_post.kl_div() kld_state = zstate_post.kl_div(zstate_prior).masked_fill(~conv_mask, 0) kld_sent = zsent_post.kl_div(zsent_prior).masked_fill(~conv_mask, 0) kld_spkr = zspkr_post.kl_div(zspkr_prior).masked_fill(~conv_mask, 0) w_target = (batch.sent.value.masked_fill(~sent_mask, -1).view(-1, max_sent_len))[..., 1:] s_target = utils.to_dense(idx=batch.state.value, lens=batch.state.lens1, max_size=self.num_asv) p_target = batch.speaker.value.masked_fill(~conv_mask, -1) state_loss = (self._bce(s_logit, s_target.float()).masked_fill( ~state_logit_mask, 0)).sum(-1) spkr_loss = self._ce(p_logit.view(-1, self.vocabs.num_speakers), p_target.view(-1)).view(batch_size, max_conv_len) sent_loss = self._ce( w_logit[:, :, :-1].contiguous().view(-1, len(self.vocabs.word)), w_target.contiguous().view(-1)).view(batch_size, max_conv_len, -1).sum(-1) kld_weight = self.kld_weight.get(step) loss_kld = (kld_conv + kld_sent.sum(-1) + kld_state.sum(-1) + kld_spkr.sum(-1)) loss_recon = (sent_loss.sum(-1) + state_loss.sum(-1) + spkr_loss.sum(-1)) nll = loss_recon + loss_kld conv_mi = estimate_mi(zconv_post) sent_mi = \ (estimate_mi(zsent_post.view(batch_size * max_conv_len, -1)) .view(batch_size, max_conv_len).masked_fill(~conv_mask, 0).sum(-1)) spkr_mi = \ (estimate_mi(zspkr_post.view(batch_size * max_conv_len, -1)) .view(batch_size, max_conv_len).masked_fill(~conv_mask, 0).sum(-1)) state_mi = \ (estimate_mi(zstate_post.view(batch_size * max_conv_len, -1)) .view(batch_size, max_conv_len).masked_fill(~conv_mask, 0).sum(-1)) if self.enable_kl: if self.kl_mode == "kl-mi": loss = loss_recon + kld_weight * (loss_kld - conv_mi) elif self.kl_mode == "kl-mi+": loss = loss_recon + kld_weight * (loss_kld - conv_mi - sent_mi - spkr_mi - state_mi) else: loss = loss_recon + kld_weight * loss_kld else: loss = loss_recon stats = { "nll": nll.mean(), "conv-mi": conv_mi.mean(), "sent-mi": sent_mi.mean(), "state-mi": state_mi.mean(), "spkr-mi": spkr_mi.mean(), "loss": loss.mean(), "loss-recon": loss_recon.mean(), "loss-sent": sent_loss.sum(-1).mean(), "loss-sent-turn": sent_loss.sum() / conv_lens.sum(), "loss-sent-word": sent_loss.sum() / sent_lens.sum(), "ppl-turn": (sent_loss.sum() / conv_lens.sum()).exp(), "ppl-word": (sent_loss.sum() / sent_lens.sum()).exp(), "loss-state": state_loss.sum(-1).mean(), "loss-state-turn": state_loss.sum() / conv_lens.sum(), "loss-state-asv": state_loss.sum() / state_logit_mask.sum(), "loss-spkr": spkr_loss.sum(-1).mean(), "loss-spkr-turn": spkr_loss.sum() / conv_lens.sum(), "kld-weight": torch.tensor(kld_weight), "kld-sent": kld_sent.sum(-1).mean(), "kld-sent-turn": kld_sent.sum() / conv_lens.sum(), "kld-conv": kld_conv.sum(-1).mean(), "kld-state": kld_state.sum(-1).mean(), "kld-state-turn": kld_state.sum() / conv_lens.sum(), "kld-spkr": kld_spkr.sum(-1).mean(), "kld-spkr-turn": kld_spkr.sum() / conv_lens.sum(), "kld": loss_kld.mean() } for spkr_idx, spkr in self.vocabs.speaker.i2f.items(): if spkr == "<unk>": continue spkr_mask = p_target == spkr_idx spkr_sent_lens = sent_lens.masked_fill(~spkr_mask, 0) spkr_state_mask = \ state_logit_mask.masked_fill(~spkr_mask.unsqueeze(-1), 0) spkr_sent_loss = sent_loss.masked_fill(~spkr_mask, 0).sum() spkr_state_loss = state_loss.masked_fill(~spkr_mask, 0).sum() spkr_spkr_loss = spkr_loss.masked_fill(~spkr_mask, 0).sum() spkr_kld_sent = kld_sent.masked_fill(~spkr_mask, 0).sum() spkr_kld_state = kld_state.masked_fill(~spkr_mask, 0).sum() spkr_kld_spkr = kld_spkr.masked_fill(~spkr_mask, 0).sum() spkr_stats = { "loss-sent": spkr_sent_loss / batch_size, "loss-sent-turn": spkr_sent_loss / spkr_mask.sum(), "loss-sent-word": spkr_sent_loss / spkr_sent_lens.sum(), "ppl-turn": (spkr_sent_loss / spkr_mask.sum()).exp(), "ppl-word": (spkr_sent_loss / spkr_sent_lens.sum()).exp(), "loss-state": spkr_state_loss / batch_size, "loss-state-turn": spkr_state_loss / spkr_mask.sum(), "loss-state-asv": spkr_state_loss / spkr_state_mask.sum(), "loss-spkr": spkr_spkr_loss / batch_size, "loss-spkr-turn": spkr_spkr_loss / spkr_mask.sum(), "kld-sent": spkr_kld_sent / batch_size, "kld-sent-turn": spkr_kld_sent / spkr_mask.sum(), "kld-state": spkr_kld_state / batch_size, "kld-state-turn": spkr_kld_state / spkr_mask.sum(), "kld-spkr": spkr_kld_spkr / batch_size, "kld-spkr-turn": spkr_kld_spkr / spkr_mask.sum(), } stats.update({f"{k}-{spkr}": v for k, v in spkr_stats.items()}) return loss.mean(), stats
class VHCRLoss(Loss): vocabs: VocabSet enable_kl: bool = True kl_mode: str = "kl" kld_weight: utils.Scheduler = utils.ConstantScheduler(1.0) _ce: ClassVar[nn.CrossEntropyLoss] = \ nn.CrossEntropyLoss(reduction="none", ignore_index=-1) _bce: ClassVar[nn.BCEWithLogitsLoss] = \ nn.BCEWithLogitsLoss(reduction="none") def __post_init__(self): assert self.kl_mode in {"kl", "kl-mi", "kl-mi+"} @property def num_asv(self): return len(self.vocabs.goal_state.asv) def compute(self, batch: BatchData, outputs, step: int = None ) -> Tuple[torch.Tensor, utils.TensorMap]: step = step or 0 logit, post, prior = outputs batch_size = batch.batch_size max_conv_len = batch.max_conv_len max_sent_len = batch.max_sent_len w_logit, p_logit, g_logit, s_logit = \ (logit[k] for k in ("sent", "speaker", "goal", "state")) zconv_post, zsent_post = (post[k] for k in ("conv", "sent")) zconv_prior, zsent_prior = (prior[k] for k in ("conv", "sent")) conv_lens, sent_lens = batch.conv_lens, batch.sent.lens1 conv_mask = utils.mask(conv_lens, max_conv_len) sent_lens = sent_lens.masked_fill(~conv_mask, 0) sent_mask = utils.mask(sent_lens, max_sent_len) kld_conv = zconv_post.kl_div() kld_sent = zsent_post.kl_div(zsent_prior).masked_fill(~conv_mask, 0) w_target = (batch.sent.value .masked_fill(~sent_mask, -1) .view(-1, max_sent_len))[..., 1:] sent_loss = self._ce( w_logit[:, :, :-1].contiguous().view(-1, len(self.vocabs.word)), w_target.contiguous().view(-1) ).view(batch_size, max_conv_len, -1).sum(-1) kld_weight = self.kld_weight.get(step) loss_kld = kld_sent.sum(-1) + kld_conv loss_recon = sent_loss.sum(-1) nll = loss_recon + loss_kld conv_mi = estimate_mi(zconv_post) sent_mi = \ (estimate_mi(zsent_post.view(batch_size * max_conv_len, -1)) .view(batch_size, max_conv_len).masked_fill(~conv_mask, 0).sum(-1)) if self.enable_kl: if self.kl_mode == "kl": loss = loss_recon + kld_weight * loss_kld elif self.kl_mode == "kl-mi": loss = loss_recon + kld_weight * (loss_kld - conv_mi) elif self.kl_mode == "kl-mi+": loss = loss_recon + kld_weight * (loss_kld - conv_mi - sent_mi) else: raise ValueError(f"unexpected kl mode: {self.kl_mode}") else: loss = loss_recon stats = { "nll": nll.mean(), "conv-mi": conv_mi.mean(), "sent-mi": sent_mi.mean(), "loss": loss.mean(), "loss-recon": loss_recon.mean(), "loss-sent": sent_loss.sum(-1).mean(), "loss-sent-turn": sent_loss.sum() / conv_lens.sum(), "loss-sent-word": sent_loss.sum() / sent_lens.sum(), "ppl-turn": (sent_loss.sum() / conv_lens.sum()).exp(), "ppl-word": (sent_loss.sum() / sent_lens.sum()).exp(), "kld-weight": torch.tensor(kld_weight), "kld-sent": kld_sent.sum(-1).mean(), "kld-sent-turn": kld_sent.sum() / conv_lens.sum(), "kld-conv": kld_conv.sum(-1).mean(), "kld": loss_kld.mean() } return loss.mean(), stats
def main(self): args = self.args device = "cuda" if not args['no_cuda'] and torch.cuda.is_available( ) else "cpu" ################## Setup Dataset ################################## if args['dataset'].lower().startswith("syn"): # Manual toggle of different datasets: #d_k_a = {'feature_mode': 'default', 'assign_feat': 'id'} #enzyme_gen = DiffpoolDataset('ENZYMES', use_node_attr=True, # use_node_label=False, # mode='train', # train_ratio=0.8, # test_ratio=0.1, # **d_k_a) #clique_gen = DummyClique(600, [10, 20, 30, 40, 50], [20, 30, 40, 50, 60], # 5) #dataset = SyncPoolDataset(600, graph_dataset=clique_gen, # num_sub_graphs=10, mode='train') # Fix a random dataset pickle_in = open('sync_fix_H.pickle', 'rb') #pickle_in = open('sync_fix.pickle', 'rb') dataset = pickle.load(pickle_in) pickle_in = open('sync_fix_H_val.pickle', 'rb') dataset_val = pickle.load(pickle_in) pickle_in = open('sync_fix_H_test.pickle', 'rb') dataset_test = pickle.load(pickle_in) # Hijack node feature adversial_feature = [] for i in range(len(dataset.features)): feat = np.ones(dataset.features[i].shape) adversial_feature.append(feat) max_num_nodes_candidate = [] for ds in (dataset, dataset_val, dataset_test): max_num_nodes = max( np.array([item[0][0].shape[0] for item in ds])) max_num_nodes_candidate.append(max_num_nodes) max_num_nodes = max(max_num_nodes_candidate) #dataset.features = adversial_feature #print('manually set node features to constant') else: dataset = TUDataset(args['dataset']) max_num_nodes = max( np.array([item[0][0].shape[0] for item in dataset])) # Turn this off if dataset not pre_separated # \TODO move this to argparser pre_separated = True dataset_size = len(dataset) train_size = int(dataset_size * args['train_ratio']) val_size = dataset_size - train_size mean_num_nodes = int( np.array([item[0][0].shape[0] for item in dataset]).mean()) n_classes = int(max([item[1] for item in dataset])) + 1 skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=42) labels = np.array(list(zip(*dataset))[1]) for train_index, test_index in skf.split(np.zeros_like(labels), labels): # THIS K-FOLD IS BROKEN #\TODO Fix utils.writer = utils.CustomLogger( comment='|'.join([args['dataset'], args['logname']])) utils.writer.add_text("args", str(args)) utils.writer.log_args(args) utils.writer.log_py_file() train_val_data = torch.utils.data.Subset(dataset, train_index) train_val_labels = np.array(list(zip(*train_val_data))[1]) train_data, val_data = train_test_split(train_val_data, train_size=0.8, stratify=train_val_labels) test_data = torch.utils.data.Subset(dataset, test_index) viz_data = torch.utils.data.Subset(test_data, [0]) input_shape = int(dataset[0][0][1].shape[-1]) if pre_separated: #\TODO not k-fold ready yet! train_loader = DataLoader(dataset, batch_size=args['batch_size'], shuffle=True, collate_fn=CollateFn( max_num_nodes, device)) val_loader = DataLoader(dataset_val, batch_size=args['batch_size'], shuffle=True, collate_fn=CollateFn( max_num_nodes, device)) test_loader = DataLoader(dataset_test, batch_size=args['batch_size'], shuffle=True, collate_fn=CollateFn( max_num_nodes, device)) else: train_loader = DataLoader(train_data, batch_size=args['batch_size'], shuffle=True, collate_fn=CollateFn( max_num_nodes, device)) val_loader = DataLoader(val_data, batch_size=args['batch_size'], shuffle=False, collate_fn=CollateFn( max_num_nodes, device)) test_loader = DataLoader(test_data, batch_size=args['batch_size'], shuffle=False, collate_fn=CollateFn( max_num_nodes, device)) viz_loader = DataLoader(viz_data, batch_size=1, collate_fn=CollateFn( max_num_nodes, device)) ############### Record Config and setup model ############################### config = {} config.update(args) if args.get('pool_size', None) is not None: pool_size = args['pool_size'] else: pool_size = int(mean_num_nodes * args['pool_ratio']) for k, v in locals().copy().items(): if k in [ 'device', 'args', 'pool_size', 'input_shape', 'n_classes' ]: config[k] = v config['rtn'] = args['rtn'] config['link_pred'] = args['link_pred'] config['min_cut'] = True print("############################") print(config) if args['rtn'] > 0: tqdm.write("Using Routing Model") else: tqdm.write("Using DiffPool") model = BatchedModel(**config).to(device) print(model) ############### Optimizer and Scheduler ################################# self.optimizer = optim.Adam(model.parameters()) if config.get("scheduler", False): self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( self.optimizer) else: self.scheduler = utils.ConstantScheduler() self.vg = True for e in tqdm(range(args['epochs'])): utils.e = e if args['viz']: self.visualize(e, model, viz_loader) self.train(args, e, model, train_loader) self.val(args, e, model, val_loader) self.test(args, e, model, test_loader) utils.writer.log_epochs(e) utils.writer.log_tfboard() break