class SentLengthEvaluator(DialogEvaluator): vocabs: VocabSet _lens: List[int] = utils.private_field(default_factory=list) _lens_spkr: dict = utils.private_field(default_factory=dict) def reset(self): self._lens.clear() self._lens_spkr.clear() def update(self, samples: Sequence) -> Optional[TensorMap]: lens = [len(turn.text.split()) for sample in samples for turn in sample.output.turns] self._lens.extend(lens) for spkr in self.vocabs.speaker.f2i: if spkr == "<unk>": continue if spkr not in self._lens_spkr: self._lens_spkr[spkr] = list() lens = [len(turn.text.split()) for sample in samples for turn in sample.output.turns if turn.speaker == spkr] self._lens_spkr[spkr].extend(lens) return def get(self) -> TensorMap: stats = {"sent-len": torch.tensor(np.mean(self._lens))} for spkr in self.vocabs.speaker.f2i: if spkr == "<unk>": continue stats[f"sent-len-{spkr}"] = \ torch.tensor(np.mean(self._lens_spkr[spkr])) return stats
class DistinctEvaluator(DialogEvaluator): vocabs: VocabSet ngrams: Sequence[int] = frozenset({1, 2}) _values: Dict[int, List[float]] = utils.private_field(default_factory=dict) _values_spkr: Dict[str, Dict[int, List[float]]] = \ utils.private_field(default_factory=dict) def reset(self): self._values.clear() self._values_spkr.clear() @staticmethod def compute_distinct(tokens, n): if len(tokens) == 0: return 0.0 vocab = set(ngrams(tokens, n)) return len(vocab) / len(tokens) def compute(self, samples: Sequence, spkr=None): return { i: [ self.compute_distinct(turn.text, i) for sample in samples for turn in sample.output.turns if spkr is None or turn.speaker == spkr ] for i in self.ngrams } def update(self, samples: Sequence) -> Optional[TensorMap]: res = self.compute(samples) for i, values in res.items(): if i not in self._values: self._values[i] = list() self._values[i].extend(values) for spkr in self.vocabs.speaker.f2i: if spkr == "<unk>": continue if spkr not in self._values_spkr: self._values_spkr[spkr] = dict() res = self.compute(samples, spkr) for i, values in res.items(): if i not in self._values_spkr[spkr]: self._values_spkr[spkr][i] = list() self._values_spkr[spkr][i].extend(values) return def get(self) -> TensorMap: stats = { f"dist-{i}": torch.tensor(np.mean(vs)) for i, vs in self._values.items() } stats.update({ f"dist-{i}-{spkr}": torch.tensor(np.mean(vs)) for spkr, values in self._values_spkr.items() for i, vs in values.items() }) return stats
class DialogPreprocessor: lowercase: bool = True replace_number: Optional[str] = None tokenizer: str = "corenlp" special_chars: Optional[str] = ",.?'" _charset: set = utils.private_field(default_factory=set) _replace_token: ClassVar[Mapping[str, str]] = {"&": "and"} def __post_init__(self): self._charset.update((list(self.special_chars) if self. special_chars is not None else []) + list(string.ascii_letters)) def filter_word(self, word: str) -> Optional[str]: if not word: return if word.isdigit(): if self.replace_number is not None: word = self.replace_number return word if self.lowercase: word = word.lower() word = self._replace_token.get(word, word) if any(c not in self._charset for c in word): return return word def preprocess_sent(self, sent: str) -> str: tokens = tokenize(sent, self.tokenizer) tokens = filter(None, map(self.filter_word, tokens)) return " ".join(tokens) def preprocess_state(self, state: DialogState) -> DialogState: def preprocess_token(token): if self.lowercase: token = token.lower() return token.strip() new_state = DialogState() for asv in state: new_state.add(ActSlotValue(*map(preprocess_token, asv.values))) return new_state def preprocess_turn(self, turn: Turn) -> Turn: return Turn(text=self.preprocess_sent(turn.text), speaker=turn.speaker, goal=self.preprocess_state(turn.goal), state=self.preprocess_state(turn.state), asr={ self.preprocess_sent(sent): score for sent, score in turn.asr.items() }, meta=turn.meta) def preprocess(self, dialog: Dialog) -> Dialog: return Dialog(turns=list(map(self.preprocess_turn, dialog.turns)), meta=dialog.meta)
class InterpolateInferencer: model: models.AbstractTDA processor: DialogProcessor device: torch.device = torch.device("cpu") asv_tensor: utils.Stacked1DTensor = None _num_instances: int = utils.private_field(default=None) def __post_init__(self): if self.asv_tensor is None: self.asv_tensor = self.processor.tensorize_state_vocab( "goal_state") self.asv_tensor = self.asv_tensor.to(self.device) def prepare_data_batch(self, batch: BatchData) -> dict: return { "conv_lens": batch.conv_lens, "sent": batch.sent.value, "sent_lens": batch.sent.lens1, "speaker": batch.speaker.value, "goal": batch.goal.value, "goal_lens": batch.goal.lens1, "state": batch.state.value, "state_lens": batch.state.lens1, "asv": self.asv_tensor.value, "asv_lens": self.asv_tensor.lens } def prepare_z_batch(self, batch: torch.Tensor) -> dict: return { "zconv": batch, "asv": self.asv_tensor.value, "asv_lens": self.asv_tensor.lens } def encode(self, dataloader) -> torch.Tensor: self.model.eval() self.model.encode() zconv = [] for batch in dataloader: batch = batch.to(self.device) zconv.append(self.model(self.prepare_data_batch(batch)).mu) return torch.cat(zconv, 0) def generate(self, dataloader) -> Sequence[Dialog]: self.model.eval() self.model.decode_optimal() dialogs = [] for batch in dataloader: batch = batch.to(self.device) pred, _ = self.model(self.prepare_z_batch(batch), spkr_scale=0.0, goal_scale=1.0, state_scale=0.0, sent_scale=1.0) dialogs.extend(map(self.processor.lexicalize_global, pred)) return dialogs
class DistinctStateEvaluator(DialogEvaluator): vocabs: VocabSet _values: dict = utils.private_field(default_factory=dict) def reset(self): self._values.clear() @property def speakers(self): return set(spkr for spkr in self.vocabs.speaker.f2i if spkr != "<unk>") @staticmethod def compute_distinct(tokens): if len(tokens) == 0: return torch.tensor(0.0) return torch.tensor(len(set(tokens)) / len(tokens)) def compute(self, samples: Sequence, spkr=None): return { i: [ self.compute_distinct(turn.text, i) for sample in samples for turn in sample.output.turns if spkr is None or turn.speaker == spkr ] for i in self.ngrams } def update(self, samples: Sequence) -> Optional[TensorMap]: for sample in samples: asvs = [ asv for turn in sample.output if turn.speaker != "<unk>" for asv in turn.state ] spkr_asvs = { spkr: [ asv for turn in sample.output if turn.speaker != "<unk>" for asv in turn.state ] for spkr in self.speakers } stats = {"dist-a": self.compute_distinct(asvs)} stats.update({ f"dist-a-{spkr}": self.compute_distinct(spkr_asvs[spkr]) for spkr in self.speakers }) for k, v in stats.items(): if k not in self._values: self._values[k] = list() self._values[k].append(v.item()) return def get(self) -> Optional[TensorMap]: return {k: torch.tensor(v).mean() for k, v in self._values.items()}
class DSTDialogDataset: dialogs: Sequence[DSTDialog] processor: DSTDialogProcessor _num_turns: Sequence[int] = utils.private_field() _cum_turns: Sequence[int] = utils.private_field() def __post_init__(self): self._num_turns = [len(dialog.dst_turns) for dialog in self.dialogs] self._cum_turns = [0] + np.cumsum(self._num_turns).tolist() def __len__(self): return self._cum_turns[-1] def __getitem__(self, item): idx = bisect.bisect_right(self._cum_turns, item) - 1 if idx >= len(self.dialogs): raise ValueError(f"must be less than dataset size: " f"{item} >= {len(self)}") turn_idx = item - self._cum_turns[idx] dialog = self.dialogs[idx] return self.processor.tensorize_dst_turn(dialog.dst_turns[turn_idx])
class DialogLengthEvaluator(DialogEvaluator): _lens: List[int] = utils.private_field(default_factory=list) def reset(self): self._lens.clear() def update(self, samples: Sequence) -> Optional[TensorMap]: lens = [len(sample.output) for sample in samples] self._lens.extend(lens) return {"conv-len": torch.tensor(lens).float().mean()} def get(self) -> TensorMap: return {"conv-len": torch.tensor(self._lens).float().mean()}
class StateNoveltyEvaluator(DialogEvaluator): dataset: DialogDataset _turn_states: Set[frozenset] = utils.private_field(default_factory=set) _values: dict = utils.private_field(default_factory=dict) def __post_init__(self): self.prepare_turn_states() def prepare_turn_states(self): for dialog in self.dataset.data: for turn in dialog: state = frozenset(turn.state) self._turn_states.add(state) def reset(self): self._values.clear() @property def speakers(self): return set(spkr for spkr in self.vocabs.speaker.f2i if spkr != "<unk>") def update(self, samples: Sequence) -> Optional[TensorMap]: for sample in samples: stats = { "novel-a": torch.tensor([ frozenset(turn.state) not in self._turn_states for turn in sample.output ]).float().mean() } for k, v in stats.items(): if k not in self._values: self._values[k] = list() self._values[k].append(v.item()) return def get(self) -> Optional[TensorMap]: return {k: torch.tensor(v).mean() for k, v in self._values.items()}
class DataAdapter: _logger: logging.Logger = utils.private_field(default=None) def __post_init__(self): self._logger = logging.getLogger(self.__class__.__name__) @staticmethod def serialize_semantics(state: DialogState): def serialize_asv(asv: ActSlotValue): return {"slots": [[asv.slot, asv.value]], "act": asv.act} return list(map(serialize_asv, state)) @staticmethod def parse_semantics(data) -> DialogState: state = DialogState() for asv_data in data: for s, v in asv_data["slots"]: state.add( ActSlotValue(act=str(asv_data["act"]).strip(), slot=str(s).strip(), value=str(v).strip())) return state def load(self, path: pathlib.Path, split: str = None) -> Mapping[str, Sequence[Dialog]]: """Use `split` option to specify specific split to be loaded. The underlying implementation might not choose to support this.""" raise NotImplementedError def save_imp(self, dat: Mapping[str, Sequence[Dialog]], path: pathlib.Path): """Save data into the directory/path. May assume that the path is clean.""" raise NotImplementedError def save(self, data: Mapping[str, Sequence[Dialog]], path: pathlib.Path, overwrite: bool = False): if ((path.is_file() and path.exists() or path.is_dir() and utils.has_element(path.glob("*"))) and not overwrite): raise FileExistsError(f"file exists or directory is " f"not empty: {path}") shell = utils.ShellUtils() shell.remove(path, recursive=True, silent=True) return self.save_imp(data, path)
class StateCountEvaluator(DialogEvaluator): vocabs: VocabSet _values: dict = utils.private_field(default_factory=dict) def reset(self): self._values.clear() @property def speakers(self): return set(spkr for spkr in self.vocabs.speaker.f2i if spkr != "<unk>") def update(self, samples: Sequence) -> Optional[TensorMap]: for sample in samples: for turn in sample.output: if turn.speaker == "<unk>": continue spkr = turn.speaker stats = { "state-cnt": torch.tensor(len(turn.state)).float(), f"state-cnt-{spkr}": torch.tensor(len(turn.state)).float() } for k, v in stats.items(): if k not in self._values: self._values[k] = list() self._values[k].append(v.item()) stats = { "state-cnt-conv": torch.tensor( sum(1 for turn in sample.output for _ in turn.state)).float() } for k, v in stats.items(): if k not in self._values: self._values[k] = list() self._values[k].append(v.item()) return def get(self) -> Optional[TensorMap]: return {k: torch.tensor(v).mean() for k, v in self._values.items()}
class EvaluatingInferencer(Inferencer): evaluators: Sequence[FinegrainedEvaluator] = tuple() _requires_lexical_form: bool = utils.private_field(default=False) def __post_init__(self): super().__post_init__() self._requires_lexical_form = any(e.requires_lexical_form for e in self.evaluators) def on_run_started(self, dataloader: td.DataLoader) -> td.DataLoader: dataloader = super().on_run_started(dataloader) for evaluator in self.evaluators: evaluator.reset() return dataloader def on_batch_ended(self, batch: BatchData, pred: BatchData, outputs ) -> utils.TensorMap: stats = dict(super().on_batch_ended(batch, pred, outputs)) batch_lex, pred_lex = None, None if self._requires_lexical_form: batch_lex = list(map(self.processor.lexicalize_global, batch)) pred_lex = list(map(self.processor.lexicalize_global, pred)) with torch.no_grad(): for evaluator in self.evaluators: if evaluator.requires_lexical_form: eval_stats = evaluator.update(batch_lex, pred_lex, outputs) else: eval_stats = evaluator.update(batch, pred, outputs) stats.update(eval_stats or dict()) return stats def on_run_ended(self, stats: utils.TensorMap) -> utils.TensorMap: stats = dict(super().on_run_ended(stats)) with torch.no_grad(): for evaluator in self.evaluators: stats.update(evaluator.get() or dict()) return stats
class TestDataloader: dialogs: Sequence[DSTDialog] processor: dst_datasets.DSTDialogProcessor max_batch_size: int = 32 _collator: PaddingCollator = utils.private_field() def __post_init__(self): self._collator = PaddingCollator( frozenset(("sent", "system_acts", "belief_state", "slot", "asr_score"))) @property def total_items(self): return sum(len(dialog.dst_turns) for dialog in self.dialogs) def __len__(self): return len(self.dialogs) def create_batch(self, dialogs: Iterable[DSTDialog]): items = [] for dialog in dialogs: for turn in dialog.dst_turns: items.append(self.processor.tensorize_dst_test_turn(turn)) return DSTTestBatchData.from_dict(self._collator(items)) def __iter__(self): bucket: List[DSTDialog] = [] for dialog in self.dialogs: if (sum(len(d.dst_turns) for d in bucket) + len(dialog.dst_turns)) \ > self.max_batch_size: if bucket: yield bucket, self.create_batch(bucket) bucket = [] bucket.append(dialog) if bucket: yield bucket, self.create_batch(bucket)
class LanguageNoveltyEvaluator(DialogEvaluator): dataset: DialogDataset _sents: Set[str] = utils.private_field(default_factory=set) _bigrams: Set[Tuple[str, str]] = utils.private_field(default_factory=set) _trigrams: Set[Tuple[str, str, str]] = \ utils.private_field(default_factory=set) _spkr_bigrams: Dict[str, Set[Tuple[str, str]]] = \ utils.private_field(default_factory=dict) _spkr_trigrams: Dict[str, Set[Tuple[str, str, str]]] = \ utils.private_field(default_factory=dict) _spkr_sents: Dict[str, Set[str]] = utils.private_field(default_factory=dict) _values: Dict[str, List[float]] = \ utils.private_field(default_factory=dict) def __post_init__(self): self.prepare_ngrams() @property def speakers(self): return set(spkr for spkr in self.vocabs.speaker.f2i if spkr != "<unk>") @property def vocabs(self): return self.dataset.processor.vocabs def prepare_ngrams(self): for dialog in self.dataset.data: for turn in dialog.turns: spkr = turn.speaker if spkr == "<unk>": continue if spkr not in self._spkr_bigrams: self._spkr_bigrams[spkr] = set() self._spkr_trigrams[spkr] = set() self._spkr_sents[spkr] = set() tokens = \ self.dataset.processor.sent_processor.process(turn.text) tokens = utils.lstrip(tokens, "<bos>") tokens = utils.rstrip(tokens, "<eos>") for bigram in nltk.bigrams(tokens): self._bigrams.add(tuple(bigram)) self._spkr_bigrams[spkr].add(tuple(bigram)) for trigram in nltk.ngrams(tokens, 3): self._trigrams.add(tuple(trigram)) self._spkr_trigrams[spkr].add(tuple(trigram)) sent = " ".join(tokens) self._sents.add(sent) self._spkr_sents[spkr].add(sent) def reset(self): self._values.clear() def compute(self, text: str): stats = { "novel-2": torch.tensor(0.0), "novel-3": torch.tensor(0.0), "novel-utt": torch.tensor(0.0) } tokens = text.split() bigrams = list(map(tuple, nltk.bigrams(tokens))) trigrams = list(map(tuple, nltk.trigrams(tokens))) if bigrams: stats["novel-2"] = \ (torch.tensor([w not in self._bigrams for w in bigrams]) .float().mean()) if trigrams: stats["novel-3"] = \ (torch.tensor([w not in self._trigrams for w in trigrams]) .float().mean()) if text: stats["novel-utt"] = torch.tensor(text not in self._sents).float() return stats def compute_spkr(self, text: str, spkr: str): stats = { "novel-2": torch.tensor(0.0), "novel-3": torch.tensor(0.0), "novel-utt": torch.tensor(0.0) } tokens = text.split() bigrams = list(map(tuple, nltk.bigrams(tokens))) trigrams = list(map(tuple, nltk.trigrams(tokens))) if bigrams: stats["novel-2"] = \ (torch.tensor([w not in self._spkr_bigrams[spkr] for w in bigrams]) .float().mean()) if trigrams: stats["novel-3"] = \ (torch.tensor([w not in self._spkr_trigrams[spkr] for w in trigrams]) .float().mean()) if text: stats["novel-utt"] = \ torch.tensor(text not in self._spkr_sents[spkr]).float() return stats def update(self, samples: Sequence) -> Optional[TensorMap]: for sample in samples: for turn in sample.output: spkr = turn.speaker stats = self.compute(turn.text) if spkr != "<unk>": stats.update({ f"{k}-{spkr}": v for k, v in self.compute_spkr(turn.text, spkr).items() }) for k, v in stats.items(): if k not in self._values: self._values[k] = list() self._values[k].append(v.item()) return def get(self) -> Optional[TensorMap]: return {k: torch.tensor(v).mean() for k, v in self._values.items()}
class Runner: model: dst.AbstractDialogStateTracker processor: dst_datasets.DSTDialogProcessor save_dir: pathlib.Path = pathlib.Path("out") device: torch.device = torch.device("cpu") epochs: int = 30 scheduler: Callable[[op.Optimizer], op.lr_scheduler._LRScheduler] = None loss: str = "sum" gradient_clip: Optional[float] = None l2norm: Optional[float] = None train_validate: bool = False inepoch_report_chance: float = 0.1 early_stop: bool = False early_stop_criterion: str = "joint-goal" early_stop_patience: Optional[int] = None asr_method: str = "score" asr_sigmoid_sum_order: str = "sum-sigmoid" asr_topk: Optional[int] = None _logger: logging.Logger = utils.private_field(default=None) _user_tensor: utils.Stacked1DTensor = utils.private_field(default=None) _wizard_tensor: utils.Stacked1DTensor = utils.private_field(default=None) _bce: nn.BCEWithLogitsLoss = utils.private_field(default=None) _record: Record = utils.private_field(default=None) def __post_init__(self): self._logger = logging.getLogger(self.__class__.__name__) self._user_tensor = self.processor.tensorize_state_vocab( speaker="user", # tensorizer=self.processor.tensorize_turn_label_asv ) self._user_tensor = self._user_tensor.to(self.device) self._wizard_tensor = self.processor.tensorize_state_vocab( speaker="wizard" ) self._wizard_tensor = self._wizard_tensor.to(self.device) self._bce = nn.BCEWithLogitsLoss(reduction="none") utils.ShellUtils().mkdir(self.save_dir, True) assert self.asr_sigmoid_sum_order in {"sigmoid-sum", "sum-sigmoid"} @property def vocabs(self): return self.processor.vocabs @property def criterion(self): return (self.early_stop_criterion.lstrip("~"), self.early_stop_criterion.startswith("~")) def prepare_system_acts(self, s: utils.Stacked1DTensor): s, s_lens = s.value, s.lens if s_lens.max().item() == 0: return (self._wizard_tensor.value[s], s_lens, self._wizard_tensor.lens[s]) s_lats = s_lens s, s_lens = self._wizard_tensor.value[s], self._wizard_tensor.lens[s] return s[..., :s_lens.max()], s_lats, s_lens def compute_loss(self, batch: DSTBatchData, ontology): loss = None for act_slot, (ont_idx, ont) in ontology.items(): as_idx = torch.tensor(self.vocabs.speaker_state["user"] .act_slot[act_slot]).to(self.device) ont_idx, ont = ont_idx.to(self.device), ont.to(self.device) logit = self.model( as_idx, *batch.sent.tensors, *self.prepare_system_acts(batch.system_acts), *ont.tensors ) # ont_idx: [num_ont] -> [batch_size x num_ont x state_lat] # s: [batch_size x state_lat] -> # [batch_size x num_ont x state_lat] # target: [batch_size x num_ont] s = batch.belief_state target = \ ((ont_idx.unsqueeze(0).unsqueeze(-1) == s.value.unsqueeze(1)) .masked_fill(~utils.mask(s.lens).unsqueeze(1), 0).any(-1)) current_loss = self._bce(logit, target.float()) if self.loss == "mean": current_loss = current_loss.mean(-1) elif self.loss == "sum": current_loss = current_loss.sum(-1) else: raise ValueError(f"unsupported loss method: {self.loss}") if loss is None: loss = current_loss else: loss += current_loss return loss def make_record(self, epoch, stats): self._record = Record( epoch=epoch, value=stats.get(self.criterion[0], None), stats=stats, params={k: v.cpu().detach() for k, v in self.model.state_dict().items()} ) def predict(self, batch, ontology): pred = [list() for _ in range(batch.batch_size)] loss = None for act_slot, (ont_idx, ont) in ontology.items(): as_idx = torch.tensor(self.vocabs.speaker_state["user"] .act_slot[act_slot]).to(self.device) ont_idx, ont = ont_idx.to(self.device), ont.to(self.device) logit = self.model( as_idx, *batch.sent.tensors, *self.prepare_system_acts(batch.system_acts), *ont.tensors ) # ont_idx: [num_ont] -> [batch_size x num_ont x state_lat] # s: [batch_size x state_lat] -> # [batch_size x num_ont x state_lat] # target: [batch_size x num_ont] s = batch.belief_state target = \ ((ont_idx.unsqueeze(0).unsqueeze(-1) == s.value.unsqueeze(1)) .masked_fill(~utils.mask(s.lens).unsqueeze(1), 0).any(-1)) current_loss = self._bce(logit, target.float()) if self.loss == "mean": current_loss = current_loss.mean(-1) elif self.loss == "sum": current_loss = current_loss.sum(-1) else: raise ValueError(f"unsupported loss method: {self.loss}") if loss is None: loss = current_loss else: loss += current_loss for batch_idx, val_idx in \ (torch.sigmoid(logit) > 0.5).nonzero().tolist(): pred[batch_idx].append( (ont_idx[val_idx].item(), logit[batch_idx, val_idx])) def to_dialog_state(data: Sequence[Tuple[ActSlotValue, float]]): state = DialogState() as_map = collections.defaultdict(list) for asv, score in data: as_map[(asv.act, asv.slot)].append((asv, score)) for (act, slt), data in as_map.items(): if act == "request" and slt == "slot": state.update(asv for asv, _ in data) elif act == "inform": state.add(max(data, key=lambda x: x[1])[0]) return state pred = [[(self.processor.vocabs.speaker_state["user"].asv[idx], score) for idx, score in v] for v in pred] pred = list(map(to_dialog_state, pred)) pred_inform = [{sv.slot: sv.value for sv in p.get("inform")} for p in pred] pred_request = [{sv.value for sv in p.get("request")} for p in pred] # DSTC2: 'this' resolution pred = [ (DSTTurn(turn.wizard, turn.user.clone(inform=pi, request=pr)) .resolve_this().user.state) for turn, pi, pr in zip(batch.raw, pred_inform, pred_request) ] return loss, pred def predict_asr(self, batch, ontology): pred = [list() for _ in range(batch.batch_size)] batch_loss = [] for batch_idx, (batch_asr, score) in enumerate(self.iter_asr(batch)): loss = None if self.asr_topk is not None: score_list = sorted(enumerate(score.tolist()), key=lambda x: x[1], reverse=True) score_list = list(utils.bucket( score_list, compare_fn=lambda x, y: x[1] == y[1])) score_list = list(itertools.chain(*score_list[:self.asr_topk])) score_idx = [idx for idx, _ in score_list] score = score[score_idx] batch_asr = batch_asr[score_idx] if self.asr_method == "uniform": score.fill_(1 / len(score)) elif self.asr_method == "ones": score.fill_(1) elif self.asr_method == "scaled": max_score = score.max() if max_score.item() == 0: score.fill_(1 / len(score)) else: score = score * (1 / max_score.item()) elif self.asr_method == "score": if score.sum().item() == 0: score.fill_(1 / len(score)) else: score = score / score.sum() else: raise ValueError(f"unsupported method: {self.asr_method}") for act_slot, (ont_idx, ont) in ontology.items(): as_idx = torch.tensor(self.vocabs.speaker_state["user"] .act_slot[act_slot]).to(self.device) ont_idx, ont = ont_idx.to(self.device), ont.to(self.device) logit = logit_raw = self.model( as_idx, *batch_asr.sent.tensors, *self.prepare_system_acts(batch_asr.system_acts), *ont.tensors ) logit = torch.mm(score.unsqueeze(0), logit).squeeze(0) # ont_idx: [num_ont] -> [num_ont x state_lat] # s: [state_lat] -> [num_ont x state_lat] # target: [num_ont] s = batch_asr.belief_state[0] target = ((ont_idx.unsqueeze(-1) == s.unsqueeze(0)).any(-1)) current_loss = self._bce(logit, target.float()) if self.loss == "mean": current_loss = current_loss.mean(-1) elif self.loss == "sum": current_loss = current_loss.sum(-1) else: raise ValueError(f"unsupported loss method: {self.loss}") if loss is None: loss = current_loss else: loss += current_loss if self.asr_sigmoid_sum_order == "sum-sigmoid": current_pred = torch.sigmoid(logit) > 0.5 elif self.asr_sigmoid_sum_order == "sigmoid-sum": sigmoid = (torch.mm(score.unsqueeze(0), torch.sigmoid(logit_raw)) .squeeze(0).clamp_(0.0, 1.0)) current_pred = sigmoid > 0.5 logit = (sigmoid / (1 - sigmoid)).log() else: raise ValueError(f"unsupported order: " f"{self.asr_sigmoid_sum_order}") for (val_idx,) in current_pred.nonzero().tolist(): pred[batch_idx] \ .append((ont_idx[val_idx].item(), logit[val_idx])) batch_loss.append(loss) def to_dialog_state(data: Sequence[Tuple[ActSlotValue, float]]): state = DialogState() as_map = collections.defaultdict(list) for asv, score in data: as_map[(asv.act, asv.slot)].append((asv, score)) for (act, slt), data in as_map.items(): if act == "request" and slt == "slot": state.update(asv for asv, _ in data) elif act == "inform": state.add(max(data, key=lambda x: x[1])[0]) return state pred = [[(self.processor.vocabs.speaker_state["user"].asv[idx], score) for idx, score in v] for v in pred] pred = list(map(to_dialog_state, pred)) pred_inform = [{sv.slot: sv.value for sv in p.get("inform")} for p in pred] pred_request = [{sv.value for sv in p.get("request")} for p in pred] # DSTC2: 'this' resolution pred = [ (DSTTurn(turn.wizard, turn.user.clone(inform=pi, request=pr)) .resolve_this().user.state) for turn, pi, pr in zip(batch.raw, pred_inform, pred_request) ] return torch.stack(batch_loss), pred def train(self, train_dataloader, dev_dataloader, test_fn=None): test_fn = test_fn or self.test writer = tb.SummaryWriter(log_dir=str(self.save_dir)) ont = self.processor.tensorize_state_dict(self._user_tensor, "user") optimizer = op.Adam(p for p in self.model.parameters() if p.requires_grad) scheduler = None if self.scheduler is not None: scheduler = self.scheduler(optimizer) global_step = 0 final_stats = None for eidx in range(1, self.epochs + 1): progress_local = tqdm.tqdm( total=len(train_dataloader.dataset), dynamic_ncols=True, desc=f"training epoch-{eidx}" ) cum_stats = collections.defaultdict(float) local_step = 0 self.model.train() for batch in train_dataloader: batch = batch.to(self.device) batch_size = batch.batch_size global_step += batch_size local_step += batch_size progress_local.update(batch_size) optimizer.zero_grad() loss = self.compute_loss(batch, ont) loss_mean = loss.mean() stats = {"train-loss": loss_mean.item()} if self.l2norm is not None: l2norm = sum(p.norm(2) for p in self.model.parameters() if p.requires_grad) loss_mean += self.l2norm * l2norm if self.gradient_clip is not None: nn.utils.clip_grad_norm_( parameters=itertools.chain( *(d["params"] for d in optimizer.param_groups)), max_norm=self.gradient_clip ) loss_mean.backward() optimizer.step() for k, v in stats.items(): cum_stats[k] += v * batch_size if self.inepoch_report_chance >= random.random(): for k, v in stats.items(): writer.add_scalar(k, v, global_step) progress_local.set_postfix({"loss": loss_mean.item()}) stats = {f"{k}-epoch": v / local_step for k, v in cum_stats.items()} progress_local.close() self._logger.info(f"epoch {eidx} train summary:") self._logger.info( f"\n" f" * train-loss: {stats['train-loss-epoch']:.4f}" ) if self.train_validate: with torch.no_grad(): train_stats = test_fn(TestDataloader( dialogs=train_dataloader.dataset.dialogs, processor=self.processor, max_batch_size=train_dataloader.batch_size ), ont, "train-validate") stats.update({f"val-train-{k}": v for k, v in train_stats.items()}) self._logger.info(f"epoch {eidx} train-validation summary:") self._logger.info( f"\n" f" * val-train-loss: {train_stats['loss']:.4f}\n" f" * val-hmean: {train_stats['hmean-inform']:.4f}\n" f" * val-train-joint: {train_stats['joint-goal']:.4f}\n" f" * val-train-inform: {train_stats['turn-inform']:.4f}\n" f" * val-train-request: {train_stats['turn-request']:.4f}" ) with torch.no_grad(): val_stats = test_fn(dev_dataloader, ont, "validate") self._logger.info(f"epoch {eidx} validation summary:") self._logger.info( f"\n" f" * val-loss: {val_stats['loss']:.4f}\n" f" * val-hmean: {val_stats['hmean-inform']:.4f}\n" f" * val-joint: {val_stats['joint-goal']:.4f}\n" f" * val-inform: {val_stats['turn-inform']:.4f}\n" f" * val-request: {val_stats['turn-request']:.4f}" ) stats.update({f"val-{k}": v for k, v in val_stats.items()}) if self.early_stop: if self._record is None: self.make_record(eidx, val_stats) elif (self.early_stop_patience is not None and eidx > self._record.epoch + self.early_stop_patience): break else: crit, neg = self.criterion crit_value = val_stats[crit] if (crit_value > self._record.value) != neg: self._logger.info(f"new record made! " f"{crit}={crit_value:.3f}") self.make_record(eidx, val_stats) for k, v in stats.items(): writer.add_scalar(k, v, global_step) final_stats = stats if scheduler is not None: scheduler.step() if self._record is None and final_stats is not None: self.make_record(self.epochs, final_stats) if self._record is not None: self.model.load_state_dict(self._record.params) return self._record @staticmethod def evaluate_batch(pred: Sequence[DialogState], gold: Sequence[DialogState]): pred_inform = [s.get("inform") for s in pred] gold_inform = [s.get("inform") for s in gold] pred_request = [s.get("request") for s in pred] gold_request = [s.get("request") for s in gold] def all_eq(p, q): return [x == y for x, y in zip(p, q)] stats = { "turn-acc": np.mean(all_eq(pred, gold)), "turn-inform": np.mean(all_eq(pred_inform, gold_inform)), "turn-request": np.mean(all_eq(pred_request, gold_request)) } stats["hmean-inform"] = harmonic_mean(np.array([ stats["joint-goal"], stats["turn-inform"] ])) return stats @staticmethod def evaluate_dialog(pred: Sequence[DialogState], gold: Sequence[DialogState]): pred_inform = [s.get("inform") for s in pred] gold_inform = [s.get("inform") for s in gold] pred_request = [s.get("request") for s in pred] gold_request = [s.get("request") for s in gold] def cum_sum(data: Sequence[DialogAct]): goal = dict() ret = [] for da in data: for sv in da: goal[sv.slot] = sv.value ret.append(copy.copy(goal)) return ret pred_goal = cum_sum(pred_inform) gold_goal = cum_sum(gold_inform) def all_eq(p, q): return [x == y for x, y in zip(p, q)] stats = { "joint-goal": np.mean(all_eq(pred_goal, gold_goal)), "turn-acc": np.mean(all_eq(pred, gold)), "turn-inform": np.mean(all_eq(pred_inform, gold_inform)), "turn-request": np.mean(all_eq(pred_request, gold_request)) } stats["hmean-inform"] = harmonic_mean(np.array([ stats["joint-goal"], stats["turn-inform"] ])) return stats def test(self, dataloader: TestDataloader, ontology=None, mode="test"): self.model.eval() if ontology is None: ontology = (self.processor .tensorize_state_dict(self._user_tensor, "user")) cum_stats = collections.defaultdict(float) progress = tqdm.tqdm( total=dataloader.total_items, dynamic_ncols=True, desc=mode ) local_step = 0 for dialogs, batch in dataloader: batch = batch.to(self.device) batch_size = batch.batch_size local_step += batch_size progress.update(batch_size) loss, pred = self.predict(batch, ontology) loss_mean = loss.mean() pred = bucket_items(pred, list(len(d.dst_turns) for d in dialogs)) gold = [[turn.resolve_this().user.state for turn in dialog.dst_turns] for dialog in dialogs] for p, g in zip(pred, gold): for k, v in self.evaluate_dialog(p, g).items(): cum_stats[k] += v * len(p) cum_stats["loss"] += loss_mean.item() * batch_size progress.close() return {k: v / local_step for k, v in cum_stats.items()} @staticmethod def iter_asr(batch: DSTTestBatchData) -> Iterable[DSTBatchData]: asr, asr_score = batch.asr, batch.asr_score assert (utils.compare_tensors(asr.lens, asr_score.lens) and asr.size(0) == asr_score.size(0)) for i in range(asr.size(0)): asr_item = asr[i] asr_score_item = asr_score[i] num_asr = asr_item.size(0) yield DSTBatchData( sent=asr_item, system_acts=utils.pad_stack([batch.system_acts[i]] * num_asr), belief_state=utils.pad_stack([batch.belief_state[i]] * num_asr), slot=utils.pad_stack([batch.slot[i]] * num_asr), raw=[batch.raw[i]] * num_asr ), asr_score_item def test_asr(self, dataloader: TestDataloader, ontology=None, mode="test"): self.model.eval() if ontology is None: ontology = (self.processor .tensorize_state_dict(self._user_tensor, "user")) cum_stats = collections.defaultdict(float) progress = tqdm.tqdm( total=dataloader.total_items, dynamic_ncols=True, desc=mode ) local_step = 0 for dialogs, batch in dataloader: batch = batch.to(self.device) batch_size = batch.batch_size local_step += batch_size progress.update(batch_size) loss, pred = self.predict_asr(batch, ontology) loss_mean = loss.mean() pred = bucket_items(pred, list(len(d.dst_turns) for d in dialogs)) gold = [[turn.resolve_this().user.state for turn in dialog.dst_turns] for dialog in dialogs] for p, g in zip(pred, gold): for k, v in self.evaluate_dialog(p, g).items(): cum_stats[k] += v * len(p) cum_stats["loss"] += loss_mean.item() * batch_size progress.set_postfix({ "joint-goal": cum_stats["joint-goal"] / local_step}) progress.close() return {k: v / local_step for k, v in cum_stats.items()}
class EmbeddingEvaluator(DialogEvaluator): vocab: utils.Vocabulary embeds: Embeddings _emb: torch.Tensor = utils.private_field(default=None) _cache: Set[int] = utils.private_field(default_factory=set) _stats: dict = utils.private_field(default_factory=dict) _seen: int = utils.private_field(default=0) _logger: logging.Logger = utils.private_field(default=None) def __post_init__(self): self._emb = torch.zeros(len(self.vocab), self.embeds.dim, requires_grad=False) self._logger = logging.getLogger(self.__class__.__name__) @staticmethod def compare_mean(pred: torch.Tensor, gold: torch.Tensor): return cosine_sim(pred.mean(0), gold.mean(0)) @staticmethod def compare_extrema(pred: torch.Tensor, gold: torch.Tensor): def extrema(x): return x.gather(0, x.abs().max(0)[1].unsqueeze(0)).squeeze(0) return cosine_sim(extrema(pred), extrema(gold)) @staticmethod def compare_greedy(pred: torch.Tensor, gold: torch.Tensor): dim = pred.size(-1) return cosine_sim( (pred.unsqueeze(1).expand(-1, gold.size(0), -1).contiguous().view( -1, dim)), (gold.unsqueeze(0).expand( pred.size(0), -1, -1).contiguous().view(-1, dim)), dim=1).view(pred.size(0), gold.size(0)).max(0)[0].mean() def to_embeddings(self, words: torch.Tensor): self._emb = self._emb new_idx = set(e.item() for e in words.unique()) - self._cache for i in new_idx: w = self.vocab.i2f[i] if w not in self.embeds: self._logger.debug(f"word ({w}) not found in the embeddings") else: self._emb[i, :] = torch.tensor(self.embeds[w]) if new_idx: self._cache.update(new_idx) exists = torch.tensor( [self.vocab.i2f[w.item()] in self.embeds for w in words]).bool() return self._emb[words.masked_select(exists)] def _eval(self, p, g, compare_fn) -> Optional[torch.Tensor]: pe = self.to_embeddings(p) ge = self.to_embeddings(g) if not len(ge): self._logger.debug(f"entire utterance omitted as no matching " f"embeddings found for the gold sentence") return if not len(pe): self._logger.debug( f"no matching embeddings found for the pred sentence; " f"giving a zero score for the similarity score") return torch.tensor(0.0) res = compare_fn(pe, ge) if (res != res).any(): self._logger.debug(f"NaN detected for pred = {pu} and gold = {gu}") return return res def reset(self): self._seen = 0 self._stats.clear() def update(self, samples: Sequence) -> Optional[TensorMap]: pred = [ torch.tensor([ self.vocab[w] for turn in sample.output.turns for w in turn.text.split() ]).long() for sample in samples ] batch = [ torch.tensor([ self.vocab[w] for turn in sample.input.turns for w in turn.text.split() ]).long() for sample in samples ] for p, g in zip(pred, batch): stats = { "emb-mean": self._eval(p, g, self.compare_mean), "emb-extrema": self._eval(p, g, self.compare_extrema), "emb-greedy": self._eval(p, g, self.compare_greedy) } for k, v in stats.items(): if v is None: continue if k not in self._stats: self._stats[k] = list() self._stats[k].append(v.item()) return def get(self) -> TensorMap: return {k: torch.tensor(v).mean() for k, v in self._stats.items()}
class VocabSetFactory: """A helper class for managing vocabularies related to task-oriented dialogues. Note that dialogue acts (act-slot-value triples) are organized by the speakers. """ tokenizer: Callable[[str], Sequence[str]] = None word: VocabularyFactory = VocabularyFactory() speaker: VocabularyFactory = VocabularyFactory() goal_cls: Callable[[], StateVocabSetFactory] = StateVocabSetFactory state_cls: Callable[[], StateVocabSetFactory] = StateVocabSetFactory goal_state_cls: Callable[[], StateVocabSetFactory] = StateVocabSetFactory _goal: StateVocabSetFactory = utils.private_field(default=None) _state: StateVocabSetFactory = utils.private_field(default=None) _goal_state: StateVocabSetFactory = utils.private_field(default=None) _goal_factories: Dict[str, StateVocabSetFactory] = \ utils.private_field(default_factory=dict) _state_factories: Dict[str, StateVocabSetFactory] = \ utils.private_field(default_factory=dict) _goal_state_factories: Dict[str, StateVocabSetFactory] = \ utils.private_field(default_factory=dict) def __post_init__(self): self._goal = self.goal_cls() self._state = self.state_cls() self._goal_state = self.goal_state_cls() def tokenize(s): return s.split() self.tokenizer = self.tokenizer or tokenize def update_turn(self, turn: Turn): tokenizer = self.tokenizer self.word.update(tokenizer(turn.text)) for sent in turn.asr: self.word.update(tokenizer(sent)) for asv in turn.state: self.word.update(tokenizer(asv.act)) self.word.update(tokenizer(asv.slot)) self.word.update(tokenizer(asv.value)) for asv in turn.goal: self.word.update(tokenizer(asv.act)) self.word.update(tokenizer(asv.slot)) self.word.update(tokenizer(asv.value)) self.speaker.update((turn.speaker,)) self._goal.update(turn.goal) self._state.update(turn.state) self._goal_state.update(turn.goal) self._goal_state.update(turn.state) if turn.speaker not in self._state_factories: self._state_factories[turn.speaker] = self.state_cls() self._state_factories[turn.speaker].update(turn.state) if turn.speaker not in self._goal_factories: self._goal_factories[turn.speaker] = self.goal_cls() self._goal_factories[turn.speaker].update(turn.goal) if turn.speaker not in self._goal_state_factories: self._goal_state_factories[turn.speaker] = self.goal_state_cls() self._goal_state_factories[turn.speaker].update(turn.goal) self._goal_state_factories[turn.speaker].update(turn.state) def update_turns(self, turns: Iterable[Turn]): for turn in turns: self.update_turn(turn) def update_words(self, words: Iterable[str]): self.word.update(words) def get_vocabs(self) -> VocabSet: return VocabSet( word=self.word.get_vocab(), speaker=self.speaker.get_vocab(), goal=self._goal.get_vocabs(), state=self._state.get_vocabs(), goal_state=self._goal_state.get_vocabs(), speaker_goal={spkr: factory.get_vocabs() for spkr, factory in self._goal_factories.items()}, speaker_state={spkr: factory.get_vocabs() for spkr, factory in self._state_factories.items()}, speaker_goal_state={s: f.get_vocabs() for s, f in self._goal_state_factories.items()} )
class LogInferencer(Inferencer): progress_stat: Optional[str] = None display_stats: Optional[Union[Set[str], Sequence[str]]] = None writer: Optional[torch.utils.tensorboard.SummaryWriter] = None stats_formatter: StatsFormatter = field(default_factory=StatsFormatter) dialog_formatter: DialogFormatter = \ field(default_factory=DialogTableFormatter) report_every: Optional[int] = None run_end_report: bool = True _tqdm: tqdm.tqdm = utils.private_field(default=None) _last_report: int = utils.private_field(default=0) _dialog_md_formatter: ClassVar[utils.DialogMarkdownFormatter] = \ utils.DialogMarkdownFormatter() def __post_init__(self): super().__post_init__() if self.display_stats is not None: self.display_stats = set(self.display_stats) def on_run_started(self, dataloader: td.DataLoader) -> td.DataLoader: dataloader = super(LogInferencer, self).on_run_started(dataloader) self._tqdm = tqdm.tqdm( total=len(dataloader.dataset), dynamic_ncols=True, desc=self.__class__.__name__, ) return dataloader def on_batch_started(self, batch: BatchData) -> BatchData: batch = super().on_batch_started(batch) self._tqdm.update(batch.batch_size) return batch def on_batch_ended(self, batch: BatchData, pred: BatchData, outputs) -> utils.TensorMap: stats = super().on_batch_ended(batch, pred, outputs) if self.progress_stat is not None and self.progress_stat in stats: self._tqdm.set_postfix( {self.progress_stat: stats[self.progress_stat].item()}) if self.report_every is not None and \ (self.global_step - self._last_report) >= self.report_every: idx = random.randint(0, batch.batch_size - 1) batch_sample = batch.raw[idx] pred_sample = self.processor.lexicalize_global(pred[idx]) self.log_dialog("input-sample", batch_sample) self.log_dialog("pred-sample", pred_sample) self.log_stats("summary", stats) self._last_report = self.global_step return stats def log_dialog(self, tag: str, dialog: Dialog): if self.writer is not None: self.writer.add_text( tag=tag, text_string=self._dialog_md_formatter.format(dialog), global_step=self.global_step) self._logger.info(f"global step {self.global_step:,d} {tag}:\n" f"{self.dialog_formatter.format(dialog)}") def log_stats(self, tag: str, stats: utils.TensorMap, prefix=None, postfix=None): def wrap_key(key): if prefix is not None: key = f"{prefix}-{key}" if postfix is not None: key = f"{key}-{postfix}" return key if self.writer is not None: for k, v in stats.items(): self.writer.add_scalar(wrap_key(k), v.item(), self.global_step) if self.display_stats is not None: display_stats = { k: v for k, v in stats.items() if k in self.display_stats } else: display_stats = stats display_stats = {wrap_key(k): v for k, v in display_stats.items()} self._logger.info(f"global step {self.global_step:,d} {tag}:\n" f"{self.stats_formatter.format(display_stats)}") def on_run_ended(self, stats: utils.TensorMap) -> utils.TensorMap: stats = super().on_run_ended(stats) if self.run_end_report: self.log_stats("run-end-summary", stats, postfix="run") self._tqdm.close() return stats
class BeamSearcher: """ Init Arguments: embedder (Callable): an embedding object that has the following call signature: Call Arguments: input (LongTensor): [N x seq_len] lens (optional, LongTensor): [N] Call Returns: output (FloatTensor): [N x seq_lenx embed_dim] cell (Callable): a rnn-cell-like object that has the following call signature: Call Arguments: input (FloatTensor): [N x embed_dim] state (optional, FloatTensor or tuple of): [N x hidden_dim], ... Call Returns: output (FloatTensor): [N x hidden_dim] next_state (FloatTensor, FloatTensor or tuple of): [N x hidden_dim], ... classifier (Callable): a feedforward-like object that has the following call signature: Call Arguments: input (FloatTensor): [N1 x ... x Nn x hidden_dim] Call Returns: output (FloatTensor): [N1 x ... x Nn x vocab_size] """ embedder: Callable cell: Callable classifier: Callable vocab: Vocabulary beam_size: int = 8 max_len: int = 30 bos: str = "<bos>" eos: str = "<eos>" _eos_idx: Optional[int] = private_field(default=None) def __post_init__(self): if self.bos not in self.vocab: raise ValueError(f"bos token {self.bos} not found in vocab.") if self.eos not in self.vocab: raise ValueError(f"eos token {self.eos} not found in vocab.") self._eos_idx = self.vocab[self.eos] def search(self, s0): """Perform beam search. Arguments: s0 (torch.FloatTensor or tuple of): initial cell state [N x hidden_dim], ... Returns: sample (torch.LongTensor): [N x beam_size x seq_len] lens (torch.LongTensor): [N x beam_size] prob (torch.FloatTensor): [N x beam_size] """ s0_sample = s0[0] if isinstance(s0, tuple) else s0 batch_size = s0_sample.size(0) logit0 = torch.zeros(len(self.vocab)).to(s0_sample).fill_(float("-inf")) logit0[self.vocab[self.bos]] = 0 word = s0_sample.new(batch_size, self.beam_size, 0).long() done = s0_sample.new(batch_size, self.beam_size).fill_(0).bool() prob = s0_sample.new(batch_size, self.beam_size).fill_(1.0) lens = s0_sample.new(batch_size, self.beam_size).fill_(0).long() s = (tuple(t.unsqueeze(1) .expand(batch_size, self.beam_size, -1).contiguous() for t in s0) if isinstance(s0, tuple) else s0) while not done.all() and lens.max() < self.max_len: seq_len = word.size(2) if not seq_len: logit = logit0 prob_prime, word_prime = torch.softmax(logit, 0).sort(0, True) prob_prime = prob_prime[:self.beam_size] word_prime = word_prime[:self.beam_size] prob = prob_prime.unsqueeze(0).expand(batch_size, -1) word = (word_prime.unsqueeze(0).unsqueeze(-1) .expand(batch_size, -1, -1)).contiguous() lens += 1 continue emb = self.embedder(word.view(-1, word.size(-1)), lens.view(-1)) emb = emb[:, -1, :].view(batch_size, self.beam_size, -1) if isinstance(s, tuple): s_flat = tuple(t.view(-1, t.size(-1)) for t in s) else: s_flat = s.view(-1, s.size(-1)) o, s_prime = self.cell(emb.view(-1, emb.size(-1)), s_flat) if isinstance(s_prime, tuple): s = tuple(t.view(batch_size, self.beam_size, -1) for t in s_prime) else: s = s_prime.view(batch_size, self.beam_size, -1) logit = self.classifier(o).view(batch_size, self.beam_size, -1) if self._eos_idx is not None: logit_eos = torch.full_like(logit, float("-inf")) logit_eos[:, :, self._eos_idx] = 0 logit = logit.masked_scatter(done.unsqueeze(-1), logit_eos) vocab_size = logit.size(-1) prob_prime = prob.unsqueeze(-1) * torch.softmax(logit, 2) prob_prime, prob_idx = \ prob_prime.view(batch_size, -1).sort(1, True) prob = prob_prime[:, :self.beam_size] prob_idx = prob_idx[:, :self.beam_size].long() beam_idx = prob_idx / vocab_size word_prime = prob_idx % vocab_size word = torch.cat([ word.gather(1, beam_idx.unsqueeze(-1).expand_as(word)), word_prime.unsqueeze(-1) ], 2) lens = lens.gather(1, beam_idx) + (~done).long() if self._eos_idx is not None: done = (word_prime == self._eos_idx) | done.gather(1, beam_idx) return word, lens, prob
class DialogStateEvaluator(FinegrainedEvaluator): vocabs: VocabSet return_update: bool = False _pred_goal: List[Stacked1DTensor] = \ utils.private_field(default_factory=list) _gold_goal: List[Stacked1DTensor] = \ utils.private_field(default_factory=list) _pred_state: List[Stacked1DTensor] = \ utils.private_field(default_factory=list) _gold_state: List[Stacked1DTensor] = \ utils.private_field(default_factory=list) _spkr: List[torch.Tensor] = utils.private_field(default_factory=list) def compute_accuracy(self, pred: DoublyStacked1DTensor, gold: DoublyStacked1DTensor, turn_mask=None) -> utils.TensorMap: batch_size = pred.size(0) pred_dense = utils.to_dense(pred.value, pred.lens1, max_size=len(self.vocabs.goal_state.asv)) gold_dense = utils.to_dense(gold.value, gold.lens1, max_size=len(self.vocabs.goal_state.asv)) crt = (pred_dense == gold_dense).all(-1) conv_mask = utils.mask(pred.lens, pred.size(1)) if turn_mask is None: turn_mask = torch.ones_like(conv_mask).bool() turn_mask = turn_mask & conv_mask crt = crt & turn_mask num_turns = turn_mask.sum() stats = { "acc": (crt | ~turn_mask).all(-1).sum().float() / batch_size, "acc-turn": crt.sum().float() / num_turns, } return stats def reset(self): self._pred_goal, self._gold_goal = list(), list() self._pred_state, self._gold_state = list(), list() self._spkr = list() def update(self, batch: BatchData, pred: BatchData, outputs) -> Optional[TensorMap]: self._pred_goal.extend(pred.goal) self._gold_goal.extend(batch.goal) self._pred_state.extend(pred.state) self._gold_state.extend(batch.state) self._spkr.extend(batch.speaker) if not self.return_update: return stats = dict() stats.update({ f"goal-{k}": v for k, v in self.compute_accuracy(pred.goal, batch.goal).items() }) stats.update({ f"state-{k}": v for k, v in self.compute_accuracy(pred.state, batch.state).items() }) for spkr_idx, spkr in self.vocabs.speaker.i2f.items(): if spkr == "<unk>": continue spkr_value = batch.speaker.value stats.update({ f"goal-{k}-{spkr}": v for k, v in self.compute_accuracy( pred.goal, batch.goal, spkr_value == spkr_idx).items() }) stats.update({ f"state-{k}-{spkr}": v for k, v in self.compute_accuracy( pred.state, batch.state, spkr_value == spkr_idx).items() }) return stats def get(self) -> TensorMap: pred_goal = utils.stack_stacked1dtensors(self._pred_goal) pred_state = utils.stack_stacked1dtensors(self._pred_state) gold_goal = utils.stack_stacked1dtensors(self._gold_goal) gold_state = utils.stack_stacked1dtensors(self._gold_state) spkr_value = utils.pad_stack(self._spkr).value stats = dict() stats.update({ f"goal-{k}": v for k, v in self.compute_accuracy(pred_goal, gold_goal).items() }) stats.update({ f"state-{k}": v for k, v in self.compute_accuracy(pred_state, gold_state).items() }) for spkr_idx, spkr in self.vocabs.speaker.i2f.items(): if spkr == "<unk>": continue stats.update({ f"goal-{k}-{spkr}": v for k, v in self.compute_accuracy( pred_goal, gold_goal, spkr_value == spkr_idx).items() }) stats.update({ f"state-{k}-{spkr}": v for k, v in self.compute_accuracy( pred_state, gold_state, spkr_value == spkr_idx).items() }) return stats
class WordEntropyEvaluator(DialogEvaluator): """arXiv:1605.06069""" dataset: DialogDataset _prob: torch.Tensor = utils.private_field(default=None) _spkr_prob: dict = utils.private_field(default=None) _values: dict = utils.private_field(default_factory=dict) def __post_init__(self): self._prob, self._spkr_prob = self.compute_unigram_prob() @property def speakers(self): return set(spkr for spkr in self.vocabs.speaker.f2i if spkr != "<unk>") @property def vocabs(self): return self.dataset.processor.vocabs def compute_unigram_prob(self): prob = torch.ones(len(self.vocabs.word)).long() # +1 smoothing spkr_prob = {spkr: torch.ones(len(self.vocabs.word)).long() for spkr in self.speakers} for dialog in self.dataset.data: for turn in dialog.turns: if turn.speaker == "<unk>": continue tokens = \ self.dataset.processor.sent_processor.process(turn.text) tokens = utils.lstrip(tokens, "<bos>") tokens = utils.rstrip(tokens, "<eos>") for word in tokens: word_idx = self.vocabs.word[word] spkr_prob[turn.speaker][word_idx] += 1 prob[word_idx] += 1 prob = prob.float() / prob.sum() spkr_prob = {spkr: p.float() / p.sum() for spkr, p in spkr_prob.items()} return prob, spkr_prob def reset(self): self._values.clear() def compute_entropy(self, text: str) -> Optional[float]: tokens = self.dataset.processor.tokenize(text) if not tokens: return { "word-ent": torch.tensor(0.0), "word-ent-sent": torch.tensor(0.0) } words, counts = zip(*collections.Counter(tokens).items()) words = [self.vocabs.word[w] for w in words] words, counts = torch.tensor(words).long(), torch.tensor(counts).float() text_prob = counts.float() / counts.sum() ent = (text_prob * self._prob[words]).sum() return { "word-ent": ent, "word-ent-sent": len(tokens) * ent } def compute_entropy_spkr(self, text: str, spkr) -> Optional[float]: tokens = self.dataset.processor.tokenize(text) if not tokens: return { "word-ent": torch.tensor(0.0), "word-ent-sent": torch.tensor(0.0) } words, counts = zip(*collections.Counter(tokens).items()) words = [self.vocabs.word[w] for w in words] words, counts = torch.tensor(words).long(), torch.tensor(counts).float() text_prob = counts.float() / counts.sum() ent = (text_prob * self._spkr_prob[spkr][words]).sum() return { "word-ent": ent, "word-ent-sent": len(tokens) * ent } def update(self, samples: Sequence) -> Optional[TensorMap]: for sample in samples: for turn in sample.output: spkr = turn.speaker sent = turn.text stats = self.compute_entropy(sent) if spkr != "<unk>": stats.update({ f"{k}-{spkr}": v for k, v in self.compute_entropy_spkr(sent, spkr).items() }) for k, v in stats.items(): if k not in self._values: self._values[k] = list() self._values[k].append(v.item()) return def get(self) -> Optional[TensorMap]: return {k: torch.tensor(v).mean() for k, v in self._values.items()}
class StateEntropyEvaluator(DialogEvaluator): """(New!)""" dataset: DialogDataset _asv_prob: torch.Tensor = utils.private_field(default=None) _spkr_asv_prob: dict = utils.private_field(default_factory=dict) _values: dict = utils.private_field(default_factory=dict) def __post_init__(self): self.compute_distributions() @property def speakers(self): return set(spkr for spkr in self.vocabs.speaker.f2i if spkr != "<unk>") @property def vocabs(self): return self.dataset.processor.vocabs def compute_distributions(self): # +1 smoothing prob = torch.ones(len(self.vocabs.state.asv)).long() spkr_prob = { spkr: torch.ones(len(self.vocabs.speaker_state[spkr].asv)).long() for spkr in self.speakers } for dialog in self.dataset.data: for turn in dialog.turns: if turn.speaker == "<unk>": continue spkr = turn.speaker for asv in turn.state: asv_idx = self.vocabs.state.asv[asv] spkr_asv_idx = self.vocabs.speaker_state[spkr].asv[asv] spkr_prob[spkr][spkr_asv_idx] += 1 prob[asv_idx] += 1 prob = prob.float() / prob.sum() spkr_prob = { spkr: p.float() / p.sum() for spkr, p in spkr_prob.items() } self._asv_prob, self._spkr_asv_prob = prob, spkr_prob def reset(self): self._values.clear() def compute_entropy(self, state: Sequence[ActSlotValue]) -> Optional[float]: if not state: return { "asv-ent": torch.tensor(0.0), "asv-ent-turn": torch.tensor(0.0) } asvs, counts = zip(*collections.Counter(state).items()) asvs = [self.vocabs.state.asv[w] for w in asvs] asvs, counts = torch.tensor(asvs).long(), torch.tensor(counts).float() text_prob = counts.float() / counts.sum() ent = (text_prob * self._asv_prob[asvs]).sum() return {"asv-ent": ent, "asv-ent-turn": len(state) * ent} def compute_entropy_spkr(self, state: Sequence[ActSlotValue], spkr: str) -> Optional[float]: if not state: return { "asv-ent": torch.tensor(0.0), "asv-ent-turn": torch.tensor(0.0) } asvs, counts = zip(*collections.Counter(state).items()) asvs = [self.vocabs.speaker_state[spkr].asv[w] for w in asvs] asvs, counts = torch.tensor(asvs).long(), torch.tensor(counts).float() text_prob = counts.float() / counts.sum() ent = (text_prob * self._spkr_asv_prob[spkr][asvs]).sum() return {"asv-ent": ent, "asv-ent-turn": len(state) * ent} def update(self, samples: Sequence) -> Optional[TensorMap]: for sample in samples: for turn_gold, turn in zip(sample.input, sample.output): spkr = turn.speaker asvs = list(turn.state) stats = self.compute_entropy(asvs) if spkr != "<unk>": stats.update({ f"{k}-{spkr}": v for k, v in self.compute_entropy_spkr(asvs, spkr).items() }) for k, v in stats.items(): if k not in self._values: self._values[k] = list() self._values[k].append(v.item()) return def get(self) -> Optional[TensorMap]: return {k: torch.tensor(v).mean() for k, v in self._values.items()}
class BLEUEvaluator(DialogEvaluator): vocabs: VocabSet _ref: List[Sequence[Sequence[str]]] = \ utils.private_field(default_factory=list) _hyp: List[Sequence[Sequence[str]]] = \ utils.private_field(default_factory=list) _hyp_spkr: dict = utils.private_field(default_factory=dict) _ref_spkr: dict = utils.private_field(default_factory=dict) _logger: logging.Logger = utils.private_field(default=None) def __post_init__(self): self._logger = logging.getLogger(self.__class__.__name__) def reset(self): self._ref.clear() self._hyp.clear() self._ref_spkr.clear() self._hyp_spkr.clear() @staticmethod def closest_rlen(ref_lens, hyp_len): idx = bisect.bisect_left(ref_lens, hyp_len) if idx == 0: return ref_lens[0] elif idx == len(ref_lens): return ref_lens[idx - 1] else: len1, len2 = ref_lens[idx - 1], ref_lens[idx] if hyp_len - len1 <= len2 - hyp_len: return len1 else: return len2 def try_compute(self, *args, **kwargs) -> float: try: return bl.sentence_bleu(*args, **kwargs) except Exception as e: self._logger.debug(f"error in bleu: {e}") return 0.0 def compute_bleu(self, ref: Sequence[Sequence[str]], hyp: Sequence[Sequence[str]]): r_lens = list(sorted(set(map(len, ref)))) chencherry = bl.SmoothingFunction() methods = ( ("smooth0", chencherry.method0), ("smooth1", chencherry.method1), ("smooth2", chencherry.method2), ("smooth3", chencherry.method3), ("smooth4", chencherry.method4), ("smooth5", chencherry.method5), ("smooth6", chencherry.method6), ("smooth7", chencherry.method7) ) if not hyp: return {name: 0.0 for name, _ in methods} stats = { name: np.mean([ (self.try_compute(ref, h, smoothing_function=method) * bl.brevity_penalty(self.closest_rlen(r_lens, len(h)), len(h))) for h in hyp ]) for name, method in methods } return stats def compute_bleu_corpus(self, refs: Sequence[Sequence[Sequence[str]]], hyps: Sequence[Sequence[Sequence[str]]]): assert len(refs) == len(hyps) stats = collections.defaultdict(float) for ref, hyp in zip(refs, hyps): for k, v in self.compute_bleu(ref, hyp).items(): stats[k] += v return {k: v / len(refs) for k, v in stats.items()} def update(self, samples: Sequence) -> Optional[TensorMap]: ref = [[turn.text.split() for turn in sample.input.turns] for sample in samples] hyp = [[turn.text.split() for turn in sample.output.turns] for sample in samples] self._ref.extend(ref) self._hyp.extend(hyp) for spkr in self.vocabs.speaker.f2i: if spkr == "<unk>": continue if spkr not in self._ref_spkr: self._ref_spkr[spkr] = list() self._hyp_spkr[spkr] = list() ref = [[turn.text.split() for turn in sample.input.turns if turn.speaker == spkr] for sample in samples] hyp = [[turn.text.split() for turn in sample.output.turns if turn.speaker == spkr] for sample in samples] self._ref_spkr[spkr].extend(ref) self._hyp_spkr[spkr].extend(hyp) def get(self) -> TensorMap: stats = {f"bleu-{k}": torch.tensor(v) for k, v in self.compute_bleu_corpus(self._ref, self._hyp).items()} for spkr in self.vocabs.speaker.f2i: if spkr == "<unk>": continue stats.update({ f"bleu-{k}-{spkr}": torch.tensor(v) for k, v in self.compute_bleu_corpus(self._ref_spkr[spkr], self._hyp_spkr[spkr]).items() }) return stats
class Generator: model: AbstractTDA processor: DialogProcessor batch_size: int = 32 device: torch.device = torch.device("cpu") global_step: int = 0 asv_tensor: utils.Stacked1DTensor = None _num_instances: int = utils.private_field(default=None) def __post_init__(self): if self.asv_tensor is None: self.asv_tensor = self.processor.tensorize_state_vocab( "goal_state") self.asv_tensor = self.asv_tensor.to(self.device) def on_run_started(self): return def on_run_ended(self, samples: Sequence[Sample], stats: TensorMap) -> Tuple[Sequence[Sample], TensorMap]: return samples, stats def on_batch_started(self, batch: BatchData) -> BatchData: return batch def validate_sample(self, sample: Sample): return True def on_batch_ended(self, samples: Sequence[Sample]) -> TensorMap: return dict() def generate_kwargs(self) -> dict: return dict() def prepare_batch(self, batch: BatchData) -> dict: return { "conv_lens": batch.conv_lens, "sent": batch.sent.value, "sent_lens": batch.sent.lens1, "speaker": batch.speaker.value, "goal": batch.goal.value, "goal_lens": batch.goal.lens1, "state": batch.state.value, "state_lens": batch.state.lens1, "asv": self.asv_tensor.value, "asv_lens": self.asv_tensor.lens } def __call__( self, data: Optional[Sequence[Dialog]] = None, num_instances: Optional[int] = None ) -> Tuple[Sequence[Sample], TensorMap]: if data is None and num_instances is None: raise ValueError(f"must provide a data source or " f"number of instances.") dataloader = None if data is not None: dataloader = create_dataloader(dataset=DialogDataset( data=data, processor=self.processor), batch_size=self.batch_size, drop_last=False, shuffle=False) if num_instances is None: num_instances = len(data) self._num_instances = num_instances self.on_run_started() dataloader = (itertools.repeat(None) if dataloader is None else itertools.cycle(dataloader)) cum_stats = collections.defaultdict(float) samples = [] for batch in dataloader: self.model.eval() if batch is None: batch_size = min(self.batch_size, num_instances - len(samples)) self.model.genconv_prior() with torch.no_grad(): pred, info = self.model( torch.tensor(batch_size).to(self.device), **self.generate_kwargs()) else: batch = batch.to(self.device) batch_size = batch.batch_size self.global_step += batch_size batch = self.on_batch_started(batch) self.model.genconv_post() with torch.no_grad(): pred, info = self.model(self.prepare_batch(batch), **self.generate_kwargs()) batch_samples = list( filter(self.validate_sample, (Sample(*args) for args in zip( map(self.processor.lexicalize_global, batch), map(self.processor.lexicalize_global, pred), info["logprob"])))) num_res = max(0, len(samples) + len(batch_samples) - num_instances) if num_res > 0: batch_samples = random.sample(batch_samples, num_instances - len(samples)) batch_size = len(batch_samples) self.global_step += batch_size stats = self.on_batch_ended(batch_samples) samples.extend(batch_samples) for k, v in stats.items(): cum_stats[k] += v * batch_size if len(samples) >= num_instances: break assert len(samples) == num_instances cum_stats = {k: v / len(samples) for k, v in cum_stats.items()} return self.on_run_ended(samples, cum_stats)
class Inferencer: model: AbstractTDA processor: DialogProcessor device: torch.device = torch.device("cpu") global_step: int = 0 asv_tensor: utils.Stacked1DTensor = None _logger: logging.Logger = utils.private_field(default=None) def __post_init__(self): self._logger = logging.getLogger(self.__class__.__name__) if self.asv_tensor is None: self.asv_tensor = self.processor.tensorize_state_vocab( "goal_state") self.asv_tensor = self.asv_tensor.to(self.device) def on_run_started(self, dataloader: td.DataLoader) -> td.DataLoader: return dataloader def on_run_ended(self, stats: utils.TensorMap) -> utils.TensorMap: return stats def on_batch_started(self, batch: BatchData) -> BatchData: return batch def on_batch_ended(self, batch: BatchData, pred: BatchData, outputs) -> utils.TensorMap: return {} def model_kwargs(self) -> dict: return {} @staticmethod def predict(batch: BatchData, outputs) -> BatchData: def predict_state(state_logit, conv_lens): batch_size, max_conv_len, num_asv = state_logit.size() mask = ((state_logit == float("-inf")) | (state_logit == float("inf"))) pred = (torch.sigmoid(state_logit.masked_fill( mask, 0)).masked_fill(mask, 0)) > 0.5 pred = utils.to_sparse(pred.view(-1, num_asv)) return utils.DoublyStacked1DTensor( value=pred.value.view(batch_size, max_conv_len, -1), lens=conv_lens, lens1=pred.lens.view(batch_size, max_conv_len)) logit, post, prior = outputs return BatchData( sent=utils.DoublyStacked1DTensor(value=torch.cat([ batch.sent.value[..., :1], logit["sent"].max(-1)[1][..., :-1] ], 2), lens=batch.sent.lens, lens1=batch.sent.lens1), speaker=utils.Stacked1DTensor(value=logit["speaker"].max(-1)[1], lens=batch.conv_lens), goal=predict_state(logit["goal"], batch.conv_lens), state=predict_state(logit["state"], batch.conv_lens), raw=batch.raw) def prepare_batch(self, batch: BatchData) -> dict: return { "conv_lens": batch.conv_lens, "sent": batch.sent.value, "sent_lens": batch.sent.lens1, "speaker": batch.speaker.value, "goal": batch.goal.value, "goal_lens": batch.goal.lens1, "state": batch.state.value, "state_lens": batch.state.lens1, "asv": self.asv_tensor.value, "asv_lens": self.asv_tensor.lens } def __call__(self, dataloader): dataloader = self.on_run_started(dataloader) cum_stats = collections.defaultdict(float) total_steps = 0 for batch in dataloader: batch = batch.to(self.device) total_steps += batch.batch_size self.global_step += batch.batch_size batch = self.on_batch_started(batch) self.model.inference() outputs = self.model(self.prepare_batch(batch), **self.model_kwargs()) pred = self.predict(batch, outputs) stats = self.on_batch_ended(batch, pred, outputs) for k, v in stats.items(): cum_stats[k] += v.detach() * batch.batch_size cum_stats = {k: v / total_steps for k, v in cum_stats.items()} return self.on_run_ended(cum_stats)
class RougeEvaluator(DialogEvaluator): vocabs: VocabSet _hyp: List[str] = utils.private_field(default_factory=list) _ref: List[str] = utils.private_field(default_factory=list) _hyp_spkr: dict = utils.private_field(default_factory=dict) _ref_spkr: dict = utils.private_field(default_factory=dict) _rouge: rouge.Rouge = utils.private_field(default_factory=rouge.Rouge) _key_map: ClassVar[Mapping[str, str]] = { "f": "f1", "p": "prec", "r": "rec" } def reset(self): self._hyp.clear() self._ref.clear() self._hyp_spkr.clear() self._ref_spkr.clear() def try_rouge(self, hyp: Sequence[str], ref: Sequence[str]): try: return { f"{k}-{self._key_map.get(k2, k2)}": v2 for k, v in self._rouge.get_scores(hyp, ref, avg=True).items() for k2, v2 in v.items() } except Exception as e: return { f"rouge-{k}-{k2}": 0.0 for k in ("1", "2", "l") for k2 in ("f1", "prec", "rec") } def update(self, samples: Sequence) -> Optional[TensorMap]: hyps, refs = list(), list() for sample in samples: hyp, ref = "", "" for hyp_turn in sample.output.turns: hyp += " " + hyp_turn.text.strip() for ref_turn in sample.input.turns: ref += " " + ref_turn.text.strip() hyp, ref = hyp.strip(), ref.strip() if not ref: continue if not hyp: hyp = "<pad>" hyps.append(hyp), refs.append(ref) self._hyp.extend(hyps) self._ref.extend(refs) for spkr in self.vocabs.speaker.f2i: if spkr == "<unk>": continue if spkr not in self._hyp_spkr: self._hyp_spkr[spkr] = list() if spkr not in self._ref_spkr: self._ref_spkr[spkr] = list() hyps, refs = list(), list() for sample in samples: hyp, ref = "", "" for hyp_turn in sample.output.turns: if hyp_turn.speaker != spkr: continue hyp += " " + hyp_turn.text.strip() for ref_turn in sample.input.turns: if ref_turn.speaker != spkr: continue ref += " " + ref_turn.text.strip() hyp, ref = hyp.strip(), ref.strip() if not ref: continue if not hyp: hyp = "<pad>" hyps.append(hyp), refs.append(ref) self._hyp_spkr[spkr].extend(hyps) self._ref_spkr[spkr].extend(refs) return def get(self) -> TensorMap: stats = self.try_rouge(self._hyp, self._ref) for spkr in self.vocabs.speaker.f2i: if spkr == "<unk>": continue stats.update({ f"{k}-{spkr}": v for k, v in self.try_rouge(self._hyp_spkr.get( spkr, list()), self._ref_spkr.get(spkr, list())).items() }) return {k: torch.tensor(v) for k, v in stats.items()}