Exemple #1
0
def create_exp_dir(path: str, script_path: str, overwrite=False) -> str:
    """
    Create experiment directory, and return path to log file.
    """
    if os.path.exists(path):
        if not overwrite:
            print(
                Logging.color(
                    col='red',
                    s=f"The experiment name: {path} already exists. "
                    f"Training will not proceed without the `--overwrite` flag."
                ))
            sys.exit(1)
        else:
            # Radical
            print(
                Logging.color(col='green',
                              s=f"Overwriting the experiment: {path} ..."))
            shutil.rmtree(path)

    os.mkdir(path)
    shutil.copy(script_path, path)
    logfile = os.path.join(
        path, f"{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}.log")
    return logfile
Exemple #2
0
    def sampling_decode(self, vocab: Dict[str, Vocab], example: LRLMExample,
                        begin_symbol: int = 2, end_symbol: int = 5,
                        initial_hidden: Optional[HiddenState] = None, warm_up: Optional[int] = None,
                        max_length: int = 200, greedy: bool = False, topk: Optional[int] = None,
                        print_info: bool = True, color_outputs: bool = False, init_batch=None, **_kwargs) \
            -> SampledOutput:
        tensor = functools.partial(sample_utils.tensor, device=self.device)
        sample = functools.partial(sample_utils.sample,
                                   greedy=greedy,
                                   topk=topk)

        self.eval()
        self.init_hidden(1, init_batch)

        if warm_up is None:
            inputs = [begin_symbol]
            hidden = initial_hidden
            total_log_prob = 0.0
        else:
            inputs = list(vocab["word"].numericalize(
                example.sentence[:warm_up]))
            total_log_prob, hidden = self.forward(tensor(inputs[:-1]),
                                                  target=tensor(inputs[1:]))
            total_log_prob = -torch.sum(total_log_prob).item() * (len(inputs) -
                                                                  1)

        while len(inputs) < max_length and inputs[-1] != end_symbol:
            # full copy of the forward pass, including dropouts. But they won't be applied due to .eval function.
            # Run LSTM over the word
            word_log_probs, new_hidden = self.forward(tensor(inputs[-1]),
                                                      hidden)
            word_id, word_log_prob = sample(word_log_probs)
            inputs.append(word_id)
            hidden = new_hidden
            total_log_prob += word_log_prob

        sample_loss = -total_log_prob / (len(inputs) - 1)
        if print_info:
            print(
                f"Sample loss: {sample_loss:.3f}, PPL: {math.exp(sample_loss):.3f}"
            )

        # Format the output
        words = [vocab["word"].i2w[token] for token in inputs]
        if color_outputs and warm_up is not None:
            words[:warm_up] = [
                Logging.color('yellow', w) for w in words[:warm_up]
            ]

        output = SampledOutput(sentence=words,
                               sample_loss=sample_loss,
                               complete_copies=0,
                               incomplete_copies=0)
        return output
Exemple #3
0
    def _try_load_cache(self, path: Path) -> bool:
        r"""Try loading a cached dataset from the data directory.

        :param path: The path to data directory.
        :return: Whether loading was successful.
        """
        cache_dir = path / '_cached'
        if not cache_dir.exists():
            return False
        params_index_path = cache_dir / 'params_index.pkl'
        params_index: List[Dict[str, Any]] = loadpkl(params_index_path)
        params = self._get_params()
        index = next(
            (idx for idx, p in enumerate(params_index)
             if (cache_dir /
                 f'{idx}.pkl').exists() and self._compare_params(p, params)),
            -1)
        if index != -1:
            load_path = cache_dir / f'{index}.pkl'
            self.batches = loadpkl(load_path)
            self.ntokens = {
                split: sum(batch.ntokens for _, batches in dataset
                           for batch in batches)
                for split, dataset in self.batches.items()
            }
            LOGGER.info(
                f"Cached dataset loaded from {load_path}, with settings: {params}"
            )
            # check for excluded keys and warn in case of mismatch
            load_params = params_index[index]
            for key in self.EXCLUDE_KEYS:
                if key in params or key in load_params:
                    current = params.get(key, "<does not exist>")
                    loaded = load_params.get(key, "<does not exist>")
                    if current != loaded:
                        LOGGER.info(
                            Logging.color(
                                'red', f"Ignored data param '{key}' mismatch "
                                f"(current: {current}, loaded: {loaded})"))
            return True
        return False
Exemple #4
0
    def sampling_decode(self, vocab: Dict[str, Vocab], example: LRLMExample,
                        begin_symbol: int = 2, end_symbol: int = 5,
                        initial_hidden: Optional[HiddenState] = None, warm_up: Optional[int] = None,
                        max_length: int = 200, greedy: bool = False, topk: Optional[int] = None,
                        print_info: bool = True, color_outputs: bool = False, show_rel_type: bool = True,
                        sanity_check: bool = False, unkinfo: Optional[Tuple[Tensor, List[str]]] = None, **kwargs) \
            -> SampledOutput:
        r"""
        Sampling for LRLM.

        Output format:
        - Red words:       Copied from canonical form of entity.
        - Green words:     Copied from alias form of entity.
        - Yellow words:    Warm-up context.
        - word_[type]:     "word" is an entity of type "type".
        - @-@:             A dash in the original text without spaces around, e.g. M @-@ 82 => M-82.

        :param vocab: Vocabulary containing id2word mapping.
        :param example: The :class:`Example` object of the current topic.
        :param begin_symbol: Start of sentence symbol.
        :param end_symbol: End of sentence symbol. Sampling stops when this symbol is generated.
        :param initial_hidden: If not specified, default hidden states returned by :meth:`init_hidden` is used.
        :param warm_up: Number of tokens to provide as context before performing sampling.
        :param max_length: If generated sentence exceeds specified length, sampling is force terminated.
        :param greedy: If ``True``, use greedy decoding instead of sampling.
        :param topk: If not ``None``, only sample from indices with top-k probabilites.

        :param print_info: If ``True``, print information about sampled result.
        :param color_outputs: If ``True``, include annotations for each output token. Tokens from entities will be
            colored red.
        :param show_rel_type: If ``True``, show relation types for copied entities.

        :param sanity_check: If ``True``, perform sanity check on generated sample.

        :param unkinfo: Precomputed unkprobs and the index-to-vocabulary mapping.

        :return: A tuple of (loss_value, formatted list of words).
        """
        if unkinfo is not None:
            unkprob, unki2w = unkinfo
            unkprob = unkprob[self._vocab_size:]
            unki2w = unki2w[self._vocab_size:]
            normalized_unkprob = F.log_softmax(unkprob, dim=0)

        # noinspection PyPep8Naming
        UNK, INVALID, CANONICAL_IDX, WORD_PREDICTOR, REL_PREDICTOR, EPS = -100, -1, 0, 0, 1, 1e-4

        self.eval()
        self.init_hidden(1, [example.relations])

        word_vocab, rel_vocab = vocab['word'], vocab['rel']

        tensor = functools.partial(sample_utils.tensor, device=self.device)
        sample = functools.partial(sample_utils.sample,
                                   greedy=greedy,
                                   topk=topk)
        np_sample = functools.partial(sample_utils.np_sample,
                                      greedy=greedy,
                                      topk=topk)

        # noinspection PyShadowingNames
        def compute_loss(
                inputs: List[int],
                spans: List[MatchedSpan],
                hidden: Optional[HiddenState] = None
        ) -> Tuple[float, HiddenState]:
            batch = SimpleNamespace(
                sequence=tensor(inputs[:-1]),
                target=tensor(inputs[1:]),
                spans=[spans],
                unkprob=None,
                lengths=torch.tensor([len(inputs) - 1], device=self.device),
                ntokens=len(inputs) - 1,
            )
            loss, next_hidden = self.calc_loss(batch,
                                               hidden=hidden)  # type: ignore
            return loss.item(), next_hidden

        if warm_up is None:
            inputs = [begin_symbol]
            rel_ids = [INVALID]
            surface_indices = [INVALID]
            spans: List[MatchedSpan] = []
            total_log_prob = 0.0
            marginal_log_prob = 0.0
            hidden = initial_hidden
        else:
            inputs = list(word_vocab.numericalize(example.sentence[:warm_up]))
            rel_ids = [INVALID] * len(
                inputs)  # assume everything is generated from vocabulary
            surface_indices = [INVALID] * len(inputs)
            spans = [span for span in example.spans if span.end < warm_up]
            loss, hidden = compute_loss(inputs, spans, initial_hidden)
            total_log_prob = -loss * (len(inputs) - 1)
            marginal_log_prob = -loss * (len(inputs) - 1)

        while len(inputs) < max_length and inputs[-1] != end_symbol:
            computed_log_probs, new_hidden = self._compute_log_probs(
                tensor(inputs[-1]), hidden)
            predictor, selector_loss = sample(computed_log_probs.selector)

            if predictor == REL_PREDICTOR:
                rel_id, rel_loss = sample(computed_log_probs.rel[0])

                if self._alias_disamb is AliasDisamb.FastText:
                    assert computed_log_probs.alias_logits is not None
                    aliases = example.relations[rel_id].obj_alias
                    alias_vecs = self.alias_vec[aliases]
                    surface_log_prob = F.log_softmax(torch.mv(
                        alias_vecs, computed_log_probs.alias_logits.flatten()),
                                                     dim=0)
                    surface_idx, alias_loss = sample(surface_log_prob)
                    alias = self.alias_list[aliases[surface_idx]]
                else:
                    # can't tell which one under oracle, use the canonical (first) alias
                    surface_idx = 0
                    alias_loss = 0.0
                    alias = example.relations[rel_id].obj_alias[
                        0]  # type: ignore

                # forward the hidden state according to the generated in-vocab tokens
                raw_tokens: List[str] = alias.split()
                token_ids: List[int] = word_vocab.numericalize(raw_tokens)
                if len(raw_tokens) > 1:
                    _, new_hidden = self._compute_log_probs(
                        tensor(token_ids[:-1]), new_hidden)

                # compute marginal probability for current span
                span_inputs = tensor([inputs[-1]] + token_ids[:-1])
                span_computed_log_probs, _ = self._compute_log_probs(
                    span_inputs, hidden)
                word_gen_loss = torch.sum(
                    span_computed_log_probs.selector[0, :, WORD_PREDICTOR] +
                    torch.gather(span_computed_log_probs.word,
                                 index=tensor(token_ids).unsqueeze(-1),
                                 dim=2).flatten()).item()
                marginal_log_prob += torch.logsumexp(tensor(
                    [selector_loss + rel_loss + alias_loss, word_gen_loss]),
                                                     dim=1).item()

                spans.append(
                    MatchedSpan(
                        len(inputs) - 1,
                        len(inputs) + len(token_ids) - 1,
                        example.relations[rel_id].rel_typ, rel_id,
                        surface_idx))
                inputs.extend(token_ids)
                rel_ids.extend([rel_id] + [INVALID] * (len(token_ids) - 1))
                surface_indices.extend([surface_idx] + [INVALID] *
                                       (len(token_ids) - 1))

                total_log_prob += selector_loss + rel_loss + alias_loss
            elif predictor == WORD_PREDICTOR:
                word, word_loss = sample(computed_log_probs.word)
                total_log_prob += selector_loss + word_loss
                marginal_log_prob += selector_loss + word_loss

                if word == 0 and unkinfo is not None:  # unk
                    unk_idx, unk_loss = np_sample(normalized_unkprob)
                    total_log_prob += unk_loss
                    marginal_log_prob += unk_loss
                    # Ugly multi-purpose use of variables.
                    surface_indices.append(
                        unk_idx)  # Record unk word index in surface_indices.
                    rel_ids.append(UNK)  # Record UNK in rel_ids.
                else:
                    rel_ids.append(INVALID)
                    surface_indices.append(INVALID)

                inputs.append(word)
            else:
                raise ValueError

            hidden = new_hidden

        sample_loss = -total_log_prob / (len(inputs) - 1)
        marginal_loss = -marginal_log_prob / (len(inputs) - 1)
        if print_info:
            print(
                f"Sample loss: {sample_loss:.3f}, PPL: {math.exp(sample_loss):.3f}"
            )
            print(
                f"Marginal sample loss: {marginal_loss:.3f}, PPL: {math.exp(marginal_loss):.3f}"
            )
        # Sanity checks
        if sanity_check:
            # noinspection PyTypeChecker
            loss_val, gold_hidden = compute_loss(inputs, spans, initial_hidden)
            assert hidden is not None
            hidden_state_diff = max(
                torch.max(torch.abs(g - h)).item()
                for g, h in zip(gold_hidden, hidden))
            if hidden_state_diff > EPS:
                Logging.warn(
                    f"Hidden states do not match. Difference: {hidden_state_diff}"
                )
            if abs(marginal_loss - loss_val) > EPS:
                Logging.warn(
                    f"Marginal loss values do not match. "
                    f"Forward loss: {loss_val}, difference: {abs(marginal_loss - loss_val)}"
                )

        num_rels_generated = sum(int(rel_id != INVALID) for rel_id in rel_ids)
        if print_info:
            print(
                f"Relations [Generated / Annotated]: "
                f"[{num_rels_generated} / {len([s for s in example.spans if s.end < max_length])}]"
            )

        words = []
        idx = 0
        copy_count = 0
        while idx < len(inputs):
            is_warm_up = (warm_up is not None and idx < warm_up)
            token_id, rel_id, surface_idx = inputs[idx], rel_ids[
                idx], surface_indices[idx]
            if rel_id == INVALID:
                token = word_vocab.i2w[token_id]
                idx += 1
            elif rel_id == UNK:
                token = Logging.color('blue', unki2w[surface_idx])
                idx += 1
            else:
                copy_count += 1
                word_id = example.relations[rel_id].obj_alias[
                    surface_idx]  # multiple words
                token = self.alias_list[word_id]
                idx += len(token.split())
                if show_rel_type:
                    token = f"{token}_[{rel_vocab.i2w[example.relations[rel_id].rel_typ]}]"
                if color_outputs and not is_warm_up:
                    token = Logging.color(
                        'red' if surface_idx == CANONICAL_IDX else 'green',
                        token)
            if color_outputs and is_warm_up:
                token = Logging.color('yellow', token)
            words.append(token)

        if print_info:
            print(f"# of copied entities: {copy_count}")

        output = SampledOutput(sentence=words,
                               sample_loss=sample_loss,
                               complete_copies=copy_count,
                               incomplete_copies=0)
        return output
Exemple #5
0
 def color_if_less(val: float,
                   threshold: float,
                   format_str: str = '{:.4f}',
                   color: str = 'yellow'):
     s = format_str.format(val)
     return Logging.color(color, s) if val < threshold else s
Exemple #6
0
    def span_log_probs(example,
                       _,
                       bptt_size=140,
                       ppl_threshold=200.0,
                       split_len=20,
                       n_context=5,
                       max_segments=-1):
        (init_batch, batches) = dataset.create_one_batch([example], bptt_size)
        rels: List[Relation] = init_batch[0]

        posterior_probs: List[List[Optional[Tuple[float, float]]]] = [
        ]  # list(n_batches) of list(n_spans)
        seq_loss: List[float] = []

        # noinspection PyShadowingNames
        def callback(loss: Tensor, batch: BatchSequence) -> None:
            assert batch.spans is not None
            probs_dict: Dict[MatchedSpan, Tuple[
                float, float]] = model.model_cache['posterior_log_probs'][0]
            probs = [probs_dict.get(span, None) for span in batch.spans[0]]
            posterior_probs.append(probs)
            seq_loss.append(loss.item())

        def color_if_less(val: float,
                          threshold: float,
                          format_str: str = '{:.4f}',
                          color: str = 'yellow'):
            s = format_str.format(val)
            return Logging.color(color, s) if val < threshold else s

        with torch.no_grad():
            model.eval()
            compute_batch_loss(model,
                               init_batch,
                               batches,
                               use_unk_probs=True,
                               callback=callback,
                               evaluate=True,
                               calc_loss_kwargs={'dump_posterior_probs': True})

        n_words = 0
        for seq_idx, (batch, probs) in enumerate(zip(batches,
                                                     posterior_probs)):
            if max_segments != -1 and seq_idx >= max_segments:
                break
            print(
                Logging.color(
                    'green', f"Segment #{seq_idx}: "
                    f"words {n_words} - {n_words + batch.ntokens}, "
                    f"ppl = {math.exp(seq_loss[seq_idx]):.4f}"))
            n_words += batch.ntokens

            tokens = batch.raw_sequence[0][1:]
            spans = batch.spans[0]
            is_in_span = [False] * batch.ntokens
            for span in spans:
                if span.start > span.end or span.end >= batch.ntokens:
                    continue
                is_in_span[span.start:(
                    span.end + 1)] = [True] * (span.end - span.start + 1)
            for idx in range(0, batch.ntokens, split_len):
                print(
                    f'{idx:3d}:', ' '.join(
                        Logging.color('red', w) if in_span else w
                        for w, in_span in zip(
                            tokens[idx:(idx + split_len)], is_in_span[idx:(
                                idx + split_len)])))
            print()

            for span, prob in sorted(zip(spans, probs)):
                if prob is None:
                    continue
                rel_prob, word_prob = prob
                l = max(0, span.start - n_context)
                r = min(batch.ntokens, span.end + 1 + n_context)
                print(
                    f"[{span.start}, {span.end}]"
                    f" <{dataset.rel_vocab.i2w[span.rel_typ]}>"
                    f" {Logging.color('red', rels[span.rel_idx].obj_alias[span.surface_idx])}"
                    f"{' (alias)' if span.surface_idx > 0 else ''}"
                    f": rel = {color_if_less(math.exp(-rel_prob), ppl_threshold)}"
                    f", word = {color_if_less(math.exp(-word_prob), ppl_threshold)}"
                )
                print(
                    '    ', '... ' if l > 0 else '', ' '.join(
                        Logging.color('red', tokens[idx])
                        if span.start <= idx <= span.end else tokens[idx]
                        for idx in range(l, r)),
                    ' ...' if r < batch.ntokens else '')
            print()
Exemple #7
0
    def posterior_log_probs(example,
                            _,
                            bptt_size=140,
                            n_context=5,
                            max_segments=-1,
                            ignore_single_relation=False,
                            file=sys.stdout):
        from collections import defaultdict
        from dataset.utils import flip_batches, search_paths

        (init_batch, batches) = dataset.create_one_batch([example], bptt_size)
        flipped_batches = flip_batches(batches)

        word_prob: Dict[str, List[np.ndarray]] = {
            k: []
            for k in ["forward", "backward"]
        }
        marginal_prob: Dict[str, List[np.ndarray]] = {
            k: []
            for k in ["forward", "backward"]
        }
        posterior_probs: List[Dict[MatchedSpan, Tuple[float, float]]] = [
        ]  # list(n_batches) of list(n_spans)
        seq_loss: List[float] = []

        def callback(loss: Tensor, _batch: BatchSequence):
            posterior_probs.append(model.model_cache['posterior_log_probs'][0])
            word_prob['forward'].append(
                model.model_cache['target_cond_log_probs'][0])
            marginal_prob['forward'].append(
                model.model_cache['stacked_log_probs'][0])
            seq_loss.append(loss.item())

        def callback_flip(_loss: Tensor, _batch: BatchSequence):
            marginal_prob['backward'].append(
                model.model_cache['stacked_log_probs'][0])

        with torch.no_grad():
            model.eval()
            compute_batch_loss(model,
                               init_batch,
                               batches,
                               use_unk_probs=True,
                               callback=callback,
                               evaluate=True,
                               calc_loss_kwargs={'dump_posterior_probs': True})

            compute_batch_loss(model,
                               init_batch,
                               flipped_batches,
                               use_unk_probs=True,
                               callback=callback_flip,
                               evaluate=True,
                               calc_loss_kwargs={'dump_posterior_probs': True})

        # credit: http://bayesjumping.net/log-sum-exp-trick/ for without using scipy
        def log_sum_exp(ns: List[int]):
            max_ = np.max(ns)
            sum_exp = np.exp(ns - max_).sum()
            return max_ + np.log(sum_exp)

        n_words = 0
        for seq_idx, (batch,
                      probs_dict) in enumerate(zip(batches, posterior_probs)):
            if max_segments != -1 and seq_idx >= max_segments:
                break

            tokens = batch.raw_sequence[0][1:]
            spans = batch.spans[0]
            if len(spans) == 0:
                continue
            overlap_group = defaultdict(list)
            sorted_spans = sorted(spans, key=lambda x: (x.start, x.end))
            latest = (sorted_spans[0].start, sorted_spans[0].end
                      )  # The most recent span group
            overlap_group[latest] = [sorted_spans[0]]
            for sp in sorted_spans[1:]:
                if sp.start > sp.end or sp.end >= batch.ntokens:
                    continue
                if sp.start <= latest[1]:
                    grp = overlap_group[latest]
                    del overlap_group[latest]
                    latest = (latest[0], max(latest[1], sp.end))
                    overlap_group[latest] = grp + [sp]
                else:
                    latest = (sp.start, sp.end)
                    overlap_group[latest] = [sp]

            if np.any([len(g) > 1 for g in overlap_group.values()
                       ]) or not ignore_single_relation:
                print(Logging.color(
                    'green', f"Segment #{seq_idx}: "
                    f"words {n_words} - {n_words + batch.ntokens}, "
                    f"ppl = {math.exp(seq_loss[seq_idx]):.4f}"),
                      file=file)
                n_words += batch.ntokens

            for span, group in overlap_group.items():
                if ignore_single_relation and len(group) == 1:
                    continue

                alpha = marginal_prob["forward"][seq_idx][span[0] - 1]
                beta = marginal_prob["backward"][::-1][seq_idx][
                    batch.lengths[0] - (span[1] + 1) - 1]
                # Enumerate all the paths
                paths = search_paths(group, span[0], span[1])
                log_probs = []
                annotations = []
                delimiters = []
                for path in paths:
                    path_anno = []
                    path_delims = []
                    logprob = alpha + beta
                    for hop in path:
                        if hop.rel_typ == -100:  # dummy relation - word transition
                            logprob += word_prob["forward"][seq_idx][span[0]]
                            path_anno.append("word")
                        else:
                            logprob += probs_dict[hop][0]
                            path_anno.append(
                                dataset.rel_vocab.i2w[hop.rel_typ])
                        path_delims += ["   "] * (hop.end - hop.start)
                        if hop.end < span[1]:
                            path_delims.append(" | ")

                    delimiters.append(path_delims)
                    log_probs.append(logprob)
                    annotations.append(path_anno)
                log_denom = log_sum_exp(log_probs)
                normalized_probs = [
                    np.exp(log_prob - log_denom) for log_prob in log_probs
                ]

                l = max(0, span[0] - n_context)
                r = min(batch.ntokens, span[1] + n_context)
                token_string = " ".join([
                    '    ', '... ' if l > 0 else '',
                    ' '.join(tokens[idx]
                             for idx in range(l, span[0])), '|', '   '.join(
                                 Logging.color('red', tokens[idx])
                                 for idx in range(span[0], span[1] + 1)), '|',
                    ' '.join(tokens[idx] for idx in range(span[1] + 1, r)),
                    ' ...' if r < batch.ntokens else ''
                ])
                print(token_string, file=file)

                annotation_strings = [" => ".join(a) for a in annotations]
                max_anno_len = max([len(a) for a in annotation_strings])
                max_score_idx = np.argmax(normalized_probs)

                for idx, (delim, anno, prob) in enumerate(
                        zip(delimiters, annotation_strings, normalized_probs)):
                    matched_span_tokens = (
                        " " * token_string.index(" | ") + " | " +
                        ' _ '.join(tokens[idx]
                                   for idx in range(span[0], span[1] + 1)) +
                        " | ")
                    delim_positions = re.finditer(r" _ ", matched_span_tokens)
                    for d, match in zip(delim, delim_positions):
                        pos = match.start(0)
                        matched_span_tokens = matched_span_tokens[:
                                                                  pos] + d + matched_span_tokens[
                                                                      (pos +
                                                                       3):]
                    score = Logging.color(
                        "green", f"  {prob:1.4f}"
                    ) if idx == max_score_idx else f"  {prob:1.4f}"
                    matched_span_tokens += "  " + f"{anno}{' ' * (max_anno_len - len(anno))}" + score
                    print(matched_span_tokens, file=file)

                print(file=file)
Exemple #8
0
def repl(dataset: KBLMDataset, model: BaseLM):
    import re
    from run import compute_batch_loss  # avoid circular import

    print(
        Logging.color(
            s="Execute `sample(name=\"Barack Obama\")` to generate samples for given entity.\n"
            "Execute `sample(split='test', index=1)` to generate samples for specific data entry.\n"
            "For more configurable settings, please refer to method `models.lrlm.LRLM.sampling_decode`.\n",
            col='green'))

    def get_example(func):
        def find_name(name: str) -> Tuple[str, int]:
            for split in dataset.data:
                try:
                    index = next(idx
                                 for idx, ex in enumerate(dataset.data[split])
                                 if name in ' '.join(ex.sentence[:10]))
                    return split, index
                except StopIteration:
                    continue
            else:
                raise ValueError("Name not found!")

        @functools.wraps(func)
        def wrapped(name: Optional[str] = None,
                    split: str = 'test',
                    index: int = 1,
                    start_symbol: str = '<s>',
                    **kwargs):
            if name is not None:
                try:
                    split, index = find_name(f"= {name} =")
                except ValueError:
                    split, index = find_name(name)
            if start_symbol is None:
                start_symbol = '<s>'
            start_symbol = dataset.word_vocab.w2i[start_symbol] if isinstance(
                start_symbol, str) else start_symbol
            example = dataset.data[split][index]
            try:
                name = ' '.join(
                    example.sentence[2:example.sentence.index('=', 2)])
            except ValueError:
                # probably WikiFacts
                pos = min((example.sentence.index(token)
                           for token in ['is', 'was', '(']
                           if token in example.sentence),
                          default=3)
                name = ' '.join(example.sentence[1:pos])
            if "file" not in kwargs:
                file = sys.stdout
            else:
                file = kwargs["file"]
            print(f"Data split {split}, index {index}, name \"{name}\"",
                  file=file)
            func(example, start_symbol, **kwargs)

        return wrapped

    end_symbol = dataset.word_vocab.w2i['</s>']

    @get_example
    def sample(example,
               start_symbol,
               max_length=500,
               n_tries=1,
               bptt_size=150,
               **kwargs):
        (init_batch, batches) = dataset.create_one_batch([example], bptt_size)
        if kwargs.get('generate_unk', False):
            fulli2w = [
                w for w, i in sorted(dataset.total_w2i.items(),
                                     key=lambda x: x[1])
            ]
            unkinfo = (dataset.unk_probs, fulli2w)
            del kwargs["generate_unk"]
        else:
            unkinfo = None

        best_output = None
        for _ in range(n_tries):
            print_info = kwargs.get('print_info', n_tries == 1)
            output: SampledOutput = model.sampling_decode(
                dataset.vocab,
                example,
                begin_symbol=
                start_symbol,  # this is the actual start symbol in all articles
                end_symbol=end_symbol,
                max_length=max_length,
                color_outputs=True,  # generate colored outputs in terminal
                print_info=print_info,
                unkinfo=unkinfo,
                init_batch=init_batch,
                **kwargs)
            if best_output is None or output.sample_loss < best_output.sample_loss:
                best_output = output

        # format title & subtitles
        if n_tries > 1:
            print(
                f"Sample loss: {best_output.sample_loss:.3f}, PPL: {math.exp(best_output.sample_loss):.3f}"
            )
            print(
                f"Complete / incomplete entities: {best_output.complete_copies} / {best_output.incomplete_copies}"
            )
        sentence = re.sub(r'= = (.*?) = = ', r'\n\n== \1 ==\n    ',
                          ' '.join(best_output.sentence))
        sentence = re.sub(r'= (.*?) = ', r'= \1 =\n    ', sentence)
        print(sentence)

    @get_example
    def average_copies(example,
                       start_symbol,
                       max_length=200,
                       n_tries=10,
                       progress=False,
                       show_samples=False,
                       **kwargs):
        complete_copies = utils.SimpleAverage()
        incomplete_copies = utils.SimpleAverage()

        if show_samples:
            sample_kwargs = dict(print_info=True,
                                 color_outputs=True,
                                 color_incomplete=False,
                                 **kwargs)
        else:
            sample_kwargs = dict(print_info=False,
                                 color_outputs=False,
                                 **kwargs)

        for _ in utils.progress(n_tries, verbose=progress):
            output: SampledOutput = model.sampling_decode(
                dataset.vocab,
                example,
                begin_symbol=start_symbol,
                end_symbol=end_symbol,
                max_length=max_length,
                **sample_kwargs)
            complete_copies.add(output.complete_copies)
            incomplete_copies.add(output.incomplete_copies)
            if show_samples:
                print(' '.join(output.sentence))
        if isinstance(example, LRLMExample):
            n_gold_rels = sum(
                int(span.end < max_length) for span in example.spans)
            print(
                f"Complete / Gold entities: {complete_copies.value()} / {n_gold_rels}"
            )
        else:
            print(
                f"Complete / Incomplete entities: {complete_copies.value()} / {incomplete_copies.value()}"
            )

    @get_example
    def logprob_article(example, _start_symbol, bptt_size=150, **_kwargs):
        (init_batch, batches) = dataset.create_one_batch([example], bptt_size)

        # log-prob calculation
        def callback(loss: Tensor, batch: BatchSequence, *_extra):
            if batch.spans is None:
                # NKLM
                rel_ids = np.concatenate(
                    [[-1], batch.seqs['rel_ids'][0].cpu().numpy(), [-1]])
                diff = np.ediff1d(rel_ids)
                n_rels = np.count_nonzero(np.cumsum(diff[diff != 0]))
            else:
                n_rels = len(batch.spans[0])
            print(f"{n_rels:2d}\t{math.exp(loss.item()):2.3f}")

        print("#rels\tPPL")
        with torch.no_grad():
            model.eval()
            total_loss = compute_batch_loss(model,
                                            init_batch,
                                            batches,
                                            use_unk_probs=True,
                                            callback=callback,
                                            evaluate=True)
        total_loss /= sum(batch.ntokens for batch in batches)
        print(f"Total PPL: {math.exp(total_loss)}")

    @get_example
    def posterior_log_probs(example,
                            _,
                            bptt_size=140,
                            n_context=5,
                            max_segments=-1,
                            ignore_single_relation=False,
                            file=sys.stdout):
        from collections import defaultdict
        from dataset.utils import flip_batches, search_paths

        (init_batch, batches) = dataset.create_one_batch([example], bptt_size)
        flipped_batches = flip_batches(batches)

        word_prob: Dict[str, List[np.ndarray]] = {
            k: []
            for k in ["forward", "backward"]
        }
        marginal_prob: Dict[str, List[np.ndarray]] = {
            k: []
            for k in ["forward", "backward"]
        }
        posterior_probs: List[Dict[MatchedSpan, Tuple[float, float]]] = [
        ]  # list(n_batches) of list(n_spans)
        seq_loss: List[float] = []

        def callback(loss: Tensor, _batch: BatchSequence):
            posterior_probs.append(model.model_cache['posterior_log_probs'][0])
            word_prob['forward'].append(
                model.model_cache['target_cond_log_probs'][0])
            marginal_prob['forward'].append(
                model.model_cache['stacked_log_probs'][0])
            seq_loss.append(loss.item())

        def callback_flip(_loss: Tensor, _batch: BatchSequence):
            marginal_prob['backward'].append(
                model.model_cache['stacked_log_probs'][0])

        with torch.no_grad():
            model.eval()
            compute_batch_loss(model,
                               init_batch,
                               batches,
                               use_unk_probs=True,
                               callback=callback,
                               evaluate=True,
                               calc_loss_kwargs={'dump_posterior_probs': True})

            compute_batch_loss(model,
                               init_batch,
                               flipped_batches,
                               use_unk_probs=True,
                               callback=callback_flip,
                               evaluate=True,
                               calc_loss_kwargs={'dump_posterior_probs': True})

        # credit: http://bayesjumping.net/log-sum-exp-trick/ for without using scipy
        def log_sum_exp(ns: List[int]):
            max_ = np.max(ns)
            sum_exp = np.exp(ns - max_).sum()
            return max_ + np.log(sum_exp)

        n_words = 0
        for seq_idx, (batch,
                      probs_dict) in enumerate(zip(batches, posterior_probs)):
            if max_segments != -1 and seq_idx >= max_segments:
                break

            tokens = batch.raw_sequence[0][1:]
            spans = batch.spans[0]
            if len(spans) == 0:
                continue
            overlap_group = defaultdict(list)
            sorted_spans = sorted(spans, key=lambda x: (x.start, x.end))
            latest = (sorted_spans[0].start, sorted_spans[0].end
                      )  # The most recent span group
            overlap_group[latest] = [sorted_spans[0]]
            for sp in sorted_spans[1:]:
                if sp.start > sp.end or sp.end >= batch.ntokens:
                    continue
                if sp.start <= latest[1]:
                    grp = overlap_group[latest]
                    del overlap_group[latest]
                    latest = (latest[0], max(latest[1], sp.end))
                    overlap_group[latest] = grp + [sp]
                else:
                    latest = (sp.start, sp.end)
                    overlap_group[latest] = [sp]

            if np.any([len(g) > 1 for g in overlap_group.values()
                       ]) or not ignore_single_relation:
                print(Logging.color(
                    'green', f"Segment #{seq_idx}: "
                    f"words {n_words} - {n_words + batch.ntokens}, "
                    f"ppl = {math.exp(seq_loss[seq_idx]):.4f}"),
                      file=file)
                n_words += batch.ntokens

            for span, group in overlap_group.items():
                if ignore_single_relation and len(group) == 1:
                    continue

                alpha = marginal_prob["forward"][seq_idx][span[0] - 1]
                beta = marginal_prob["backward"][::-1][seq_idx][
                    batch.lengths[0] - (span[1] + 1) - 1]
                # Enumerate all the paths
                paths = search_paths(group, span[0], span[1])
                log_probs = []
                annotations = []
                delimiters = []
                for path in paths:
                    path_anno = []
                    path_delims = []
                    logprob = alpha + beta
                    for hop in path:
                        if hop.rel_typ == -100:  # dummy relation - word transition
                            logprob += word_prob["forward"][seq_idx][span[0]]
                            path_anno.append("word")
                        else:
                            logprob += probs_dict[hop][0]
                            path_anno.append(
                                dataset.rel_vocab.i2w[hop.rel_typ])
                        path_delims += ["   "] * (hop.end - hop.start)
                        if hop.end < span[1]:
                            path_delims.append(" | ")

                    delimiters.append(path_delims)
                    log_probs.append(logprob)
                    annotations.append(path_anno)
                log_denom = log_sum_exp(log_probs)
                normalized_probs = [
                    np.exp(log_prob - log_denom) for log_prob in log_probs
                ]

                l = max(0, span[0] - n_context)
                r = min(batch.ntokens, span[1] + n_context)
                token_string = " ".join([
                    '    ', '... ' if l > 0 else '',
                    ' '.join(tokens[idx]
                             for idx in range(l, span[0])), '|', '   '.join(
                                 Logging.color('red', tokens[idx])
                                 for idx in range(span[0], span[1] + 1)), '|',
                    ' '.join(tokens[idx] for idx in range(span[1] + 1, r)),
                    ' ...' if r < batch.ntokens else ''
                ])
                print(token_string, file=file)

                annotation_strings = [" => ".join(a) for a in annotations]
                max_anno_len = max([len(a) for a in annotation_strings])
                max_score_idx = np.argmax(normalized_probs)

                for idx, (delim, anno, prob) in enumerate(
                        zip(delimiters, annotation_strings, normalized_probs)):
                    matched_span_tokens = (
                        " " * token_string.index(" | ") + " | " +
                        ' _ '.join(tokens[idx]
                                   for idx in range(span[0], span[1] + 1)) +
                        " | ")
                    delim_positions = re.finditer(r" _ ", matched_span_tokens)
                    for d, match in zip(delim, delim_positions):
                        pos = match.start(0)
                        matched_span_tokens = matched_span_tokens[:
                                                                  pos] + d + matched_span_tokens[
                                                                      (pos +
                                                                       3):]
                    score = Logging.color(
                        "green", f"  {prob:1.4f}"
                    ) if idx == max_score_idx else f"  {prob:1.4f}"
                    matched_span_tokens += "  " + f"{anno}{' ' * (max_anno_len - len(anno))}" + score
                    print(matched_span_tokens, file=file)

                print(file=file)

    @get_example
    def span_log_probs(example,
                       _,
                       bptt_size=140,
                       ppl_threshold=200.0,
                       split_len=20,
                       n_context=5,
                       max_segments=-1):
        (init_batch, batches) = dataset.create_one_batch([example], bptt_size)
        rels: List[Relation] = init_batch[0]

        posterior_probs: List[List[Optional[Tuple[float, float]]]] = [
        ]  # list(n_batches) of list(n_spans)
        seq_loss: List[float] = []

        # noinspection PyShadowingNames
        def callback(loss: Tensor, batch: BatchSequence) -> None:
            assert batch.spans is not None
            probs_dict: Dict[MatchedSpan, Tuple[
                float, float]] = model.model_cache['posterior_log_probs'][0]
            probs = [probs_dict.get(span, None) for span in batch.spans[0]]
            posterior_probs.append(probs)
            seq_loss.append(loss.item())

        def color_if_less(val: float,
                          threshold: float,
                          format_str: str = '{:.4f}',
                          color: str = 'yellow'):
            s = format_str.format(val)
            return Logging.color(color, s) if val < threshold else s

        with torch.no_grad():
            model.eval()
            compute_batch_loss(model,
                               init_batch,
                               batches,
                               use_unk_probs=True,
                               callback=callback,
                               evaluate=True,
                               calc_loss_kwargs={'dump_posterior_probs': True})

        n_words = 0
        for seq_idx, (batch, probs) in enumerate(zip(batches,
                                                     posterior_probs)):
            if max_segments != -1 and seq_idx >= max_segments:
                break
            print(
                Logging.color(
                    'green', f"Segment #{seq_idx}: "
                    f"words {n_words} - {n_words + batch.ntokens}, "
                    f"ppl = {math.exp(seq_loss[seq_idx]):.4f}"))
            n_words += batch.ntokens

            tokens = batch.raw_sequence[0][1:]
            spans = batch.spans[0]
            is_in_span = [False] * batch.ntokens
            for span in spans:
                if span.start > span.end or span.end >= batch.ntokens:
                    continue
                is_in_span[span.start:(
                    span.end + 1)] = [True] * (span.end - span.start + 1)
            for idx in range(0, batch.ntokens, split_len):
                print(
                    f'{idx:3d}:', ' '.join(
                        Logging.color('red', w) if in_span else w
                        for w, in_span in zip(
                            tokens[idx:(idx + split_len)], is_in_span[idx:(
                                idx + split_len)])))
            print()

            for span, prob in sorted(zip(spans, probs)):
                if prob is None:
                    continue
                rel_prob, word_prob = prob
                l = max(0, span.start - n_context)
                r = min(batch.ntokens, span.end + 1 + n_context)
                print(
                    f"[{span.start}, {span.end}]"
                    f" <{dataset.rel_vocab.i2w[span.rel_typ]}>"
                    f" {Logging.color('red', rels[span.rel_idx].obj_alias[span.surface_idx])}"
                    f"{' (alias)' if span.surface_idx > 0 else ''}"
                    f": rel = {color_if_less(math.exp(-rel_prob), ppl_threshold)}"
                    f", word = {color_if_less(math.exp(-word_prob), ppl_threshold)}"
                )
                print(
                    '    ', '... ' if l > 0 else '', ' '.join(
                        Logging.color('red', tokens[idx])
                        if span.start <= idx <= span.end else tokens[idx]
                        for idx in range(l, r)),
                    ' ...' if r < batch.ntokens else '')
            print()

    from IPython import embed
    embed()
Exemple #9
0
    def sampling_decode(self, vocab: Dict[str, Vocab], example: NKLMExample,
                        begin_symbol: int = 2, end_symbol: int = 5,
                        initial_hidden: Optional[HiddenState] = None, warm_up: Optional[int] = None,
                        max_length: int = 200, greedy=False, topk=None,
                        fill_incomplete=False, allow_invalid_pos=False,
                        print_info=True, color_outputs=False, color_incomplete=True,
                        show_ellipses=True, show_rel_type=True, show_copy_pos=False,
                        sanity_check=False, unkinfo: Optional[Tuple[Tensor, List[str]]] = None, **kwargs) \
            -> SampledOutput:
        """
        Sampling for NKLM.

        Output format:
        - Red words:       Copied from canonical form of entity.
        - Green words:     Copied from alias form of entity.
        - Yellow words:    Warm-up context.
        - word_[type]:     "word" is an entity of type "type".
        - word...(a_b_c):  "word" is a partially copied entity with remaining suffix "a b c".
        - (a_b_c)...word:  "word" is a partially copied entity with remaining prefix "a b c".
        - @-@:             A dash in the original text without spaces around, e.g. M @-@ 82 => M-82.
        - <X>:             A token copied from an invalid position of an entity.

        :param vocab: Vocabulary containing id2word mapping.
        :param example: The :class:`Example` object of the current topic.
        :param begin_symbol: Start of sentence symbol.
        :param end_symbol: End of sentence symbol. Sampling stops when this symbol is generated.
        :param initial_hidden: If not specified, default hidden states returned by :meth:`init_hidden` is used.
        :param warm_up: Number of tokens to provide as context before performing sampling.
        :param max_length: If generated sentence exceeds specified length, sampling is force terminated.
        :param greedy: If ``True``, use greedy decoding instead of sampling.
        :param topk: If not ``None``, only sample from indices with top-k probabilites.
        :param fill_incomplete: If ``True``, entities that are partially copied will be completed.
        :param allow_invalid_pos: If ``True``, allowing copying from invalid positions, and use <unk> as input.

        :param print_info: If ``True``, print information about sampled result.
        :param color_outputs: If ``True``, include annotations for each output token. Tokens from entities will be
            colored red.
        :param color_incomplete: If ``True`` and ``color_outputs`` is ``True``, also color partially copied entities.
        :param show_ellipses: If ``True``, show ellipses at beginning or end of partially copied entities.
        :param show_rel_type: If ``True``, show relation types for copied entities.
        :param show_copy_pos: If ``True``, show the position from which the entity tokens are copied.

        :param sanity_check: If ``True``, perform sanity check on generated sample.

        :return: A tuple of (loss_value, formatted list of words).
        """

        if unkinfo is not None:
            unkprob, unki2w = unkinfo
            unkprob = unkprob[self._vocab_size:]
            unki2w = unki2w[self._vocab_size:]
            normalized_unkprob = F.log_softmax(unkprob, dim=0)

        # noinspection PyPep8Naming
        UNK, INVALID, UNK_TOKEN, CANONICAL_IDX, EPS = -100, -1, 0, 0, 1e-4

        self.eval()
        self.init_hidden(1, [example.relations])

        word_vocab, rel_vocab = vocab['word'], vocab['rel']

        tensor = functools.partial(sample_utils.tensor, device=self.device)
        randint = sample_utils.randint
        sample = functools.partial(sample_utils.sample,
                                   greedy=greedy,
                                   topk=topk)
        np_sample = functools.partial(sample_utils.np_sample,
                                      greedy=greedy,
                                      topk=topk)

        # noinspection PyShadowingNames
        def compute_loss(
                inputs: List[int],
                rel_ids: List[int],
                copy_pos: List[int],
                surface_indices: List[int],
                hidden: Optional[HiddenState] = None
        ) -> Tuple[float, HiddenState]:
            batch = SimpleNamespace(
                sequence=tensor(inputs[:-1]),
                target=tensor(inputs[1:]),
                unkprob=None,
                seqs={
                    'rel_ids': tensor(rel_ids),
                    'copy_pos': tensor(copy_pos),
                    'surface_indices': tensor(surface_indices)
                },
                ntokens=len(inputs) - 1,
            )
            loss, next_hidden = self.calc_loss(batch,
                                               hidden=hidden)  # type: ignore
            return loss.item(), next_hidden

        # Initialization
        if warm_up is None:
            inputs = [begin_symbol]
            rel_ids = [INVALID]
            copy_pos = [INVALID]
            surface_indices = [INVALID]
            total_log_prob = 0.0
            hidden = initial_hidden
        else:
            inputs = list(word_vocab.numericalize(example.sentence[:warm_up]))
            rel_ids = list(example.rel_ids[:warm_up])
            copy_pos = list(example.copy_pos[:warm_up])
            surface_indices = list(example.surface_indices[:warm_up])
            total_log_prob, hidden = compute_loss(inputs, rel_ids, copy_pos,
                                                  surface_indices,
                                                  initial_hidden)
            total_log_prob = -total_log_prob * (len(inputs) - 1)

        # Sampling procedure
        while len(inputs) < max_length and inputs[-1] != end_symbol:
            fact_log_probs, output, _, next_hidden = \
                self._compute_fact_log_probs(tensor(inputs[-1]), tensor(rel_ids[-1]), tensor(copy_pos[-1]), hidden)
            rel_id, fact_loss = sample(fact_log_probs[0])
            rel_id -= 1
            total_log_prob += fact_loss

            # next_fact_embed: (1, 1, fact_embed_dim)
            next_fact_embed = self._get_fact_embeds(tensor(rel_id))
            copy_indicator, alias_log_probs, pos_log_probs, vocab_log_probs = \
                self._compute_generate_log_probs(output, next_fact_embed, tensor([rel_ids[-1], rel_id]))
            if torch.bernoulli(copy_indicator).item():
                total_log_prob += torch.log(copy_indicator).item()
                assert rel_id != -1
                # copy entity
                aliases = example.relations[rel_id].obj_alias
                if self._alias_disamb is AliasDisamb.FastText:
                    assert alias_log_probs is not None
                    surface_idx, surface_loss = sample(alias_log_probs[0])
                else:
                    surface_idx, surface_loss = 0, 0.0
                alias = self.alias_list[aliases[surface_idx]]

                entity: List[str] = alias.split()
                # normalization not required
                # TODO: keep consistent with _mask_invalid_pos setting
                pos, pos_loss = sample(pos_log_probs if allow_invalid_pos else
                                       pos_log_probs.squeeze()[:len(entity)])
                if self._mask_invalid_pos:
                    pos_loss -= torch.logsumexp(
                        pos_log_probs.squeeze()[:len(entity)], dim=0).item()
                total_log_prob += surface_loss + pos_loss
                token = UNK_TOKEN if pos >= len(
                    entity) else word_vocab.w2i.get(entity[pos], UNK_TOKEN)
            else:
                total_log_prob += torch.log(1.0 - copy_indicator).item()
                assert rel_id == -1
                # generate word
                token, token_loss = sample(vocab_log_probs)
                total_log_prob += token_loss
                pos = INVALID
                surface_idx = INVALID

                if token == 0 and unkinfo is not None:  # unk
                    unk_idx, unk_loss = np_sample(normalized_unkprob)
                    total_log_prob += unk_loss
                    # Ugly multi-purpose use of variables.
                    surface_idx = unk_idx  # Record unk word index in surface_indices.
                    rel_id = UNK  # Record UNK in rel_ids.

            inputs.append(token)
            rel_ids.append(rel_id)
            copy_pos.append(pos)
            surface_indices.append(surface_idx)
            hidden = next_hidden

        sample_loss = -total_log_prob / (len(inputs) - 1)
        if print_info:
            print(
                f"Sample loss: {sample_loss:.3f}, PPL: {math.exp(sample_loss):.3f}"
            )
        # Sanity checks
        if sanity_check:
            loss_val, gold_hidden = compute_loss(inputs, rel_ids, copy_pos,
                                                 surface_indices,
                                                 initial_hidden)
            assert hidden is not None
            hidden_state_diff = max(
                torch.max(torch.abs(g - h)).item()
                for g, h in zip(gold_hidden, hidden))
            if hidden_state_diff > EPS:
                Logging.warn(
                    f"Hidden states do not match. Difference: {hidden_state_diff}"
                )
            if abs(sample_loss - loss_val) > EPS:
                Logging.warn(
                    f"Loss values do not match. "
                    f"Forward loss: {loss_val}, difference: {abs(sample_loss - loss_val)}"
                )

        # Format the output
        sentence = list(zip(inputs, rel_ids, copy_pos, surface_indices))
        words = []
        copy_count = 0
        complete_count = 0
        last_entity = None
        entity_continuing = False
        for idx, (token, rel_id, pos, surface_idx) in enumerate(sentence):
            is_warm_up = (warm_up is not None and idx < warm_up)
            if rel_id == INVALID:
                word = word_vocab.i2w[token]
            elif rel_id == UNK:
                word = Logging.color('blue', unki2w[surface_idx])
            else:
                copy_count += 1
                entity_id = example.relations[rel_id].obj_alias[surface_idx]
                entity = self.alias_list[entity_id].split()
                if pos >= len(entity):
                    word = "<X>"
                else:
                    word = entity[pos]
                if show_copy_pos:
                    word = f"{pos}_{rel_id}_{surface_idx}_{word}"
                is_last_word_in_entity = (idx == len(sentence) - 1
                                          or sentence[idx + 1][1:] !=
                                          (rel_id, pos + 1, surface_idx))
                is_first_word_in_entity = (idx == 0 or sentence[idx - 1][1:] !=
                                           (rel_id, pos - 1, surface_idx))
                # add entity tag after the last word
                if show_rel_type and is_last_word_in_entity:
                    word = f"{word}_[{rel_vocab.i2w[example.relations[rel_id].rel_typ]}]"

                # check whether fully copied
                if show_ellipses:
                    if pos < len(entity) - 1 and is_last_word_in_entity:
                        word = word + '...' + (
                            f"({'_'.join(entity[(pos + 1):])})"
                            if fill_incomplete else "")
                    if pos > 0 and is_first_word_in_entity:
                        word = (f"({'_'.join(entity[:pos])})"
                                if fill_incomplete else "") + '...' + word

                if entity_continuing:
                    if last_entity == (rel_id, surface_idx,
                                       pos - 1):  # Continuing
                        last_entity = (rel_id, surface_idx, pos)
                    else:
                        entity_continuing = False
                        last_entity = None

                if pos == 0 and not entity_continuing:  # reset
                    entity_continuing = True
                    last_entity = (rel_id, surface_idx, 0)

                if color_outputs and not is_warm_up and (color_incomplete
                                                         or entity_continuing):
                    word = Logging.color(
                        'red' if surface_idx == 0 else 'green', word)

                if pos == len(entity) - 1 and entity_continuing:  # commit
                    entity_continuing = False
                    complete_count += 1

            if color_outputs and is_warm_up:
                word = Logging.color('yellow', word)
            words.append(word)

        if print_info:
            print(f"Copied, Completed: {copy_count}, {complete_count}")
        sampled_output = SampledOutput(sentence=words,
                                       sample_loss=sample_loss,
                                       complete_copies=complete_count,
                                       incomplete_copies=copy_count)
        return sampled_output