Пример #1
0
class SentLengthEvaluator(DialogEvaluator):
    vocabs: VocabSet
    _lens: List[int] = utils.private_field(default_factory=list)
    _lens_spkr: dict = utils.private_field(default_factory=dict)

    def reset(self):
        self._lens.clear()
        self._lens_spkr.clear()

    def update(self, samples: Sequence) -> Optional[TensorMap]:
        lens = [len(turn.text.split()) for sample in samples
                for turn in sample.output.turns]
        self._lens.extend(lens)
        for spkr in self.vocabs.speaker.f2i:
            if spkr == "<unk>":
                continue
            if spkr not in self._lens_spkr:
                self._lens_spkr[spkr] = list()
            lens = [len(turn.text.split()) for sample in samples
                    for turn in sample.output.turns if turn.speaker == spkr]
            self._lens_spkr[spkr].extend(lens)
        return

    def get(self) -> TensorMap:
        stats = {"sent-len": torch.tensor(np.mean(self._lens))}
        for spkr in self.vocabs.speaker.f2i:
            if spkr == "<unk>":
                continue
            stats[f"sent-len-{spkr}"] = \
                torch.tensor(np.mean(self._lens_spkr[spkr]))
        return stats
Пример #2
0
class DistinctEvaluator(DialogEvaluator):
    vocabs: VocabSet
    ngrams: Sequence[int] = frozenset({1, 2})
    _values: Dict[int, List[float]] = utils.private_field(default_factory=dict)
    _values_spkr: Dict[str, Dict[int, List[float]]] = \
        utils.private_field(default_factory=dict)

    def reset(self):
        self._values.clear()
        self._values_spkr.clear()

    @staticmethod
    def compute_distinct(tokens, n):
        if len(tokens) == 0:
            return 0.0
        vocab = set(ngrams(tokens, n))
        return len(vocab) / len(tokens)

    def compute(self, samples: Sequence, spkr=None):
        return {
            i: [
                self.compute_distinct(turn.text, i) for sample in samples
                for turn in sample.output.turns
                if spkr is None or turn.speaker == spkr
            ]
            for i in self.ngrams
        }

    def update(self, samples: Sequence) -> Optional[TensorMap]:
        res = self.compute(samples)
        for i, values in res.items():
            if i not in self._values:
                self._values[i] = list()
            self._values[i].extend(values)
        for spkr in self.vocabs.speaker.f2i:
            if spkr == "<unk>":
                continue
            if spkr not in self._values_spkr:
                self._values_spkr[spkr] = dict()
            res = self.compute(samples, spkr)
            for i, values in res.items():
                if i not in self._values_spkr[spkr]:
                    self._values_spkr[spkr][i] = list()
                self._values_spkr[spkr][i].extend(values)
        return

    def get(self) -> TensorMap:
        stats = {
            f"dist-{i}": torch.tensor(np.mean(vs))
            for i, vs in self._values.items()
        }
        stats.update({
            f"dist-{i}-{spkr}": torch.tensor(np.mean(vs))
            for spkr, values in self._values_spkr.items()
            for i, vs in values.items()
        })
        return stats
Пример #3
0
class DialogPreprocessor:
    lowercase: bool = True
    replace_number: Optional[str] = None
    tokenizer: str = "corenlp"
    special_chars: Optional[str] = ",.?'"
    _charset: set = utils.private_field(default_factory=set)
    _replace_token: ClassVar[Mapping[str, str]] = {"&": "and"}

    def __post_init__(self):
        self._charset.update((list(self.special_chars) if self.
                              special_chars is not None else []) +
                             list(string.ascii_letters))

    def filter_word(self, word: str) -> Optional[str]:
        if not word:
            return
        if word.isdigit():
            if self.replace_number is not None:
                word = self.replace_number
            return word
        if self.lowercase:
            word = word.lower()
        word = self._replace_token.get(word, word)
        if any(c not in self._charset for c in word):
            return
        return word

    def preprocess_sent(self, sent: str) -> str:
        tokens = tokenize(sent, self.tokenizer)
        tokens = filter(None, map(self.filter_word, tokens))
        return " ".join(tokens)

    def preprocess_state(self, state: DialogState) -> DialogState:
        def preprocess_token(token):
            if self.lowercase:
                token = token.lower()
            return token.strip()

        new_state = DialogState()
        for asv in state:
            new_state.add(ActSlotValue(*map(preprocess_token, asv.values)))
        return new_state

    def preprocess_turn(self, turn: Turn) -> Turn:
        return Turn(text=self.preprocess_sent(turn.text),
                    speaker=turn.speaker,
                    goal=self.preprocess_state(turn.goal),
                    state=self.preprocess_state(turn.state),
                    asr={
                        self.preprocess_sent(sent): score
                        for sent, score in turn.asr.items()
                    },
                    meta=turn.meta)

    def preprocess(self, dialog: Dialog) -> Dialog:
        return Dialog(turns=list(map(self.preprocess_turn, dialog.turns)),
                      meta=dialog.meta)
Пример #4
0
class InterpolateInferencer:
    model: models.AbstractTDA
    processor: DialogProcessor
    device: torch.device = torch.device("cpu")
    asv_tensor: utils.Stacked1DTensor = None
    _num_instances: int = utils.private_field(default=None)

    def __post_init__(self):
        if self.asv_tensor is None:
            self.asv_tensor = self.processor.tensorize_state_vocab(
                "goal_state")
        self.asv_tensor = self.asv_tensor.to(self.device)

    def prepare_data_batch(self, batch: BatchData) -> dict:
        return {
            "conv_lens": batch.conv_lens,
            "sent": batch.sent.value,
            "sent_lens": batch.sent.lens1,
            "speaker": batch.speaker.value,
            "goal": batch.goal.value,
            "goal_lens": batch.goal.lens1,
            "state": batch.state.value,
            "state_lens": batch.state.lens1,
            "asv": self.asv_tensor.value,
            "asv_lens": self.asv_tensor.lens
        }

    def prepare_z_batch(self, batch: torch.Tensor) -> dict:
        return {
            "zconv": batch,
            "asv": self.asv_tensor.value,
            "asv_lens": self.asv_tensor.lens
        }

    def encode(self, dataloader) -> torch.Tensor:
        self.model.eval()
        self.model.encode()
        zconv = []
        for batch in dataloader:
            batch = batch.to(self.device)
            zconv.append(self.model(self.prepare_data_batch(batch)).mu)
        return torch.cat(zconv, 0)

    def generate(self, dataloader) -> Sequence[Dialog]:
        self.model.eval()
        self.model.decode_optimal()
        dialogs = []
        for batch in dataloader:
            batch = batch.to(self.device)
            pred, _ = self.model(self.prepare_z_batch(batch),
                                 spkr_scale=0.0,
                                 goal_scale=1.0,
                                 state_scale=0.0,
                                 sent_scale=1.0)
            dialogs.extend(map(self.processor.lexicalize_global, pred))
        return dialogs
Пример #5
0
class DistinctStateEvaluator(DialogEvaluator):
    vocabs: VocabSet
    _values: dict = utils.private_field(default_factory=dict)

    def reset(self):
        self._values.clear()

    @property
    def speakers(self):
        return set(spkr for spkr in self.vocabs.speaker.f2i if spkr != "<unk>")

    @staticmethod
    def compute_distinct(tokens):
        if len(tokens) == 0:
            return torch.tensor(0.0)
        return torch.tensor(len(set(tokens)) / len(tokens))

    def compute(self, samples: Sequence, spkr=None):
        return {
            i: [
                self.compute_distinct(turn.text, i) for sample in samples
                for turn in sample.output.turns
                if spkr is None or turn.speaker == spkr
            ]
            for i in self.ngrams
        }

    def update(self, samples: Sequence) -> Optional[TensorMap]:
        for sample in samples:
            asvs = [
                asv for turn in sample.output if turn.speaker != "<unk>"
                for asv in turn.state
            ]
            spkr_asvs = {
                spkr: [
                    asv for turn in sample.output if turn.speaker != "<unk>"
                    for asv in turn.state
                ]
                for spkr in self.speakers
            }
            stats = {"dist-a": self.compute_distinct(asvs)}
            stats.update({
                f"dist-a-{spkr}": self.compute_distinct(spkr_asvs[spkr])
                for spkr in self.speakers
            })
            for k, v in stats.items():
                if k not in self._values:
                    self._values[k] = list()
                self._values[k].append(v.item())
        return

    def get(self) -> Optional[TensorMap]:
        return {k: torch.tensor(v).mean() for k, v in self._values.items()}
Пример #6
0
class DSTDialogDataset:
    dialogs: Sequence[DSTDialog]
    processor: DSTDialogProcessor
    _num_turns: Sequence[int] = utils.private_field()
    _cum_turns: Sequence[int] = utils.private_field()

    def __post_init__(self):
        self._num_turns = [len(dialog.dst_turns) for dialog in self.dialogs]
        self._cum_turns = [0] + np.cumsum(self._num_turns).tolist()

    def __len__(self):
        return self._cum_turns[-1]

    def __getitem__(self, item):
        idx = bisect.bisect_right(self._cum_turns, item) - 1
        if idx >= len(self.dialogs):
            raise ValueError(f"must be less than dataset size: "
                             f"{item} >= {len(self)}")
        turn_idx = item - self._cum_turns[idx]
        dialog = self.dialogs[idx]
        return self.processor.tensorize_dst_turn(dialog.dst_turns[turn_idx])
Пример #7
0
class DialogLengthEvaluator(DialogEvaluator):
    _lens: List[int] = utils.private_field(default_factory=list)

    def reset(self):
        self._lens.clear()

    def update(self, samples: Sequence) -> Optional[TensorMap]:
        lens = [len(sample.output) for sample in samples]
        self._lens.extend(lens)
        return {"conv-len": torch.tensor(lens).float().mean()}

    def get(self) -> TensorMap:
        return {"conv-len": torch.tensor(self._lens).float().mean()}
Пример #8
0
class StateNoveltyEvaluator(DialogEvaluator):
    dataset: DialogDataset
    _turn_states: Set[frozenset] = utils.private_field(default_factory=set)
    _values: dict = utils.private_field(default_factory=dict)

    def __post_init__(self):
        self.prepare_turn_states()

    def prepare_turn_states(self):
        for dialog in self.dataset.data:
            for turn in dialog:
                state = frozenset(turn.state)
                self._turn_states.add(state)

    def reset(self):
        self._values.clear()

    @property
    def speakers(self):
        return set(spkr for spkr in self.vocabs.speaker.f2i if spkr != "<unk>")

    def update(self, samples: Sequence) -> Optional[TensorMap]:
        for sample in samples:
            stats = {
                "novel-a":
                torch.tensor([
                    frozenset(turn.state) not in self._turn_states
                    for turn in sample.output
                ]).float().mean()
            }
            for k, v in stats.items():
                if k not in self._values:
                    self._values[k] = list()
                self._values[k].append(v.item())
        return

    def get(self) -> Optional[TensorMap]:
        return {k: torch.tensor(v).mean() for k, v in self._values.items()}
Пример #9
0
class DataAdapter:
    _logger: logging.Logger = utils.private_field(default=None)

    def __post_init__(self):
        self._logger = logging.getLogger(self.__class__.__name__)

    @staticmethod
    def serialize_semantics(state: DialogState):
        def serialize_asv(asv: ActSlotValue):
            return {"slots": [[asv.slot, asv.value]], "act": asv.act}

        return list(map(serialize_asv, state))

    @staticmethod
    def parse_semantics(data) -> DialogState:
        state = DialogState()
        for asv_data in data:
            for s, v in asv_data["slots"]:
                state.add(
                    ActSlotValue(act=str(asv_data["act"]).strip(),
                                 slot=str(s).strip(),
                                 value=str(v).strip()))
        return state

    def load(self,
             path: pathlib.Path,
             split: str = None) -> Mapping[str, Sequence[Dialog]]:
        """Use `split` option to specify specific split to be loaded.
        The underlying implementation might not choose to support this."""
        raise NotImplementedError

    def save_imp(self, dat: Mapping[str, Sequence[Dialog]],
                 path: pathlib.Path):
        """Save data into the directory/path.
        May assume that the path is clean."""
        raise NotImplementedError

    def save(self,
             data: Mapping[str, Sequence[Dialog]],
             path: pathlib.Path,
             overwrite: bool = False):
        if ((path.is_file() and path.exists()
             or path.is_dir() and utils.has_element(path.glob("*")))
                and not overwrite):
            raise FileExistsError(f"file exists or directory is "
                                  f"not empty: {path}")
        shell = utils.ShellUtils()
        shell.remove(path, recursive=True, silent=True)
        return self.save_imp(data, path)
Пример #10
0
class StateCountEvaluator(DialogEvaluator):
    vocabs: VocabSet
    _values: dict = utils.private_field(default_factory=dict)

    def reset(self):
        self._values.clear()

    @property
    def speakers(self):
        return set(spkr for spkr in self.vocabs.speaker.f2i if spkr != "<unk>")

    def update(self, samples: Sequence) -> Optional[TensorMap]:
        for sample in samples:
            for turn in sample.output:
                if turn.speaker == "<unk>":
                    continue
                spkr = turn.speaker
                stats = {
                    "state-cnt": torch.tensor(len(turn.state)).float(),
                    f"state-cnt-{spkr}": torch.tensor(len(turn.state)).float()
                }
                for k, v in stats.items():
                    if k not in self._values:
                        self._values[k] = list()
                    self._values[k].append(v.item())
            stats = {
                "state-cnt-conv":
                torch.tensor(
                    sum(1 for turn in sample.output
                        for _ in turn.state)).float()
            }
            for k, v in stats.items():
                if k not in self._values:
                    self._values[k] = list()
                self._values[k].append(v.item())
        return

    def get(self) -> Optional[TensorMap]:
        return {k: torch.tensor(v).mean() for k, v in self._values.items()}
Пример #11
0
class EvaluatingInferencer(Inferencer):
    evaluators: Sequence[FinegrainedEvaluator] = tuple()
    _requires_lexical_form: bool = utils.private_field(default=False)

    def __post_init__(self):
        super().__post_init__()
        self._requires_lexical_form = any(e.requires_lexical_form
                                          for e in self.evaluators)

    def on_run_started(self, dataloader: td.DataLoader) -> td.DataLoader:
        dataloader = super().on_run_started(dataloader)
        for evaluator in self.evaluators:
            evaluator.reset()
        return dataloader

    def on_batch_ended(self, batch: BatchData, pred: BatchData, outputs
                       ) -> utils.TensorMap:
        stats = dict(super().on_batch_ended(batch, pred, outputs))
        batch_lex, pred_lex = None, None
        if self._requires_lexical_form:
            batch_lex = list(map(self.processor.lexicalize_global, batch))
            pred_lex = list(map(self.processor.lexicalize_global, pred))
        with torch.no_grad():
            for evaluator in self.evaluators:
                if evaluator.requires_lexical_form:
                    eval_stats = evaluator.update(batch_lex, pred_lex, outputs)
                else:
                    eval_stats = evaluator.update(batch, pred, outputs)
                stats.update(eval_stats or dict())
        return stats

    def on_run_ended(self, stats: utils.TensorMap) -> utils.TensorMap:
        stats = dict(super().on_run_ended(stats))
        with torch.no_grad():
            for evaluator in self.evaluators:
                stats.update(evaluator.get() or dict())
        return stats
Пример #12
0
class TestDataloader:
    dialogs: Sequence[DSTDialog]
    processor: dst_datasets.DSTDialogProcessor
    max_batch_size: int = 32
    _collator: PaddingCollator = utils.private_field()

    def __post_init__(self):
        self._collator = PaddingCollator(
            frozenset(("sent", "system_acts", "belief_state",
                       "slot", "asr_score")))

    @property
    def total_items(self):
        return sum(len(dialog.dst_turns) for dialog in self.dialogs)

    def __len__(self):
        return len(self.dialogs)

    def create_batch(self, dialogs: Iterable[DSTDialog]):
        items = []
        for dialog in dialogs:
            for turn in dialog.dst_turns:
                items.append(self.processor.tensorize_dst_test_turn(turn))
        return DSTTestBatchData.from_dict(self._collator(items))

    def __iter__(self):
        bucket: List[DSTDialog] = []
        for dialog in self.dialogs:
            if (sum(len(d.dst_turns) for d in bucket) + len(dialog.dst_turns)) \
                    > self.max_batch_size:
                if bucket:
                    yield bucket, self.create_batch(bucket)
                bucket = []
            bucket.append(dialog)
        if bucket:
            yield bucket, self.create_batch(bucket)
Пример #13
0
class LanguageNoveltyEvaluator(DialogEvaluator):
    dataset: DialogDataset
    _sents: Set[str] = utils.private_field(default_factory=set)
    _bigrams: Set[Tuple[str, str]] = utils.private_field(default_factory=set)
    _trigrams: Set[Tuple[str, str, str]] = \
        utils.private_field(default_factory=set)
    _spkr_bigrams: Dict[str, Set[Tuple[str, str]]] = \
        utils.private_field(default_factory=dict)
    _spkr_trigrams: Dict[str, Set[Tuple[str, str, str]]] = \
        utils.private_field(default_factory=dict)
    _spkr_sents: Dict[str,
                      Set[str]] = utils.private_field(default_factory=dict)
    _values: Dict[str, List[float]] = \
        utils.private_field(default_factory=dict)

    def __post_init__(self):
        self.prepare_ngrams()

    @property
    def speakers(self):
        return set(spkr for spkr in self.vocabs.speaker.f2i if spkr != "<unk>")

    @property
    def vocabs(self):
        return self.dataset.processor.vocabs

    def prepare_ngrams(self):
        for dialog in self.dataset.data:
            for turn in dialog.turns:
                spkr = turn.speaker
                if spkr == "<unk>":
                    continue
                if spkr not in self._spkr_bigrams:
                    self._spkr_bigrams[spkr] = set()
                    self._spkr_trigrams[spkr] = set()
                    self._spkr_sents[spkr] = set()
                tokens = \
                    self.dataset.processor.sent_processor.process(turn.text)
                tokens = utils.lstrip(tokens, "<bos>")
                tokens = utils.rstrip(tokens, "<eos>")
                for bigram in nltk.bigrams(tokens):
                    self._bigrams.add(tuple(bigram))
                    self._spkr_bigrams[spkr].add(tuple(bigram))
                for trigram in nltk.ngrams(tokens, 3):
                    self._trigrams.add(tuple(trigram))
                    self._spkr_trigrams[spkr].add(tuple(trigram))
                sent = " ".join(tokens)
                self._sents.add(sent)
                self._spkr_sents[spkr].add(sent)

    def reset(self):
        self._values.clear()

    def compute(self, text: str):
        stats = {
            "novel-2": torch.tensor(0.0),
            "novel-3": torch.tensor(0.0),
            "novel-utt": torch.tensor(0.0)
        }
        tokens = text.split()
        bigrams = list(map(tuple, nltk.bigrams(tokens)))
        trigrams = list(map(tuple, nltk.trigrams(tokens)))
        if bigrams:
            stats["novel-2"] = \
                (torch.tensor([w not in self._bigrams for w in bigrams])
                 .float().mean())
        if trigrams:
            stats["novel-3"] = \
                (torch.tensor([w not in self._trigrams for w in trigrams])
                 .float().mean())
        if text:
            stats["novel-utt"] = torch.tensor(text not in self._sents).float()
        return stats

    def compute_spkr(self, text: str, spkr: str):
        stats = {
            "novel-2": torch.tensor(0.0),
            "novel-3": torch.tensor(0.0),
            "novel-utt": torch.tensor(0.0)
        }
        tokens = text.split()
        bigrams = list(map(tuple, nltk.bigrams(tokens)))
        trigrams = list(map(tuple, nltk.trigrams(tokens)))
        if bigrams:
            stats["novel-2"] = \
                (torch.tensor([w not in self._spkr_bigrams[spkr]
                               for w in bigrams])
                 .float().mean())
        if trigrams:
            stats["novel-3"] = \
                (torch.tensor([w not in self._spkr_trigrams[spkr]
                               for w in trigrams])
                 .float().mean())
        if text:
            stats["novel-utt"] = \
                torch.tensor(text not in self._spkr_sents[spkr]).float()
        return stats

    def update(self, samples: Sequence) -> Optional[TensorMap]:
        for sample in samples:
            for turn in sample.output:
                spkr = turn.speaker
                stats = self.compute(turn.text)
                if spkr != "<unk>":
                    stats.update({
                        f"{k}-{spkr}": v
                        for k, v in self.compute_spkr(turn.text, spkr).items()
                    })
                for k, v in stats.items():
                    if k not in self._values:
                        self._values[k] = list()
                    self._values[k].append(v.item())
        return

    def get(self) -> Optional[TensorMap]:
        return {k: torch.tensor(v).mean() for k, v in self._values.items()}
Пример #14
0
class Runner:
    model: dst.AbstractDialogStateTracker
    processor: dst_datasets.DSTDialogProcessor
    save_dir: pathlib.Path = pathlib.Path("out")
    device: torch.device = torch.device("cpu")
    epochs: int = 30
    scheduler: Callable[[op.Optimizer], op.lr_scheduler._LRScheduler] = None
    loss: str = "sum"
    gradient_clip: Optional[float] = None
    l2norm: Optional[float] = None
    train_validate: bool = False
    inepoch_report_chance: float = 0.1
    early_stop: bool = False
    early_stop_criterion: str = "joint-goal"
    early_stop_patience: Optional[int] = None
    asr_method: str = "score"
    asr_sigmoid_sum_order: str = "sum-sigmoid"
    asr_topk: Optional[int] = None
    _logger: logging.Logger = utils.private_field(default=None)
    _user_tensor: utils.Stacked1DTensor = utils.private_field(default=None)
    _wizard_tensor: utils.Stacked1DTensor = utils.private_field(default=None)
    _bce: nn.BCEWithLogitsLoss = utils.private_field(default=None)
    _record: Record = utils.private_field(default=None)

    def __post_init__(self):
        self._logger = logging.getLogger(self.__class__.__name__)
        self._user_tensor = self.processor.tensorize_state_vocab(
            speaker="user",
            # tensorizer=self.processor.tensorize_turn_label_asv
        )
        self._user_tensor = self._user_tensor.to(self.device)
        self._wizard_tensor = self.processor.tensorize_state_vocab(
            speaker="wizard"
        )
        self._wizard_tensor = self._wizard_tensor.to(self.device)
        self._bce = nn.BCEWithLogitsLoss(reduction="none")
        utils.ShellUtils().mkdir(self.save_dir, True)
        assert self.asr_sigmoid_sum_order in {"sigmoid-sum", "sum-sigmoid"}

    @property
    def vocabs(self):
        return self.processor.vocabs

    @property
    def criterion(self):
        return (self.early_stop_criterion.lstrip("~"),
                self.early_stop_criterion.startswith("~"))

    def prepare_system_acts(self, s: utils.Stacked1DTensor):
        s, s_lens = s.value, s.lens
        if s_lens.max().item() == 0:
            return (self._wizard_tensor.value[s], s_lens,
                    self._wizard_tensor.lens[s])
        s_lats = s_lens
        s, s_lens = self._wizard_tensor.value[s], self._wizard_tensor.lens[s]
        return s[..., :s_lens.max()], s_lats, s_lens

    def compute_loss(self, batch: DSTBatchData, ontology):
        loss = None
        for act_slot, (ont_idx, ont) in ontology.items():
            as_idx = torch.tensor(self.vocabs.speaker_state["user"]
                                  .act_slot[act_slot]).to(self.device)
            ont_idx, ont = ont_idx.to(self.device), ont.to(self.device)
            logit = self.model(
                as_idx,
                *batch.sent.tensors,
                *self.prepare_system_acts(batch.system_acts),
                *ont.tensors
            )
            # ont_idx: [num_ont] -> [batch_size x num_ont x state_lat]
            # s: [batch_size x state_lat] ->
            #    [batch_size x num_ont x state_lat]
            # target: [batch_size x num_ont]
            s = batch.belief_state
            target = \
                ((ont_idx.unsqueeze(0).unsqueeze(-1) == s.value.unsqueeze(1))
                 .masked_fill(~utils.mask(s.lens).unsqueeze(1), 0).any(-1))
            current_loss = self._bce(logit, target.float())
            if self.loss == "mean":
                current_loss = current_loss.mean(-1)
            elif self.loss == "sum":
                current_loss = current_loss.sum(-1)
            else:
                raise ValueError(f"unsupported loss method: {self.loss}")
            if loss is None:
                loss = current_loss
            else:
                loss += current_loss
        return loss

    def make_record(self, epoch, stats):
        self._record = Record(
            epoch=epoch,
            value=stats.get(self.criterion[0], None),
            stats=stats,
            params={k: v.cpu().detach()
                    for k, v in self.model.state_dict().items()}
        )

    def predict(self, batch, ontology):
        pred = [list() for _ in range(batch.batch_size)]
        loss = None
        for act_slot, (ont_idx, ont) in ontology.items():
            as_idx = torch.tensor(self.vocabs.speaker_state["user"]
                                  .act_slot[act_slot]).to(self.device)
            ont_idx, ont = ont_idx.to(self.device), ont.to(self.device)
            logit = self.model(
                as_idx,
                *batch.sent.tensors,
                *self.prepare_system_acts(batch.system_acts),
                *ont.tensors
            )
            # ont_idx: [num_ont] -> [batch_size x num_ont x state_lat]
            # s: [batch_size x state_lat] ->
            #    [batch_size x num_ont x state_lat]
            # target: [batch_size x num_ont]
            s = batch.belief_state
            target = \
                ((ont_idx.unsqueeze(0).unsqueeze(-1) == s.value.unsqueeze(1))
                 .masked_fill(~utils.mask(s.lens).unsqueeze(1), 0).any(-1))
            current_loss = self._bce(logit, target.float())
            if self.loss == "mean":
                current_loss = current_loss.mean(-1)
            elif self.loss == "sum":
                current_loss = current_loss.sum(-1)
            else:
                raise ValueError(f"unsupported loss method: {self.loss}")
            if loss is None:
                loss = current_loss
            else:
                loss += current_loss
            for batch_idx, val_idx in \
                    (torch.sigmoid(logit) > 0.5).nonzero().tolist():
                pred[batch_idx].append(
                    (ont_idx[val_idx].item(), logit[batch_idx, val_idx]))

        def to_dialog_state(data: Sequence[Tuple[ActSlotValue, float]]):
            state = DialogState()
            as_map = collections.defaultdict(list)
            for asv, score in data:
                as_map[(asv.act, asv.slot)].append((asv, score))
            for (act, slt), data in as_map.items():
                if act == "request" and slt == "slot":
                    state.update(asv for asv, _ in data)
                elif act == "inform":
                    state.add(max(data, key=lambda x: x[1])[0])
            return state

        pred = [[(self.processor.vocabs.speaker_state["user"].asv[idx], score)
                 for idx, score in v] for v in pred]
        pred = list(map(to_dialog_state, pred))
        pred_inform = [{sv.slot: sv.value for sv in p.get("inform")}
                       for p in pred]
        pred_request = [{sv.value for sv in p.get("request")} for p in pred]
        # DSTC2: 'this' resolution
        pred = [
            (DSTTurn(turn.wizard, turn.user.clone(inform=pi, request=pr))
             .resolve_this().user.state)
            for turn, pi, pr in zip(batch.raw, pred_inform, pred_request)
        ]
        return loss, pred

    def predict_asr(self, batch, ontology):
        pred = [list() for _ in range(batch.batch_size)]
        batch_loss = []
        for batch_idx, (batch_asr, score) in enumerate(self.iter_asr(batch)):
            loss = None
            if self.asr_topk is not None:
                score_list = sorted(enumerate(score.tolist()),
                                    key=lambda x: x[1],
                                    reverse=True)
                score_list = list(utils.bucket(
                    score_list,
                    compare_fn=lambda x, y: x[1] == y[1]))
                score_list = list(itertools.chain(*score_list[:self.asr_topk]))
                score_idx = [idx for idx, _ in score_list]
                score = score[score_idx]
                batch_asr = batch_asr[score_idx]
            if self.asr_method == "uniform":
                score.fill_(1 / len(score))
            elif self.asr_method == "ones":
                score.fill_(1)
            elif self.asr_method == "scaled":
                max_score = score.max()
                if max_score.item() == 0:
                    score.fill_(1 / len(score))
                else:
                    score = score * (1 / max_score.item())
            elif self.asr_method == "score":
                if score.sum().item() == 0:
                    score.fill_(1 / len(score))
                else:
                    score = score / score.sum()
            else:
                raise ValueError(f"unsupported method: {self.asr_method}")
            for act_slot, (ont_idx, ont) in ontology.items():
                as_idx = torch.tensor(self.vocabs.speaker_state["user"]
                                      .act_slot[act_slot]).to(self.device)
                ont_idx, ont = ont_idx.to(self.device), ont.to(self.device)
                logit = logit_raw = self.model(
                    as_idx,
                    *batch_asr.sent.tensors,
                    *self.prepare_system_acts(batch_asr.system_acts),
                    *ont.tensors
                )
                logit = torch.mm(score.unsqueeze(0), logit).squeeze(0)
                # ont_idx: [num_ont] -> [num_ont x state_lat]
                # s: [state_lat] -> [num_ont x state_lat]
                # target: [num_ont]
                s = batch_asr.belief_state[0]
                target = ((ont_idx.unsqueeze(-1) == s.unsqueeze(0)).any(-1))
                current_loss = self._bce(logit, target.float())
                if self.loss == "mean":
                    current_loss = current_loss.mean(-1)
                elif self.loss == "sum":
                    current_loss = current_loss.sum(-1)
                else:
                    raise ValueError(f"unsupported loss method: {self.loss}")
                if loss is None:
                    loss = current_loss
                else:
                    loss += current_loss
                if self.asr_sigmoid_sum_order == "sum-sigmoid":
                    current_pred = torch.sigmoid(logit) > 0.5
                elif self.asr_sigmoid_sum_order == "sigmoid-sum":
                    sigmoid = (torch.mm(score.unsqueeze(0),
                                        torch.sigmoid(logit_raw))
                               .squeeze(0).clamp_(0.0, 1.0))
                    current_pred = sigmoid > 0.5
                    logit = (sigmoid / (1 - sigmoid)).log()
                else:
                    raise ValueError(f"unsupported order: "
                                     f"{self.asr_sigmoid_sum_order}")
                for (val_idx,) in current_pred.nonzero().tolist():
                    pred[batch_idx] \
                        .append((ont_idx[val_idx].item(), logit[val_idx]))
            batch_loss.append(loss)

        def to_dialog_state(data: Sequence[Tuple[ActSlotValue, float]]):
            state = DialogState()
            as_map = collections.defaultdict(list)
            for asv, score in data:
                as_map[(asv.act, asv.slot)].append((asv, score))
            for (act, slt), data in as_map.items():
                if act == "request" and slt == "slot":
                    state.update(asv for asv, _ in data)
                elif act == "inform":
                    state.add(max(data, key=lambda x: x[1])[0])
            return state

        pred = [[(self.processor.vocabs.speaker_state["user"].asv[idx], score)
                 for idx, score in v] for v in pred]
        pred = list(map(to_dialog_state, pred))
        pred_inform = [{sv.slot: sv.value for sv in p.get("inform")}
                       for p in pred]
        pred_request = [{sv.value for sv in p.get("request")} for p in pred]
        # DSTC2: 'this' resolution
        pred = [
            (DSTTurn(turn.wizard, turn.user.clone(inform=pi, request=pr))
             .resolve_this().user.state)
            for turn, pi, pr in zip(batch.raw, pred_inform, pred_request)
        ]
        return torch.stack(batch_loss), pred

    def train(self, train_dataloader, dev_dataloader, test_fn=None):
        test_fn = test_fn or self.test
        writer = tb.SummaryWriter(log_dir=str(self.save_dir))
        ont = self.processor.tensorize_state_dict(self._user_tensor, "user")
        optimizer = op.Adam(p for p in self.model.parameters()
                            if p.requires_grad)
        scheduler = None
        if self.scheduler is not None:
            scheduler = self.scheduler(optimizer)
        global_step = 0
        final_stats = None
        for eidx in range(1, self.epochs + 1):
            progress_local = tqdm.tqdm(
                total=len(train_dataloader.dataset),
                dynamic_ncols=True,
                desc=f"training epoch-{eidx}"
            )
            cum_stats = collections.defaultdict(float)
            local_step = 0
            self.model.train()
            for batch in train_dataloader:
                batch = batch.to(self.device)
                batch_size = batch.batch_size
                global_step += batch_size
                local_step += batch_size
                progress_local.update(batch_size)
                optimizer.zero_grad()
                loss = self.compute_loss(batch, ont)
                loss_mean = loss.mean()
                stats = {"train-loss": loss_mean.item()}
                if self.l2norm is not None:
                    l2norm = sum(p.norm(2) for p in
                                 self.model.parameters() if p.requires_grad)
                    loss_mean += self.l2norm * l2norm
                if self.gradient_clip is not None:
                    nn.utils.clip_grad_norm_(
                        parameters=itertools.chain(
                            *(d["params"] for d in optimizer.param_groups)),
                        max_norm=self.gradient_clip
                    )
                loss_mean.backward()
                optimizer.step()
                for k, v in stats.items():
                    cum_stats[k] += v * batch_size
                if self.inepoch_report_chance >= random.random():
                    for k, v in stats.items():
                        writer.add_scalar(k, v, global_step)
                    progress_local.set_postfix({"loss": loss_mean.item()})
            stats = {f"{k}-epoch": v / local_step for k, v in cum_stats.items()}
            progress_local.close()
            self._logger.info(f"epoch {eidx} train summary:")
            self._logger.info(
                f"\n"
                f"  * train-loss: {stats['train-loss-epoch']:.4f}"
            )
            if self.train_validate:
                with torch.no_grad():
                    train_stats = test_fn(TestDataloader(
                        dialogs=train_dataloader.dataset.dialogs,
                        processor=self.processor,
                        max_batch_size=train_dataloader.batch_size
                    ), ont, "train-validate")
                stats.update({f"val-train-{k}": v
                              for k, v in train_stats.items()})
                self._logger.info(f"epoch {eidx} train-validation summary:")
                self._logger.info(
                    f"\n"
                    f"  * val-train-loss: {train_stats['loss']:.4f}\n"
                    f"  * val-hmean: {train_stats['hmean-inform']:.4f}\n"
                    f"  * val-train-joint: {train_stats['joint-goal']:.4f}\n"
                    f"  * val-train-inform: {train_stats['turn-inform']:.4f}\n"
                    f"  * val-train-request: {train_stats['turn-request']:.4f}"
                )
            with torch.no_grad():
                val_stats = test_fn(dev_dataloader, ont, "validate")
            self._logger.info(f"epoch {eidx} validation summary:")
            self._logger.info(
                f"\n"
                f"  * val-loss: {val_stats['loss']:.4f}\n"
                f"  * val-hmean: {val_stats['hmean-inform']:.4f}\n"
                f"  * val-joint: {val_stats['joint-goal']:.4f}\n"
                f"  * val-inform: {val_stats['turn-inform']:.4f}\n"
                f"  * val-request: {val_stats['turn-request']:.4f}"
            )
            stats.update({f"val-{k}": v for k, v in val_stats.items()})
            if self.early_stop:
                if self._record is None:
                    self.make_record(eidx, val_stats)
                elif (self.early_stop_patience is not None and eidx >
                      self._record.epoch + self.early_stop_patience):
                    break
                else:
                    crit, neg = self.criterion
                    crit_value = val_stats[crit]
                    if (crit_value > self._record.value) != neg:
                        self._logger.info(f"new record made! "
                                          f"{crit}={crit_value:.3f}")
                        self.make_record(eidx, val_stats)
            for k, v in stats.items():
                writer.add_scalar(k, v, global_step)
            final_stats = stats
            if scheduler is not None:
                scheduler.step()
        if self._record is None and final_stats is not None:
            self.make_record(self.epochs, final_stats)
        if self._record is not None:
            self.model.load_state_dict(self._record.params)
        return self._record

    @staticmethod
    def evaluate_batch(pred: Sequence[DialogState],
                       gold: Sequence[DialogState]):
        pred_inform = [s.get("inform") for s in pred]
        gold_inform = [s.get("inform") for s in gold]
        pred_request = [s.get("request") for s in pred]
        gold_request = [s.get("request") for s in gold]

        def all_eq(p, q):
            return [x == y for x, y in zip(p, q)]

        stats = {
            "turn-acc": np.mean(all_eq(pred, gold)),
            "turn-inform": np.mean(all_eq(pred_inform, gold_inform)),
            "turn-request": np.mean(all_eq(pred_request, gold_request))
        }
        stats["hmean-inform"] = harmonic_mean(np.array([
            stats["joint-goal"],
            stats["turn-inform"]
        ]))
        return stats

    @staticmethod
    def evaluate_dialog(pred: Sequence[DialogState],
                        gold: Sequence[DialogState]):
        pred_inform = [s.get("inform") for s in pred]
        gold_inform = [s.get("inform") for s in gold]
        pred_request = [s.get("request") for s in pred]
        gold_request = [s.get("request") for s in gold]

        def cum_sum(data: Sequence[DialogAct]):
            goal = dict()
            ret = []
            for da in data:
                for sv in da:
                    goal[sv.slot] = sv.value
                ret.append(copy.copy(goal))
            return ret

        pred_goal = cum_sum(pred_inform)
        gold_goal = cum_sum(gold_inform)

        def all_eq(p, q):
            return [x == y for x, y in zip(p, q)]

        stats = {
            "joint-goal": np.mean(all_eq(pred_goal, gold_goal)),
            "turn-acc": np.mean(all_eq(pred, gold)),
            "turn-inform": np.mean(all_eq(pred_inform, gold_inform)),
            "turn-request": np.mean(all_eq(pred_request, gold_request))
        }
        stats["hmean-inform"] = harmonic_mean(np.array([
            stats["joint-goal"],
            stats["turn-inform"]
        ]))
        return stats

    def test(self, dataloader: TestDataloader, ontology=None, mode="test"):
        self.model.eval()
        if ontology is None:
            ontology = (self.processor
                        .tensorize_state_dict(self._user_tensor, "user"))
        cum_stats = collections.defaultdict(float)
        progress = tqdm.tqdm(
            total=dataloader.total_items,
            dynamic_ncols=True,
            desc=mode
        )
        local_step = 0
        for dialogs, batch in dataloader:
            batch = batch.to(self.device)
            batch_size = batch.batch_size
            local_step += batch_size
            progress.update(batch_size)
            loss, pred = self.predict(batch, ontology)
            loss_mean = loss.mean()
            pred = bucket_items(pred, list(len(d.dst_turns) for d in dialogs))
            gold = [[turn.resolve_this().user.state
                     for turn in dialog.dst_turns] for dialog in dialogs]
            for p, g in zip(pred, gold):
                for k, v in self.evaluate_dialog(p, g).items():
                    cum_stats[k] += v * len(p)
            cum_stats["loss"] += loss_mean.item() * batch_size
        progress.close()
        return {k: v / local_step for k, v in cum_stats.items()}

    @staticmethod
    def iter_asr(batch: DSTTestBatchData) -> Iterable[DSTBatchData]:
        asr, asr_score = batch.asr, batch.asr_score
        assert (utils.compare_tensors(asr.lens, asr_score.lens)
                and asr.size(0) == asr_score.size(0))
        for i in range(asr.size(0)):
            asr_item = asr[i]
            asr_score_item = asr_score[i]
            num_asr = asr_item.size(0)
            yield DSTBatchData(
                sent=asr_item,
                system_acts=utils.pad_stack([batch.system_acts[i]] * num_asr),
                belief_state=utils.pad_stack([batch.belief_state[i]] * num_asr),
                slot=utils.pad_stack([batch.slot[i]] * num_asr),
                raw=[batch.raw[i]] * num_asr
            ), asr_score_item

    def test_asr(self, dataloader: TestDataloader, ontology=None, mode="test"):
        self.model.eval()
        if ontology is None:
            ontology = (self.processor
                        .tensorize_state_dict(self._user_tensor, "user"))
        cum_stats = collections.defaultdict(float)
        progress = tqdm.tqdm(
            total=dataloader.total_items,
            dynamic_ncols=True,
            desc=mode
        )
        local_step = 0
        for dialogs, batch in dataloader:
            batch = batch.to(self.device)
            batch_size = batch.batch_size
            local_step += batch_size
            progress.update(batch_size)
            loss, pred = self.predict_asr(batch, ontology)
            loss_mean = loss.mean()
            pred = bucket_items(pred, list(len(d.dst_turns) for d in dialogs))
            gold = [[turn.resolve_this().user.state
                     for turn in dialog.dst_turns] for dialog in dialogs]
            for p, g in zip(pred, gold):
                for k, v in self.evaluate_dialog(p, g).items():
                    cum_stats[k] += v * len(p)
            cum_stats["loss"] += loss_mean.item() * batch_size
            progress.set_postfix({
                "joint-goal": cum_stats["joint-goal"] / local_step})
        progress.close()
        return {k: v / local_step for k, v in cum_stats.items()}
Пример #15
0
class EmbeddingEvaluator(DialogEvaluator):
    vocab: utils.Vocabulary
    embeds: Embeddings
    _emb: torch.Tensor = utils.private_field(default=None)
    _cache: Set[int] = utils.private_field(default_factory=set)
    _stats: dict = utils.private_field(default_factory=dict)
    _seen: int = utils.private_field(default=0)
    _logger: logging.Logger = utils.private_field(default=None)

    def __post_init__(self):
        self._emb = torch.zeros(len(self.vocab),
                                self.embeds.dim,
                                requires_grad=False)
        self._logger = logging.getLogger(self.__class__.__name__)

    @staticmethod
    def compare_mean(pred: torch.Tensor, gold: torch.Tensor):
        return cosine_sim(pred.mean(0), gold.mean(0))

    @staticmethod
    def compare_extrema(pred: torch.Tensor, gold: torch.Tensor):
        def extrema(x):
            return x.gather(0, x.abs().max(0)[1].unsqueeze(0)).squeeze(0)

        return cosine_sim(extrema(pred), extrema(gold))

    @staticmethod
    def compare_greedy(pred: torch.Tensor, gold: torch.Tensor):
        dim = pred.size(-1)
        return cosine_sim(
            (pred.unsqueeze(1).expand(-1, gold.size(0), -1).contiguous().view(
                -1, dim)), (gold.unsqueeze(0).expand(
                    pred.size(0), -1, -1).contiguous().view(-1, dim)),
            dim=1).view(pred.size(0), gold.size(0)).max(0)[0].mean()

    def to_embeddings(self, words: torch.Tensor):
        self._emb = self._emb
        new_idx = set(e.item() for e in words.unique()) - self._cache
        for i in new_idx:
            w = self.vocab.i2f[i]
            if w not in self.embeds:
                self._logger.debug(f"word ({w}) not found in the embeddings")
            else:
                self._emb[i, :] = torch.tensor(self.embeds[w])
        if new_idx:
            self._cache.update(new_idx)
        exists = torch.tensor(
            [self.vocab.i2f[w.item()] in self.embeds for w in words]).bool()
        return self._emb[words.masked_select(exists)]

    def _eval(self, p, g, compare_fn) -> Optional[torch.Tensor]:
        pe = self.to_embeddings(p)
        ge = self.to_embeddings(g)
        if not len(ge):
            self._logger.debug(f"entire utterance omitted as no matching "
                               f"embeddings found for the gold sentence")
            return
        if not len(pe):
            self._logger.debug(
                f"no matching embeddings found for the pred sentence; "
                f"giving a zero score for the similarity score")
            return torch.tensor(0.0)
        res = compare_fn(pe, ge)
        if (res != res).any():
            self._logger.debug(f"NaN detected for pred = {pu} and gold = {gu}")
            return
        return res

    def reset(self):
        self._seen = 0
        self._stats.clear()

    def update(self, samples: Sequence) -> Optional[TensorMap]:
        pred = [
            torch.tensor([
                self.vocab[w] for turn in sample.output.turns
                for w in turn.text.split()
            ]).long() for sample in samples
        ]
        batch = [
            torch.tensor([
                self.vocab[w] for turn in sample.input.turns
                for w in turn.text.split()
            ]).long() for sample in samples
        ]
        for p, g in zip(pred, batch):
            stats = {
                "emb-mean": self._eval(p, g, self.compare_mean),
                "emb-extrema": self._eval(p, g, self.compare_extrema),
                "emb-greedy": self._eval(p, g, self.compare_greedy)
            }
            for k, v in stats.items():
                if v is None:
                    continue
                if k not in self._stats:
                    self._stats[k] = list()
                self._stats[k].append(v.item())
        return

    def get(self) -> TensorMap:
        return {k: torch.tensor(v).mean() for k, v in self._stats.items()}
Пример #16
0
class VocabSetFactory:
    """A helper class for managing vocabularies related to task-oriented
    dialogues. Note that dialogue acts (act-slot-value triples) are organized
    by the speakers.
    """
    tokenizer: Callable[[str], Sequence[str]] = None
    word: VocabularyFactory = VocabularyFactory()
    speaker: VocabularyFactory = VocabularyFactory()
    goal_cls: Callable[[], StateVocabSetFactory] = StateVocabSetFactory
    state_cls: Callable[[], StateVocabSetFactory] = StateVocabSetFactory
    goal_state_cls: Callable[[], StateVocabSetFactory] = StateVocabSetFactory
    _goal: StateVocabSetFactory = utils.private_field(default=None)
    _state: StateVocabSetFactory = utils.private_field(default=None)
    _goal_state: StateVocabSetFactory = utils.private_field(default=None)
    _goal_factories: Dict[str, StateVocabSetFactory] = \
        utils.private_field(default_factory=dict)
    _state_factories: Dict[str, StateVocabSetFactory] = \
        utils.private_field(default_factory=dict)
    _goal_state_factories: Dict[str, StateVocabSetFactory] = \
        utils.private_field(default_factory=dict)

    def __post_init__(self):
        self._goal = self.goal_cls()
        self._state = self.state_cls()
        self._goal_state = self.goal_state_cls()

        def tokenize(s):
            return s.split()

        self.tokenizer = self.tokenizer or tokenize

    def update_turn(self, turn: Turn):
        tokenizer = self.tokenizer
        self.word.update(tokenizer(turn.text))
        for sent in turn.asr:
            self.word.update(tokenizer(sent))
        for asv in turn.state:
            self.word.update(tokenizer(asv.act))
            self.word.update(tokenizer(asv.slot))
            self.word.update(tokenizer(asv.value))
        for asv in turn.goal:
            self.word.update(tokenizer(asv.act))
            self.word.update(tokenizer(asv.slot))
            self.word.update(tokenizer(asv.value))
        self.speaker.update((turn.speaker,))
        self._goal.update(turn.goal)
        self._state.update(turn.state)
        self._goal_state.update(turn.goal)
        self._goal_state.update(turn.state)
        if turn.speaker not in self._state_factories:
            self._state_factories[turn.speaker] = self.state_cls()
        self._state_factories[turn.speaker].update(turn.state)
        if turn.speaker not in self._goal_factories:
            self._goal_factories[turn.speaker] = self.goal_cls()
        self._goal_factories[turn.speaker].update(turn.goal)
        if turn.speaker not in self._goal_state_factories:
            self._goal_state_factories[turn.speaker] = self.goal_state_cls()
        self._goal_state_factories[turn.speaker].update(turn.goal)
        self._goal_state_factories[turn.speaker].update(turn.state)

    def update_turns(self, turns: Iterable[Turn]):
        for turn in turns:
            self.update_turn(turn)

    def update_words(self, words: Iterable[str]):
        self.word.update(words)

    def get_vocabs(self) -> VocabSet:
        return VocabSet(
            word=self.word.get_vocab(),
            speaker=self.speaker.get_vocab(),
            goal=self._goal.get_vocabs(),
            state=self._state.get_vocabs(),
            goal_state=self._goal_state.get_vocabs(),
            speaker_goal={spkr: factory.get_vocabs()
                          for spkr, factory in self._goal_factories.items()},
            speaker_state={spkr: factory.get_vocabs()
                           for spkr, factory in self._state_factories.items()},
            speaker_goal_state={s: f.get_vocabs()
                                for s, f in self._goal_state_factories.items()}
        )
Пример #17
0
class LogInferencer(Inferencer):
    progress_stat: Optional[str] = None
    display_stats: Optional[Union[Set[str], Sequence[str]]] = None
    writer: Optional[torch.utils.tensorboard.SummaryWriter] = None
    stats_formatter: StatsFormatter = field(default_factory=StatsFormatter)
    dialog_formatter: DialogFormatter = \
        field(default_factory=DialogTableFormatter)
    report_every: Optional[int] = None
    run_end_report: bool = True
    _tqdm: tqdm.tqdm = utils.private_field(default=None)
    _last_report: int = utils.private_field(default=0)
    _dialog_md_formatter: ClassVar[utils.DialogMarkdownFormatter] = \
        utils.DialogMarkdownFormatter()

    def __post_init__(self):
        super().__post_init__()
        if self.display_stats is not None:
            self.display_stats = set(self.display_stats)

    def on_run_started(self, dataloader: td.DataLoader) -> td.DataLoader:
        dataloader = super(LogInferencer, self).on_run_started(dataloader)
        self._tqdm = tqdm.tqdm(
            total=len(dataloader.dataset),
            dynamic_ncols=True,
            desc=self.__class__.__name__,
        )
        return dataloader

    def on_batch_started(self, batch: BatchData) -> BatchData:
        batch = super().on_batch_started(batch)
        self._tqdm.update(batch.batch_size)
        return batch

    def on_batch_ended(self, batch: BatchData, pred: BatchData,
                       outputs) -> utils.TensorMap:
        stats = super().on_batch_ended(batch, pred, outputs)
        if self.progress_stat is not None and self.progress_stat in stats:
            self._tqdm.set_postfix(
                {self.progress_stat: stats[self.progress_stat].item()})
        if self.report_every is not None and \
                (self.global_step - self._last_report) >= self.report_every:
            idx = random.randint(0, batch.batch_size - 1)
            batch_sample = batch.raw[idx]
            pred_sample = self.processor.lexicalize_global(pred[idx])
            self.log_dialog("input-sample", batch_sample)
            self.log_dialog("pred-sample", pred_sample)
            self.log_stats("summary", stats)
            self._last_report = self.global_step
        return stats

    def log_dialog(self, tag: str, dialog: Dialog):
        if self.writer is not None:
            self.writer.add_text(
                tag=tag,
                text_string=self._dialog_md_formatter.format(dialog),
                global_step=self.global_step)
        self._logger.info(f"global step {self.global_step:,d} {tag}:\n"
                          f"{self.dialog_formatter.format(dialog)}")

    def log_stats(self,
                  tag: str,
                  stats: utils.TensorMap,
                  prefix=None,
                  postfix=None):
        def wrap_key(key):
            if prefix is not None:
                key = f"{prefix}-{key}"
            if postfix is not None:
                key = f"{key}-{postfix}"
            return key

        if self.writer is not None:
            for k, v in stats.items():
                self.writer.add_scalar(wrap_key(k), v.item(), self.global_step)
        if self.display_stats is not None:
            display_stats = {
                k: v
                for k, v in stats.items() if k in self.display_stats
            }
        else:
            display_stats = stats
        display_stats = {wrap_key(k): v for k, v in display_stats.items()}
        self._logger.info(f"global step {self.global_step:,d} {tag}:\n"
                          f"{self.stats_formatter.format(display_stats)}")

    def on_run_ended(self, stats: utils.TensorMap) -> utils.TensorMap:
        stats = super().on_run_ended(stats)
        if self.run_end_report:
            self.log_stats("run-end-summary", stats, postfix="run")
        self._tqdm.close()
        return stats
Пример #18
0
class BeamSearcher:
    """
    Init Arguments:
        embedder (Callable): an embedding object that has the following
            call signature:
                Call Arguments:
                    input (LongTensor): [N x seq_len]
                    lens (optional, LongTensor): [N]
                Call Returns:
                    output (FloatTensor): [N x seq_lenx embed_dim]
        cell (Callable): a rnn-cell-like object that has the following call
            signature:
                Call Arguments:
                    input (FloatTensor): [N x embed_dim]
                    state (optional, FloatTensor or tuple of):
                        [N x hidden_dim], ...
                Call Returns:
                    output (FloatTensor): [N x hidden_dim]
                    next_state (FloatTensor, FloatTensor or tuple of):
                        [N x hidden_dim], ...
        classifier (Callable): a feedforward-like object that has the following
            call signature:
                Call Arguments:
                    input (FloatTensor): [N1 x ... x Nn x hidden_dim]
                Call Returns:
                    output (FloatTensor): [N1 x ... x Nn x vocab_size]
    """
    embedder: Callable
    cell: Callable
    classifier: Callable
    vocab: Vocabulary
    beam_size: int = 8
    max_len: int = 30
    bos: str = "<bos>"
    eos: str = "<eos>"
    _eos_idx: Optional[int] = private_field(default=None)

    def __post_init__(self):
        if self.bos not in self.vocab:
            raise ValueError(f"bos token {self.bos} not found in vocab.")
        if self.eos not in self.vocab:
            raise ValueError(f"eos token {self.eos} not found in vocab.")
        self._eos_idx = self.vocab[self.eos]

    def search(self, s0):
        """Perform beam search.

        Arguments:
            s0 (torch.FloatTensor or tuple of): initial cell state
                [N x hidden_dim], ...

        Returns:
            sample (torch.LongTensor): [N x beam_size x seq_len]
            lens (torch.LongTensor): [N x beam_size]
            prob (torch.FloatTensor): [N x beam_size]
        """
        s0_sample = s0[0] if isinstance(s0, tuple) else s0
        batch_size = s0_sample.size(0)
        logit0 = torch.zeros(len(self.vocab)).to(s0_sample).fill_(float("-inf"))
        logit0[self.vocab[self.bos]] = 0
        word = s0_sample.new(batch_size, self.beam_size, 0).long()
        done = s0_sample.new(batch_size, self.beam_size).fill_(0).bool()
        prob = s0_sample.new(batch_size, self.beam_size).fill_(1.0)
        lens = s0_sample.new(batch_size, self.beam_size).fill_(0).long()
        s = (tuple(t.unsqueeze(1)
                   .expand(batch_size, self.beam_size, -1).contiguous()
                   for t in s0)
             if isinstance(s0, tuple) else s0)
        while not done.all() and lens.max() < self.max_len:
            seq_len = word.size(2)
            if not seq_len:
                logit = logit0
                prob_prime, word_prime = torch.softmax(logit, 0).sort(0, True)
                prob_prime = prob_prime[:self.beam_size]
                word_prime = word_prime[:self.beam_size]
                prob = prob_prime.unsqueeze(0).expand(batch_size, -1)
                word = (word_prime.unsqueeze(0).unsqueeze(-1)
                        .expand(batch_size, -1, -1)).contiguous()
                lens += 1
                continue
            emb = self.embedder(word.view(-1, word.size(-1)), lens.view(-1))
            emb = emb[:, -1, :].view(batch_size, self.beam_size, -1)
            if isinstance(s, tuple):
                s_flat = tuple(t.view(-1, t.size(-1)) for t in s)
            else:
                s_flat = s.view(-1, s.size(-1))
            o, s_prime = self.cell(emb.view(-1, emb.size(-1)), s_flat)
            if isinstance(s_prime, tuple):
                s = tuple(t.view(batch_size, self.beam_size, -1)
                          for t in s_prime)
            else:
                s = s_prime.view(batch_size, self.beam_size, -1)
            logit = self.classifier(o).view(batch_size, self.beam_size, -1)
            if self._eos_idx is not None:
                logit_eos = torch.full_like(logit, float("-inf"))
                logit_eos[:, :, self._eos_idx] = 0
                logit = logit.masked_scatter(done.unsqueeze(-1), logit_eos)
            vocab_size = logit.size(-1)
            prob_prime = prob.unsqueeze(-1) * torch.softmax(logit, 2)
            prob_prime, prob_idx = \
                prob_prime.view(batch_size, -1).sort(1, True)
            prob = prob_prime[:, :self.beam_size]
            prob_idx = prob_idx[:, :self.beam_size].long()
            beam_idx = prob_idx / vocab_size
            word_prime = prob_idx % vocab_size
            word = torch.cat([
                word.gather(1, beam_idx.unsqueeze(-1).expand_as(word)),
                word_prime.unsqueeze(-1)
            ], 2)
            lens = lens.gather(1, beam_idx) + (~done).long()
            if self._eos_idx is not None:
                done = (word_prime == self._eos_idx) | done.gather(1, beam_idx)
        return word, lens, prob
Пример #19
0
class DialogStateEvaluator(FinegrainedEvaluator):
    vocabs: VocabSet
    return_update: bool = False
    _pred_goal: List[Stacked1DTensor] = \
        utils.private_field(default_factory=list)
    _gold_goal: List[Stacked1DTensor] = \
        utils.private_field(default_factory=list)
    _pred_state: List[Stacked1DTensor] = \
        utils.private_field(default_factory=list)
    _gold_state: List[Stacked1DTensor] = \
        utils.private_field(default_factory=list)
    _spkr: List[torch.Tensor] = utils.private_field(default_factory=list)

    def compute_accuracy(self,
                         pred: DoublyStacked1DTensor,
                         gold: DoublyStacked1DTensor,
                         turn_mask=None) -> utils.TensorMap:
        batch_size = pred.size(0)
        pred_dense = utils.to_dense(pred.value,
                                    pred.lens1,
                                    max_size=len(self.vocabs.goal_state.asv))
        gold_dense = utils.to_dense(gold.value,
                                    gold.lens1,
                                    max_size=len(self.vocabs.goal_state.asv))
        crt = (pred_dense == gold_dense).all(-1)
        conv_mask = utils.mask(pred.lens, pred.size(1))
        if turn_mask is None:
            turn_mask = torch.ones_like(conv_mask).bool()
        turn_mask = turn_mask & conv_mask
        crt = crt & turn_mask
        num_turns = turn_mask.sum()
        stats = {
            "acc": (crt | ~turn_mask).all(-1).sum().float() / batch_size,
            "acc-turn": crt.sum().float() / num_turns,
        }
        return stats

    def reset(self):
        self._pred_goal, self._gold_goal = list(), list()
        self._pred_state, self._gold_state = list(), list()
        self._spkr = list()

    def update(self, batch: BatchData, pred: BatchData,
               outputs) -> Optional[TensorMap]:
        self._pred_goal.extend(pred.goal)
        self._gold_goal.extend(batch.goal)
        self._pred_state.extend(pred.state)
        self._gold_state.extend(batch.state)
        self._spkr.extend(batch.speaker)
        if not self.return_update:
            return
        stats = dict()
        stats.update({
            f"goal-{k}": v
            for k, v in self.compute_accuracy(pred.goal, batch.goal).items()
        })
        stats.update({
            f"state-{k}": v
            for k, v in self.compute_accuracy(pred.state, batch.state).items()
        })
        for spkr_idx, spkr in self.vocabs.speaker.i2f.items():
            if spkr == "<unk>":
                continue
            spkr_value = batch.speaker.value
            stats.update({
                f"goal-{k}-{spkr}": v
                for k, v in self.compute_accuracy(
                    pred.goal, batch.goal, spkr_value == spkr_idx).items()
            })
            stats.update({
                f"state-{k}-{spkr}": v
                for k, v in self.compute_accuracy(
                    pred.state, batch.state, spkr_value == spkr_idx).items()
            })
        return stats

    def get(self) -> TensorMap:
        pred_goal = utils.stack_stacked1dtensors(self._pred_goal)
        pred_state = utils.stack_stacked1dtensors(self._pred_state)
        gold_goal = utils.stack_stacked1dtensors(self._gold_goal)
        gold_state = utils.stack_stacked1dtensors(self._gold_state)
        spkr_value = utils.pad_stack(self._spkr).value
        stats = dict()
        stats.update({
            f"goal-{k}": v
            for k, v in self.compute_accuracy(pred_goal, gold_goal).items()
        })
        stats.update({
            f"state-{k}": v
            for k, v in self.compute_accuracy(pred_state, gold_state).items()
        })
        for spkr_idx, spkr in self.vocabs.speaker.i2f.items():
            if spkr == "<unk>":
                continue
            stats.update({
                f"goal-{k}-{spkr}": v
                for k, v in self.compute_accuracy(
                    pred_goal, gold_goal, spkr_value == spkr_idx).items()
            })
            stats.update({
                f"state-{k}-{spkr}": v
                for k, v in self.compute_accuracy(
                    pred_state, gold_state, spkr_value == spkr_idx).items()
            })
        return stats
Пример #20
0
class WordEntropyEvaluator(DialogEvaluator):
    """arXiv:1605.06069"""
    dataset: DialogDataset
    _prob: torch.Tensor = utils.private_field(default=None)
    _spkr_prob: dict = utils.private_field(default=None)
    _values: dict = utils.private_field(default_factory=dict)

    def __post_init__(self):
        self._prob, self._spkr_prob = self.compute_unigram_prob()

    @property
    def speakers(self):
        return set(spkr for spkr in self.vocabs.speaker.f2i if spkr != "<unk>")

    @property
    def vocabs(self):
        return self.dataset.processor.vocabs

    def compute_unigram_prob(self):
        prob = torch.ones(len(self.vocabs.word)).long()  # +1 smoothing
        spkr_prob = {spkr: torch.ones(len(self.vocabs.word)).long()
                     for spkr in self.speakers}
        for dialog in self.dataset.data:
            for turn in dialog.turns:
                if turn.speaker == "<unk>":
                    continue
                tokens = \
                    self.dataset.processor.sent_processor.process(turn.text)
                tokens = utils.lstrip(tokens, "<bos>")
                tokens = utils.rstrip(tokens, "<eos>")
                for word in tokens:
                    word_idx = self.vocabs.word[word]
                    spkr_prob[turn.speaker][word_idx] += 1
                    prob[word_idx] += 1
        prob = prob.float() / prob.sum()
        spkr_prob = {spkr: p.float() / p.sum() for spkr, p in spkr_prob.items()}
        return prob, spkr_prob

    def reset(self):
        self._values.clear()

    def compute_entropy(self, text: str) -> Optional[float]:
        tokens = self.dataset.processor.tokenize(text)
        if not tokens:
            return {
                "word-ent": torch.tensor(0.0),
                "word-ent-sent": torch.tensor(0.0)
            }
        words, counts = zip(*collections.Counter(tokens).items())
        words = [self.vocabs.word[w] for w in words]
        words, counts = torch.tensor(words).long(), torch.tensor(counts).float()
        text_prob = counts.float() / counts.sum()
        ent = (text_prob * self._prob[words]).sum()
        return {
            "word-ent": ent,
            "word-ent-sent": len(tokens) * ent
        }

    def compute_entropy_spkr(self, text: str, spkr) -> Optional[float]:
        tokens = self.dataset.processor.tokenize(text)
        if not tokens:
            return {
                "word-ent": torch.tensor(0.0),
                "word-ent-sent": torch.tensor(0.0)
            }
        words, counts = zip(*collections.Counter(tokens).items())
        words = [self.vocabs.word[w] for w in words]
        words, counts = torch.tensor(words).long(), torch.tensor(counts).float()
        text_prob = counts.float() / counts.sum()
        ent = (text_prob * self._spkr_prob[spkr][words]).sum()
        return {
            "word-ent": ent,
            "word-ent-sent": len(tokens) * ent
        }

    def update(self, samples: Sequence) -> Optional[TensorMap]:
        for sample in samples:
            for turn in sample.output:
                spkr = turn.speaker
                sent = turn.text
                stats = self.compute_entropy(sent)
                if spkr != "<unk>":
                    stats.update({
                        f"{k}-{spkr}": v for k, v in
                        self.compute_entropy_spkr(sent, spkr).items()
                    })
                for k, v in stats.items():
                    if k not in self._values:
                        self._values[k] = list()
                    self._values[k].append(v.item())
        return

    def get(self) -> Optional[TensorMap]:
        return {k: torch.tensor(v).mean() for k, v in self._values.items()}
Пример #21
0
class StateEntropyEvaluator(DialogEvaluator):
    """(New!)"""
    dataset: DialogDataset
    _asv_prob: torch.Tensor = utils.private_field(default=None)
    _spkr_asv_prob: dict = utils.private_field(default_factory=dict)
    _values: dict = utils.private_field(default_factory=dict)

    def __post_init__(self):
        self.compute_distributions()

    @property
    def speakers(self):
        return set(spkr for spkr in self.vocabs.speaker.f2i if spkr != "<unk>")

    @property
    def vocabs(self):
        return self.dataset.processor.vocabs

    def compute_distributions(self):
        # +1 smoothing
        prob = torch.ones(len(self.vocabs.state.asv)).long()
        spkr_prob = {
            spkr: torch.ones(len(self.vocabs.speaker_state[spkr].asv)).long()
            for spkr in self.speakers
        }
        for dialog in self.dataset.data:
            for turn in dialog.turns:
                if turn.speaker == "<unk>":
                    continue
                spkr = turn.speaker
                for asv in turn.state:
                    asv_idx = self.vocabs.state.asv[asv]
                    spkr_asv_idx = self.vocabs.speaker_state[spkr].asv[asv]
                    spkr_prob[spkr][spkr_asv_idx] += 1
                    prob[asv_idx] += 1
        prob = prob.float() / prob.sum()
        spkr_prob = {
            spkr: p.float() / p.sum()
            for spkr, p in spkr_prob.items()
        }
        self._asv_prob, self._spkr_asv_prob = prob, spkr_prob

    def reset(self):
        self._values.clear()

    def compute_entropy(self,
                        state: Sequence[ActSlotValue]) -> Optional[float]:
        if not state:
            return {
                "asv-ent": torch.tensor(0.0),
                "asv-ent-turn": torch.tensor(0.0)
            }
        asvs, counts = zip(*collections.Counter(state).items())
        asvs = [self.vocabs.state.asv[w] for w in asvs]
        asvs, counts = torch.tensor(asvs).long(), torch.tensor(counts).float()
        text_prob = counts.float() / counts.sum()
        ent = (text_prob * self._asv_prob[asvs]).sum()
        return {"asv-ent": ent, "asv-ent-turn": len(state) * ent}

    def compute_entropy_spkr(self, state: Sequence[ActSlotValue],
                             spkr: str) -> Optional[float]:
        if not state:
            return {
                "asv-ent": torch.tensor(0.0),
                "asv-ent-turn": torch.tensor(0.0)
            }
        asvs, counts = zip(*collections.Counter(state).items())
        asvs = [self.vocabs.speaker_state[spkr].asv[w] for w in asvs]
        asvs, counts = torch.tensor(asvs).long(), torch.tensor(counts).float()
        text_prob = counts.float() / counts.sum()
        ent = (text_prob * self._spkr_asv_prob[spkr][asvs]).sum()
        return {"asv-ent": ent, "asv-ent-turn": len(state) * ent}

    def update(self, samples: Sequence) -> Optional[TensorMap]:
        for sample in samples:
            for turn_gold, turn in zip(sample.input, sample.output):
                spkr = turn.speaker
                asvs = list(turn.state)
                stats = self.compute_entropy(asvs)
                if spkr != "<unk>":
                    stats.update({
                        f"{k}-{spkr}": v
                        for k, v in self.compute_entropy_spkr(asvs,
                                                              spkr).items()
                    })
                for k, v in stats.items():
                    if k not in self._values:
                        self._values[k] = list()
                    self._values[k].append(v.item())
        return

    def get(self) -> Optional[TensorMap]:
        return {k: torch.tensor(v).mean() for k, v in self._values.items()}
Пример #22
0
class BLEUEvaluator(DialogEvaluator):
    vocabs: VocabSet
    _ref: List[Sequence[Sequence[str]]] = \
        utils.private_field(default_factory=list)
    _hyp: List[Sequence[Sequence[str]]] = \
        utils.private_field(default_factory=list)
    _hyp_spkr: dict = utils.private_field(default_factory=dict)
    _ref_spkr: dict = utils.private_field(default_factory=dict)
    _logger: logging.Logger = utils.private_field(default=None)

    def __post_init__(self):
        self._logger = logging.getLogger(self.__class__.__name__)

    def reset(self):
        self._ref.clear()
        self._hyp.clear()
        self._ref_spkr.clear()
        self._hyp_spkr.clear()

    @staticmethod
    def closest_rlen(ref_lens, hyp_len):
        idx = bisect.bisect_left(ref_lens, hyp_len)
        if idx == 0:
            return ref_lens[0]
        elif idx == len(ref_lens):
            return ref_lens[idx - 1]
        else:
            len1, len2 = ref_lens[idx - 1], ref_lens[idx]
            if hyp_len - len1 <= len2 - hyp_len:
                return len1
            else:
                return len2

    def try_compute(self, *args, **kwargs) -> float:
        try:
            return bl.sentence_bleu(*args, **kwargs)
        except Exception as e:
            self._logger.debug(f"error in bleu: {e}")
            return 0.0

    def compute_bleu(self, ref: Sequence[Sequence[str]],
                     hyp: Sequence[Sequence[str]]):
        r_lens = list(sorted(set(map(len, ref))))
        chencherry = bl.SmoothingFunction()
        methods = (
            ("smooth0", chencherry.method0),
            ("smooth1", chencherry.method1),
            ("smooth2", chencherry.method2),
            ("smooth3", chencherry.method3),
            ("smooth4", chencherry.method4),
            ("smooth5", chencherry.method5),
            ("smooth6", chencherry.method6),
            ("smooth7", chencherry.method7)
        )
        if not hyp:
            return {name: 0.0 for name, _ in methods}
        stats = {
            name: np.mean([
                (self.try_compute(ref, h, smoothing_function=method) *
                 bl.brevity_penalty(self.closest_rlen(r_lens, len(h)), len(h)))
                for h in hyp
            ]) for name, method in methods
        }
        return stats

    def compute_bleu_corpus(self, refs: Sequence[Sequence[Sequence[str]]],
                            hyps: Sequence[Sequence[Sequence[str]]]):
        assert len(refs) == len(hyps)
        stats = collections.defaultdict(float)
        for ref, hyp in zip(refs, hyps):
            for k, v in self.compute_bleu(ref, hyp).items():
                stats[k] += v
        return {k: v / len(refs) for k, v in stats.items()}

    def update(self, samples: Sequence) -> Optional[TensorMap]:
        ref = [[turn.text.split() for turn in sample.input.turns]
               for sample in samples]
        hyp = [[turn.text.split() for turn in sample.output.turns]
               for sample in samples]
        self._ref.extend(ref)
        self._hyp.extend(hyp)
        for spkr in self.vocabs.speaker.f2i:
            if spkr == "<unk>":
                continue
            if spkr not in self._ref_spkr:
                self._ref_spkr[spkr] = list()
                self._hyp_spkr[spkr] = list()
            ref = [[turn.text.split() for turn in sample.input.turns
                    if turn.speaker == spkr]
                   for sample in samples]
            hyp = [[turn.text.split() for turn in sample.output.turns
                    if turn.speaker == spkr]
                   for sample in samples]
            self._ref_spkr[spkr].extend(ref)
            self._hyp_spkr[spkr].extend(hyp)

    def get(self) -> TensorMap:
        stats = {f"bleu-{k}": torch.tensor(v) for k, v in
                 self.compute_bleu_corpus(self._ref, self._hyp).items()}
        for spkr in self.vocabs.speaker.f2i:
            if spkr == "<unk>":
                continue
            stats.update({
                f"bleu-{k}-{spkr}": torch.tensor(v) for k, v in
                self.compute_bleu_corpus(self._ref_spkr[spkr],
                                         self._hyp_spkr[spkr]).items()
            })
        return stats
Пример #23
0
class Generator:
    model: AbstractTDA
    processor: DialogProcessor
    batch_size: int = 32
    device: torch.device = torch.device("cpu")
    global_step: int = 0
    asv_tensor: utils.Stacked1DTensor = None
    _num_instances: int = utils.private_field(default=None)

    def __post_init__(self):
        if self.asv_tensor is None:
            self.asv_tensor = self.processor.tensorize_state_vocab(
                "goal_state")
        self.asv_tensor = self.asv_tensor.to(self.device)

    def on_run_started(self):
        return

    def on_run_ended(self, samples: Sequence[Sample],
                     stats: TensorMap) -> Tuple[Sequence[Sample], TensorMap]:
        return samples, stats

    def on_batch_started(self, batch: BatchData) -> BatchData:
        return batch

    def validate_sample(self, sample: Sample):
        return True

    def on_batch_ended(self, samples: Sequence[Sample]) -> TensorMap:
        return dict()

    def generate_kwargs(self) -> dict:
        return dict()

    def prepare_batch(self, batch: BatchData) -> dict:
        return {
            "conv_lens": batch.conv_lens,
            "sent": batch.sent.value,
            "sent_lens": batch.sent.lens1,
            "speaker": batch.speaker.value,
            "goal": batch.goal.value,
            "goal_lens": batch.goal.lens1,
            "state": batch.state.value,
            "state_lens": batch.state.lens1,
            "asv": self.asv_tensor.value,
            "asv_lens": self.asv_tensor.lens
        }

    def __call__(
        self,
        data: Optional[Sequence[Dialog]] = None,
        num_instances: Optional[int] = None
    ) -> Tuple[Sequence[Sample], TensorMap]:
        if data is None and num_instances is None:
            raise ValueError(f"must provide a data source or "
                             f"number of instances.")
        dataloader = None
        if data is not None:
            dataloader = create_dataloader(dataset=DialogDataset(
                data=data, processor=self.processor),
                                           batch_size=self.batch_size,
                                           drop_last=False,
                                           shuffle=False)
            if num_instances is None:
                num_instances = len(data)
        self._num_instances = num_instances
        self.on_run_started()
        dataloader = (itertools.repeat(None)
                      if dataloader is None else itertools.cycle(dataloader))
        cum_stats = collections.defaultdict(float)
        samples = []
        for batch in dataloader:
            self.model.eval()
            if batch is None:
                batch_size = min(self.batch_size, num_instances - len(samples))
                self.model.genconv_prior()
                with torch.no_grad():
                    pred, info = self.model(
                        torch.tensor(batch_size).to(self.device),
                        **self.generate_kwargs())
            else:
                batch = batch.to(self.device)
                batch_size = batch.batch_size
                self.global_step += batch_size
                batch = self.on_batch_started(batch)
                self.model.genconv_post()
                with torch.no_grad():
                    pred, info = self.model(self.prepare_batch(batch),
                                            **self.generate_kwargs())
            batch_samples = list(
                filter(self.validate_sample, (Sample(*args) for args in zip(
                    map(self.processor.lexicalize_global, batch),
                    map(self.processor.lexicalize_global, pred),
                    info["logprob"]))))
            num_res = max(0, len(samples) + len(batch_samples) - num_instances)
            if num_res > 0:
                batch_samples = random.sample(batch_samples,
                                              num_instances - len(samples))
            batch_size = len(batch_samples)
            self.global_step += batch_size
            stats = self.on_batch_ended(batch_samples)
            samples.extend(batch_samples)
            for k, v in stats.items():
                cum_stats[k] += v * batch_size
            if len(samples) >= num_instances:
                break
        assert len(samples) == num_instances
        cum_stats = {k: v / len(samples) for k, v in cum_stats.items()}
        return self.on_run_ended(samples, cum_stats)
Пример #24
0
class Inferencer:
    model: AbstractTDA
    processor: DialogProcessor
    device: torch.device = torch.device("cpu")
    global_step: int = 0
    asv_tensor: utils.Stacked1DTensor = None
    _logger: logging.Logger = utils.private_field(default=None)

    def __post_init__(self):
        self._logger = logging.getLogger(self.__class__.__name__)
        if self.asv_tensor is None:
            self.asv_tensor = self.processor.tensorize_state_vocab(
                "goal_state")
        self.asv_tensor = self.asv_tensor.to(self.device)

    def on_run_started(self, dataloader: td.DataLoader) -> td.DataLoader:
        return dataloader

    def on_run_ended(self, stats: utils.TensorMap) -> utils.TensorMap:
        return stats

    def on_batch_started(self, batch: BatchData) -> BatchData:
        return batch

    def on_batch_ended(self, batch: BatchData, pred: BatchData,
                       outputs) -> utils.TensorMap:
        return {}

    def model_kwargs(self) -> dict:
        return {}

    @staticmethod
    def predict(batch: BatchData, outputs) -> BatchData:
        def predict_state(state_logit, conv_lens):
            batch_size, max_conv_len, num_asv = state_logit.size()
            mask = ((state_logit == float("-inf")) |
                    (state_logit == float("inf")))
            pred = (torch.sigmoid(state_logit.masked_fill(
                mask, 0)).masked_fill(mask, 0)) > 0.5
            pred = utils.to_sparse(pred.view(-1, num_asv))
            return utils.DoublyStacked1DTensor(
                value=pred.value.view(batch_size, max_conv_len, -1),
                lens=conv_lens,
                lens1=pred.lens.view(batch_size, max_conv_len))

        logit, post, prior = outputs
        return BatchData(
            sent=utils.DoublyStacked1DTensor(value=torch.cat([
                batch.sent.value[..., :1], logit["sent"].max(-1)[1][..., :-1]
            ], 2),
                                             lens=batch.sent.lens,
                                             lens1=batch.sent.lens1),
            speaker=utils.Stacked1DTensor(value=logit["speaker"].max(-1)[1],
                                          lens=batch.conv_lens),
            goal=predict_state(logit["goal"], batch.conv_lens),
            state=predict_state(logit["state"], batch.conv_lens),
            raw=batch.raw)

    def prepare_batch(self, batch: BatchData) -> dict:
        return {
            "conv_lens": batch.conv_lens,
            "sent": batch.sent.value,
            "sent_lens": batch.sent.lens1,
            "speaker": batch.speaker.value,
            "goal": batch.goal.value,
            "goal_lens": batch.goal.lens1,
            "state": batch.state.value,
            "state_lens": batch.state.lens1,
            "asv": self.asv_tensor.value,
            "asv_lens": self.asv_tensor.lens
        }

    def __call__(self, dataloader):
        dataloader = self.on_run_started(dataloader)
        cum_stats = collections.defaultdict(float)
        total_steps = 0
        for batch in dataloader:
            batch = batch.to(self.device)
            total_steps += batch.batch_size
            self.global_step += batch.batch_size
            batch = self.on_batch_started(batch)
            self.model.inference()
            outputs = self.model(self.prepare_batch(batch),
                                 **self.model_kwargs())
            pred = self.predict(batch, outputs)
            stats = self.on_batch_ended(batch, pred, outputs)
            for k, v in stats.items():
                cum_stats[k] += v.detach() * batch.batch_size
        cum_stats = {k: v / total_steps for k, v in cum_stats.items()}
        return self.on_run_ended(cum_stats)
Пример #25
0
class RougeEvaluator(DialogEvaluator):
    vocabs: VocabSet
    _hyp: List[str] = utils.private_field(default_factory=list)
    _ref: List[str] = utils.private_field(default_factory=list)
    _hyp_spkr: dict = utils.private_field(default_factory=dict)
    _ref_spkr: dict = utils.private_field(default_factory=dict)
    _rouge: rouge.Rouge = utils.private_field(default_factory=rouge.Rouge)
    _key_map: ClassVar[Mapping[str, str]] = {
        "f": "f1",
        "p": "prec",
        "r": "rec"
    }

    def reset(self):
        self._hyp.clear()
        self._ref.clear()
        self._hyp_spkr.clear()
        self._ref_spkr.clear()

    def try_rouge(self, hyp: Sequence[str], ref: Sequence[str]):
        try:
            return {
                f"{k}-{self._key_map.get(k2, k2)}": v2
                for k, v in self._rouge.get_scores(hyp, ref, avg=True).items()
                for k2, v2 in v.items()
            }
        except Exception as e:
            return {
                f"rouge-{k}-{k2}": 0.0
                for k in ("1", "2", "l") for k2 in ("f1", "prec", "rec")
            }

    def update(self, samples: Sequence) -> Optional[TensorMap]:
        hyps, refs = list(), list()
        for sample in samples:
            hyp, ref = "", ""
            for hyp_turn in sample.output.turns:
                hyp += " " + hyp_turn.text.strip()
            for ref_turn in sample.input.turns:
                ref += " " + ref_turn.text.strip()
            hyp, ref = hyp.strip(), ref.strip()
            if not ref:
                continue
            if not hyp:
                hyp = "<pad>"
            hyps.append(hyp), refs.append(ref)
        self._hyp.extend(hyps)
        self._ref.extend(refs)
        for spkr in self.vocabs.speaker.f2i:
            if spkr == "<unk>":
                continue
            if spkr not in self._hyp_spkr:
                self._hyp_spkr[spkr] = list()
            if spkr not in self._ref_spkr:
                self._ref_spkr[spkr] = list()
            hyps, refs = list(), list()
            for sample in samples:
                hyp, ref = "", ""
                for hyp_turn in sample.output.turns:
                    if hyp_turn.speaker != spkr:
                        continue
                    hyp += " " + hyp_turn.text.strip()
                for ref_turn in sample.input.turns:
                    if ref_turn.speaker != spkr:
                        continue
                    ref += " " + ref_turn.text.strip()
                hyp, ref = hyp.strip(), ref.strip()
                if not ref:
                    continue
                if not hyp:
                    hyp = "<pad>"
                hyps.append(hyp), refs.append(ref)
            self._hyp_spkr[spkr].extend(hyps)
            self._ref_spkr[spkr].extend(refs)
        return

    def get(self) -> TensorMap:
        stats = self.try_rouge(self._hyp, self._ref)
        for spkr in self.vocabs.speaker.f2i:
            if spkr == "<unk>":
                continue
            stats.update({
                f"{k}-{spkr}": v
                for k, v in self.try_rouge(self._hyp_spkr.get(
                    spkr, list()), self._ref_spkr.get(spkr, list())).items()
            })
        return {k: torch.tensor(v) for k, v in stats.items()}