Beispiel #1
0
def create_exp_dir(path: str, script_path: str, overwrite=False) -> str:
    """
    Create experiment directory, and return path to log file.
    """
    if os.path.exists(path):
        if not overwrite:
            print(
                Logging.color(
                    col='red',
                    s=f"The experiment name: {path} already exists. "
                    f"Training will not proceed without the `--overwrite` flag."
                ))
            sys.exit(1)
        else:
            # Radical
            print(
                Logging.color(col='green',
                              s=f"Overwriting the experiment: {path} ..."))
            shutil.rmtree(path)

    os.mkdir(path)
    shutil.copy(script_path, path)
    logfile = os.path.join(
        path, f"{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}.log")
    return logfile
Beispiel #2
0
 def _validate(self):
     rules = self.validate() or []
     for pattern, validator in rules:
         regex = re.compile(rf"^{pattern}$")
         matches = 0
         for k in dir(self):
             if not k.startswith('_') and regex.match(k):
                 matches += 1
                 v = getattr(self, k)
                 try:
                     result = validator(v)
                 except Exception:
                     raise ValidationError(k, v, validator.__name__)
                 else:
                     if not result:
                         raise ValidationError(k, v, validator.__name__)
         if matches == 0:
             Logging.warn(
                 f"regex \"{pattern}\" did not match any arguments")
Beispiel #3
0
    def sampling_decode(self, vocab: Dict[str, Vocab], example: LRLMExample,
                        begin_symbol: int = 2, end_symbol: int = 5,
                        initial_hidden: Optional[HiddenState] = None, warm_up: Optional[int] = None,
                        max_length: int = 200, greedy: bool = False, topk: Optional[int] = None,
                        print_info: bool = True, color_outputs: bool = False, init_batch=None, **_kwargs) \
            -> SampledOutput:
        tensor = functools.partial(sample_utils.tensor, device=self.device)
        sample = functools.partial(sample_utils.sample,
                                   greedy=greedy,
                                   topk=topk)

        self.eval()
        self.init_hidden(1, init_batch)

        if warm_up is None:
            inputs = [begin_symbol]
            hidden = initial_hidden
            total_log_prob = 0.0
        else:
            inputs = list(vocab["word"].numericalize(
                example.sentence[:warm_up]))
            total_log_prob, hidden = self.forward(tensor(inputs[:-1]),
                                                  target=tensor(inputs[1:]))
            total_log_prob = -torch.sum(total_log_prob).item() * (len(inputs) -
                                                                  1)

        while len(inputs) < max_length and inputs[-1] != end_symbol:
            # full copy of the forward pass, including dropouts. But they won't be applied due to .eval function.
            # Run LSTM over the word
            word_log_probs, new_hidden = self.forward(tensor(inputs[-1]),
                                                      hidden)
            word_id, word_log_prob = sample(word_log_probs)
            inputs.append(word_id)
            hidden = new_hidden
            total_log_prob += word_log_prob

        sample_loss = -total_log_prob / (len(inputs) - 1)
        if print_info:
            print(
                f"Sample loss: {sample_loss:.3f}, PPL: {math.exp(sample_loss):.3f}"
            )

        # Format the output
        words = [vocab["word"].i2w[token] for token in inputs]
        if color_outputs and warm_up is not None:
            words[:warm_up] = [
                Logging.color('yellow', w) for w in words[:warm_up]
            ]

        output = SampledOutput(sentence=words,
                               sample_loss=sample_loss,
                               complete_copies=0,
                               incomplete_copies=0)
        return output
Beispiel #4
0
    def _try_load_cache(self, path: Path) -> bool:
        r"""Try loading a cached dataset from the data directory.

        :param path: The path to data directory.
        :return: Whether loading was successful.
        """
        cache_dir = path / '_cached'
        if not cache_dir.exists():
            return False
        params_index_path = cache_dir / 'params_index.pkl'
        params_index: List[Dict[str, Any]] = loadpkl(params_index_path)
        params = self._get_params()
        index = next(
            (idx for idx, p in enumerate(params_index)
             if (cache_dir /
                 f'{idx}.pkl').exists() and self._compare_params(p, params)),
            -1)
        if index != -1:
            load_path = cache_dir / f'{index}.pkl'
            self.batches = loadpkl(load_path)
            self.ntokens = {
                split: sum(batch.ntokens for _, batches in dataset
                           for batch in batches)
                for split, dataset in self.batches.items()
            }
            LOGGER.info(
                f"Cached dataset loaded from {load_path}, with settings: {params}"
            )
            # check for excluded keys and warn in case of mismatch
            load_params = params_index[index]
            for key in self.EXCLUDE_KEYS:
                if key in params or key in load_params:
                    current = params.get(key, "<does not exist>")
                    loaded = load_params.get(key, "<does not exist>")
                    if current != loaded:
                        LOGGER.info(
                            Logging.color(
                                'red', f"Ignored data param '{key}' mismatch "
                                f"(current: {current}, loaded: {loaded})"))
            return True
        return False
Beispiel #5
0
    def _parse_type_spec(cls) -> Dict[str, _ArgTypeSpec]:
        """
        :return: A dict mapping argument names to their type-specs
        """
        _attr_name = '__type_dict__'
        if hasattr(cls, _attr_name):
            return getattr(cls, _attr_name)

        type_dict = {}

        # get annotations from the current class and all its base classes as well
        annotations = {}
        for base in reversed(cls.__mro__):
            if base not in [object, Arguments]:
                annotations.update(base.__dict__.get('__annotations__', {}))

        bad_names = []
        warn_names = []

        def check_name_conventions(name):
            if name.startswith('_') or name.endswith('_'):
                # names should not start or begin with underscores
                bad_names.append(name)
            if name != name.lower() or any(ord(c) >= 128 for c in name):
                # names are recommended to contain non-uppercase ASCII characters only
                warn_names.append(name)

        # check that all attributes are annotated, except for `Switch`es
        for arg_name in dir(cls):
            if arg_name.startswith('__') or cls._check_reserved(
                    arg_name):  # magic stuff
                continue
            arg_val = getattr(cls, arg_name)
            if isinstance(arg_val, Arguments.Switch):
                check_name_conventions(arg_name)
                # noinspection PyProtectedMember,PyCallByClass
                type_dict[arg_name.lower()] = Arguments._ArgTypeSpec(
                    Arguments.Switch,
                    nullable=False,
                    required=False,
                    default=arg_val._default)
            elif arg_name not in cls.__annotations__:
                raise ArgumentError(
                    f"Type is not specified for argument '{arg_name}'. "
                    f"Type annotation can omitted only when argument is a `Switch`."
                )

        # iterate over annotated values and generate type-specs
        for arg_name, arg_typ in annotations.items():
            if cls._check_reserved(arg_name):
                raise ArgumentError(
                    f"'{arg_name}' cannot be used as argument name because it is reserved."
                )

            check_name_conventions(arg_name)

            nullable = False
            # hacky check of whether `arg_typ` is `Optional`: `Optional` is `Union` with `type(None)`
            if getattr(arg_typ, '__origin__',
                       None) is Union and NoneType in arg_typ.__args__:
                nullable = True
                # extract the type wrapped inside `Optional`
                arg_typ = next(t for t in arg_typ.__args__
                               if not isinstance(t, NoneType))  # type: ignore

            arg_val = getattr(cls, arg_name, None)
            required = not hasattr(cls, arg_name) or (arg_val is None
                                                      and not nullable)
            type_dict[arg_name] = Arguments._ArgTypeSpec(arg_typ,
                                                         nullable=nullable,
                                                         required=required,
                                                         default=arg_val)

        if len(bad_names) > 0:
            bad_names_str = ', '.join(f"'{s}'" for s in bad_names)
            raise ArgumentError(f"Invalid argument names: {bad_names_str}. "
                                f"Names cannot begin or end with underscores.")
        if len(warn_names) > 0:
            warn_names_str = ', '.join(f"'{s}'" for s in warn_names)
            Logging.warn(
                f"Consider changing these argument names: {warn_names_str}. "
                f"Names are recommended to contain non-uppercase ASCII characters only."
            )

        setattr(cls, _attr_name, type_dict)
        return type_dict
Beispiel #6
0
    def __init__(self, *args, **kwargs) -> None:
        self._check_types()

        for k, v in kwargs.items():
            setattr(self, k, v)

        # TODO: Add non-null checks
        # TODO: Add "no-" prefix stuff for switches
        # TODO: Generate help by inspecting comments

        if len(args) == 0:
            argv = sys.argv
        elif len(args) == 1:
            argv = args[0]
        else:
            raise ValueError(
                f"Argument class takes zero or one positional arguments but {len(args)} were given"
            )
        i = 1
        while i < len(argv):
            arg: str = argv[i]
            if arg.startswith('--'):
                argname = arg[2:].replace('-', '_')
                if argname.startswith('no_') and not hasattr(
                        self, argname) and hasattr(self, argname[3:]):
                    attr = getattr(self, argname[3:])
                    if isinstance(attr, Arguments.Switch):
                        attr._value = False
                        i += 1
                        continue

                if hasattr(self, argname):
                    attr = getattr(self, argname)
                    if isinstance(attr, Arguments.Switch):
                        attr._value = True
                        i += 1
                        continue

                    nullable, typ = self._get_arg_type(argname)
                    argval: str = argv[i + 1]
                    if argval.lower() == 'none':
                        if nullable:
                            val = None
                        else:
                            assert typ is str or is_choices(typ), \
                                f"Cannot assign None to non-nullable, non-str argument '{argname}'"
                            val = argval
                    elif isinstance(typ,
                                    custom_types.NoneType):  # type: ignore
                        val = None  # just to suppress "ref before assign" warning
                        try:
                            # priority: low -> high
                            for target_typ in [str, float, int]:
                                val = target_typ(argval)
                        except ValueError:
                            pass
                    elif typ is str:
                        val = argval
                    elif isinstance(
                            typ,
                            custom_types.Path) or typ is custom_types.Path:
                        val = Path(argval)
                        if isinstance(typ, custom_types.Path) and typ.exists:
                            assert val.exists(), ValueError(
                                f"Argument '{argname}' requires an existing path, "
                                f"but '{argval}' does not exist")
                    elif is_choices(typ):
                        val = argval
                        assert val in typ.__values__, f"Invalid value '{val}' for argument '{arg}', " \
                                                      f"available choices are: {typ.__values__}"
                    elif issubclass(typ, Arguments.Enum):
                        # experimental support for custom enum
                        try:
                            # noinspection PyCallingNonCallable
                            val = typ(argval)
                        except ValueError:
                            valid_args = {x.value for x in typ}
                            raise ValueError(
                                f"Invalid value '{argval}' for argument '{argname}', "
                                f"available choices are: {valid_args}"
                            ) from None

                    elif typ is bool:
                        val = argval in ['true', '1', 'True', 'y', 'yes']
                    else:
                        try:
                            val = ast.literal_eval(argval)
                        except ValueError:
                            raise ValueError(
                                f"Invalid value '{argval}' for argument '{argname}'"
                            ) from None
                    setattr(self, argname, val)
                    i += 2
                else:
                    raise ValueError(f"Invalid argument: '{arg}'")
            else:
                Logging.warn(f"Unrecognized command line argument: '{arg}'")
                i += 1

        if self.pdb:
            # enter IPython debugger on exception
            from IPython.core import ultratb
            ipython_hook = ultratb.FormattedTB(mode='Context',
                                               color_scheme='Linux',
                                               call_pdb=1)

            def excepthook(type, value, traceback):
                if type is KeyboardInterrupt:
                    # don't capture keyboard interrupts (Ctrl+C)
                    sys.__excepthook__(type, value, traceback)
                else:
                    ipython_hook(type, value, traceback)

            sys.excepthook = excepthook

        self.preprocess()

        # check whether non-optional attributes are none
        for arg in dir(self):
            if not arg.startswith('_') and arg not in self._reserved_keys:
                attr = getattr(self, arg)
                nullable, _ = self._get_arg_type(arg)
                if attr is None and not nullable:
                    raise ValueError(f"argument '{arg}' cannot be none")

        self._validate()
        self.postprocess()

        # convert switches to bool
        for arg in dir(self):
            if not arg.startswith('_') and arg not in self._reserved_keys:
                attr = getattr(self, arg)
                typ = self.__annotations__.get(arg, None)
                if isinstance(attr, Arguments.Switch):
                    # noinspection PyProtectedMember
                    setattr(self, arg, bool(attr))
                if isinstance(typ, type) and issubclass(
                        typ, Path) and isinstance(attr, str):
                    setattr(self, arg, Path(attr))
Beispiel #7
0
    def __init__(self, **kwargs) -> None:
        self._check_types()

        for k, v in kwargs.items():
            setattr(self, k, v)

        # TODO: Add non-null checks
        # TODO: Add "no-" prefix stuff for switches
        # TODO: Generate help by inspecting comments

        i = 1
        while i < len(sys.argv):
            arg: str = sys.argv[i]
            if arg.startswith('--'):
                argname = arg[2:].replace('-', '_')
                if argname.startswith('no_') and not hasattr(
                        self, argname) and hasattr(self, argname[3:]):
                    attr = getattr(self, argname[3:])
                    if isinstance(attr, Arguments.Switch):
                        attr._value = False
                        i += 1
                        continue

                if hasattr(self, argname):
                    attr = getattr(self, argname)
                    if isinstance(attr, Arguments.Switch):
                        attr._value = True
                        i += 1
                        continue

                    typ = self.__annotations__.get(argname, type(attr))
                    nullable = False
                    # TODO: hacks here
                    if hasattr(
                            typ,
                            '__origin__') and typ.__origin__ == Union and type(
                                None) in typ.__args__:
                        # hacky check of whether `typ` is `Optional`
                        nullable = True
                        typ = next(t for t in typ.__args__
                                   if not isinstance(t, custom_types.NoneType)
                                   )  # type: ignore
                    argval: str = sys.argv[i + 1]
                    if argval.lower() == 'none':
                        if nullable:
                            val = None
                        else:
                            assert typ is str or is_choices(typ), \
                                f"Cannot assign None to non-nullable, non-str argument '{argname}'"
                            val = argval
                    elif isinstance(typ,
                                    custom_types.NoneType):  # type: ignore
                        val = None  # just to suppress "ref before assign" warning
                        try:
                            # priority: low -> high
                            for target_typ in [str, float, int]:
                                val = target_typ(argval)
                        except ValueError:
                            pass
                    elif typ is str:
                        val = argval
                    elif isinstance(
                            typ,
                            custom_types.Path) or typ is custom_types.Path:
                        val = Path(argval)
                        if isinstance(typ, custom_types.Path) and typ.exists:
                            assert val.exists(), ValueError(
                                f"Argument '{argname}' requires an existing path, "
                                f"but '{argval}' does not exist")
                    elif is_choices(typ):
                        val = argval
                        assert val in typ.__values__, f"Invalid value '{val}' for argument '{arg}', " \
                            f"available choices are: {typ.__values__}"
                    elif issubclass(Arguments.Enum, typ):
                        # experimental support for custom enum
                        try:
                            # noinspection PyCallingNonCallable
                            val = typ(argval)
                        except ValueError:
                            valid_args = {x.value for x in typ}
                            raise ValueError(
                                f"Invalid value '{argval}' for argument '{argname}', "
                                f"available choices are: {valid_args}"
                            ) from None

                    elif typ is bool:
                        val = argval in ['true', '1', 'True', 'y', 'yes']
                    else:
                        try:
                            val = ast.literal_eval(argval)
                        except ValueError:
                            raise ValueError(
                                f"Invalid value '{argval}' for argument '{argname}'"
                            ) from None
                    setattr(self, argname, val)
                    i += 2
                else:
                    raise ValueError(f"Invalid argument: '{arg}'")
            else:
                Logging.warn(f"Unrecognized command line argument: '{arg}'")
                i += 1

        if self.ipdb:
            # enter IPython debugger on exception
            from IPython.core import ultratb
            sys.excepthook = ultratb.FormattedTB(mode='Context',
                                                 color_scheme='Linux',
                                                 call_pdb=1)

        self.preprocess()
        self._validate()
        self.postprocess()
Beispiel #8
0
    def sampling_decode(self, vocab: Dict[str, Vocab], example: LRLMExample,
                        begin_symbol: int = 2, end_symbol: int = 5,
                        initial_hidden: Optional[HiddenState] = None, warm_up: Optional[int] = None,
                        max_length: int = 200, greedy: bool = False, topk: Optional[int] = None,
                        print_info: bool = True, color_outputs: bool = False, show_rel_type: bool = True,
                        sanity_check: bool = False, unkinfo: Optional[Tuple[Tensor, List[str]]] = None, **kwargs) \
            -> SampledOutput:
        r"""
        Sampling for LRLM.

        Output format:
        - Red words:       Copied from canonical form of entity.
        - Green words:     Copied from alias form of entity.
        - Yellow words:    Warm-up context.
        - word_[type]:     "word" is an entity of type "type".
        - @-@:             A dash in the original text without spaces around, e.g. M @-@ 82 => M-82.

        :param vocab: Vocabulary containing id2word mapping.
        :param example: The :class:`Example` object of the current topic.
        :param begin_symbol: Start of sentence symbol.
        :param end_symbol: End of sentence symbol. Sampling stops when this symbol is generated.
        :param initial_hidden: If not specified, default hidden states returned by :meth:`init_hidden` is used.
        :param warm_up: Number of tokens to provide as context before performing sampling.
        :param max_length: If generated sentence exceeds specified length, sampling is force terminated.
        :param greedy: If ``True``, use greedy decoding instead of sampling.
        :param topk: If not ``None``, only sample from indices with top-k probabilites.

        :param print_info: If ``True``, print information about sampled result.
        :param color_outputs: If ``True``, include annotations for each output token. Tokens from entities will be
            colored red.
        :param show_rel_type: If ``True``, show relation types for copied entities.

        :param sanity_check: If ``True``, perform sanity check on generated sample.

        :param unkinfo: Precomputed unkprobs and the index-to-vocabulary mapping.

        :return: A tuple of (loss_value, formatted list of words).
        """
        if unkinfo is not None:
            unkprob, unki2w = unkinfo
            unkprob = unkprob[self._vocab_size:]
            unki2w = unki2w[self._vocab_size:]
            normalized_unkprob = F.log_softmax(unkprob, dim=0)

        # noinspection PyPep8Naming
        UNK, INVALID, CANONICAL_IDX, WORD_PREDICTOR, REL_PREDICTOR, EPS = -100, -1, 0, 0, 1, 1e-4

        self.eval()
        self.init_hidden(1, [example.relations])

        word_vocab, rel_vocab = vocab['word'], vocab['rel']

        tensor = functools.partial(sample_utils.tensor, device=self.device)
        sample = functools.partial(sample_utils.sample,
                                   greedy=greedy,
                                   topk=topk)
        np_sample = functools.partial(sample_utils.np_sample,
                                      greedy=greedy,
                                      topk=topk)

        # noinspection PyShadowingNames
        def compute_loss(
                inputs: List[int],
                spans: List[MatchedSpan],
                hidden: Optional[HiddenState] = None
        ) -> Tuple[float, HiddenState]:
            batch = SimpleNamespace(
                sequence=tensor(inputs[:-1]),
                target=tensor(inputs[1:]),
                spans=[spans],
                unkprob=None,
                lengths=torch.tensor([len(inputs) - 1], device=self.device),
                ntokens=len(inputs) - 1,
            )
            loss, next_hidden = self.calc_loss(batch,
                                               hidden=hidden)  # type: ignore
            return loss.item(), next_hidden

        if warm_up is None:
            inputs = [begin_symbol]
            rel_ids = [INVALID]
            surface_indices = [INVALID]
            spans: List[MatchedSpan] = []
            total_log_prob = 0.0
            marginal_log_prob = 0.0
            hidden = initial_hidden
        else:
            inputs = list(word_vocab.numericalize(example.sentence[:warm_up]))
            rel_ids = [INVALID] * len(
                inputs)  # assume everything is generated from vocabulary
            surface_indices = [INVALID] * len(inputs)
            spans = [span for span in example.spans if span.end < warm_up]
            loss, hidden = compute_loss(inputs, spans, initial_hidden)
            total_log_prob = -loss * (len(inputs) - 1)
            marginal_log_prob = -loss * (len(inputs) - 1)

        while len(inputs) < max_length and inputs[-1] != end_symbol:
            computed_log_probs, new_hidden = self._compute_log_probs(
                tensor(inputs[-1]), hidden)
            predictor, selector_loss = sample(computed_log_probs.selector)

            if predictor == REL_PREDICTOR:
                rel_id, rel_loss = sample(computed_log_probs.rel[0])

                if self._alias_disamb is AliasDisamb.FastText:
                    assert computed_log_probs.alias_logits is not None
                    aliases = example.relations[rel_id].obj_alias
                    alias_vecs = self.alias_vec[aliases]
                    surface_log_prob = F.log_softmax(torch.mv(
                        alias_vecs, computed_log_probs.alias_logits.flatten()),
                                                     dim=0)
                    surface_idx, alias_loss = sample(surface_log_prob)
                    alias = self.alias_list[aliases[surface_idx]]
                else:
                    # can't tell which one under oracle, use the canonical (first) alias
                    surface_idx = 0
                    alias_loss = 0.0
                    alias = example.relations[rel_id].obj_alias[
                        0]  # type: ignore

                # forward the hidden state according to the generated in-vocab tokens
                raw_tokens: List[str] = alias.split()
                token_ids: List[int] = word_vocab.numericalize(raw_tokens)
                if len(raw_tokens) > 1:
                    _, new_hidden = self._compute_log_probs(
                        tensor(token_ids[:-1]), new_hidden)

                # compute marginal probability for current span
                span_inputs = tensor([inputs[-1]] + token_ids[:-1])
                span_computed_log_probs, _ = self._compute_log_probs(
                    span_inputs, hidden)
                word_gen_loss = torch.sum(
                    span_computed_log_probs.selector[0, :, WORD_PREDICTOR] +
                    torch.gather(span_computed_log_probs.word,
                                 index=tensor(token_ids).unsqueeze(-1),
                                 dim=2).flatten()).item()
                marginal_log_prob += torch.logsumexp(tensor(
                    [selector_loss + rel_loss + alias_loss, word_gen_loss]),
                                                     dim=1).item()

                spans.append(
                    MatchedSpan(
                        len(inputs) - 1,
                        len(inputs) + len(token_ids) - 1,
                        example.relations[rel_id].rel_typ, rel_id,
                        surface_idx))
                inputs.extend(token_ids)
                rel_ids.extend([rel_id] + [INVALID] * (len(token_ids) - 1))
                surface_indices.extend([surface_idx] + [INVALID] *
                                       (len(token_ids) - 1))

                total_log_prob += selector_loss + rel_loss + alias_loss
            elif predictor == WORD_PREDICTOR:
                word, word_loss = sample(computed_log_probs.word)
                total_log_prob += selector_loss + word_loss
                marginal_log_prob += selector_loss + word_loss

                if word == 0 and unkinfo is not None:  # unk
                    unk_idx, unk_loss = np_sample(normalized_unkprob)
                    total_log_prob += unk_loss
                    marginal_log_prob += unk_loss
                    # Ugly multi-purpose use of variables.
                    surface_indices.append(
                        unk_idx)  # Record unk word index in surface_indices.
                    rel_ids.append(UNK)  # Record UNK in rel_ids.
                else:
                    rel_ids.append(INVALID)
                    surface_indices.append(INVALID)

                inputs.append(word)
            else:
                raise ValueError

            hidden = new_hidden

        sample_loss = -total_log_prob / (len(inputs) - 1)
        marginal_loss = -marginal_log_prob / (len(inputs) - 1)
        if print_info:
            print(
                f"Sample loss: {sample_loss:.3f}, PPL: {math.exp(sample_loss):.3f}"
            )
            print(
                f"Marginal sample loss: {marginal_loss:.3f}, PPL: {math.exp(marginal_loss):.3f}"
            )
        # Sanity checks
        if sanity_check:
            # noinspection PyTypeChecker
            loss_val, gold_hidden = compute_loss(inputs, spans, initial_hidden)
            assert hidden is not None
            hidden_state_diff = max(
                torch.max(torch.abs(g - h)).item()
                for g, h in zip(gold_hidden, hidden))
            if hidden_state_diff > EPS:
                Logging.warn(
                    f"Hidden states do not match. Difference: {hidden_state_diff}"
                )
            if abs(marginal_loss - loss_val) > EPS:
                Logging.warn(
                    f"Marginal loss values do not match. "
                    f"Forward loss: {loss_val}, difference: {abs(marginal_loss - loss_val)}"
                )

        num_rels_generated = sum(int(rel_id != INVALID) for rel_id in rel_ids)
        if print_info:
            print(
                f"Relations [Generated / Annotated]: "
                f"[{num_rels_generated} / {len([s for s in example.spans if s.end < max_length])}]"
            )

        words = []
        idx = 0
        copy_count = 0
        while idx < len(inputs):
            is_warm_up = (warm_up is not None and idx < warm_up)
            token_id, rel_id, surface_idx = inputs[idx], rel_ids[
                idx], surface_indices[idx]
            if rel_id == INVALID:
                token = word_vocab.i2w[token_id]
                idx += 1
            elif rel_id == UNK:
                token = Logging.color('blue', unki2w[surface_idx])
                idx += 1
            else:
                copy_count += 1
                word_id = example.relations[rel_id].obj_alias[
                    surface_idx]  # multiple words
                token = self.alias_list[word_id]
                idx += len(token.split())
                if show_rel_type:
                    token = f"{token}_[{rel_vocab.i2w[example.relations[rel_id].rel_typ]}]"
                if color_outputs and not is_warm_up:
                    token = Logging.color(
                        'red' if surface_idx == CANONICAL_IDX else 'green',
                        token)
            if color_outputs and is_warm_up:
                token = Logging.color('yellow', token)
            words.append(token)

        if print_info:
            print(f"# of copied entities: {copy_count}")

        output = SampledOutput(sentence=words,
                               sample_loss=sample_loss,
                               complete_copies=copy_count,
                               incomplete_copies=0)
        return output
Beispiel #9
0
def main():
    Logging.verbosity_level = Logging.VERBOSE

    Logging.warn("This program requires lots of memory (preferably >= 30GB).")

    if not SAVE_DIR.exists():
        SAVE_DIR.mkdir(parents=True)

    # Read the Wikimedia IDs for each article, and filter the relations
    topic_ids: Set[WikidataID] = set()
    split_title_id: Dict[str, List[Tuple[str, WikidataID]]] = {}
    for split in ['train', 'valid', 'test']:
        with utils.work_in_progress(f"Loading {split} set titles"), \
             open(TOPIC_JSON_PATH(split=split)) as f:
            j = json.load(f)
        split_title_id[split] = [(article['title'], WikidataID(article['id']))
                                 for article in j]
        topic_ids.update([wid for _, wid in split_title_id[split]])
        del j

    with utils.work_in_progress("Loading Wikidata ID mapping"):
        id2rel = load_id2str(WIKIDATA_DUMP_DIR / 'properties.txt')

    # Match the relations
    matched_dataset = read_data(ALIGNED_DATA_DIR)

    # Gather entities & relation vectors
    found_entities = set()
    found_rels = set()
    for split in matched_dataset:
        for example in matched_dataset[split]:
            found_entities.add(example.topic_id)
            for rel in example.relations:
                found_entities.add(rel.obj_id)
                found_rels.add(rel.rel_typ)
    found_entities -= {UNK_ENTITY}
    found_rels -= {NAF, ANCHOR, TOPIC_ITSELF}
    with utils.work_in_progress("Building rel vecs"):
        rel_map = load_relations(found_rels)
        rel_map.update({NAF: -1, ANCHOR: -2, TOPIC_ITSELF: -3})
        unk_rels = found_rels.difference(rel_map)
        # NOTE: unk_rels is a set, its order is undetermined, so we sort it to make sure it's consistent between runs
        for idx, rel in enumerate(sorted(unk_rels)):
            rel_map[rel] = -idx - 4  # starting from -4, going towards -inf
    with utils.work_in_progress("Building entity vecs"):
        entity_map = load_entities(found_entities)
        entity_map.update({UNK_ENTITY: -1})
        print(
            f"Topic ID coverage: {len(topic_ids.intersection(entity_map))}/{len(topic_ids)}"
        )

    # save relation type names for use during generation
    id_to_rel_name = dict(id2rel)
    id_to_rel_name.update({
        NAF: 'Not-A-Fact',
        ANCHOR: 'ANCHOR',
        TOPIC_ITSELF: 'TITLE'
    })
    rel_names: Dict[int, str] = {}
    for r_rel, rel_id in rel_map.items():
        rel_names[rel_id] = id_to_rel_name[r_rel]
    with (SAVE_DIR / 'rel_names.pkl').open('wb') as f:
        pickle.dump(rel_names, f)
        print(f"Relation names saved to {(SAVE_DIR / 'rel_names.pkl')}")

    # Convert into numbers to create the final dataset
    for split in matched_dataset:
        with utils.work_in_progress(f"Converting {split} set"):
            dataset, matched_spans = numericalize_rel(matched_dataset[split],
                                                      rel_map, entity_map)

        path = SAVE_DIR / f'{split}.pkl'
        with path.open('wb') as f:
            pickle.dump(dataset, f)
        print(
            f"Dataset split '{split}' saved to {path}, {len(dataset)} examples"
        )

        path = SAVE_DIR / f'{split}.span.pkl'
        with path.open('wb') as f:
            pickle.dump(matched_spans, f)
        print(f"Matched spans split '{split}' saved to {path}")
Beispiel #10
0
def read_data(path: Path) -> Dict[str, List[RawExampleWikiID]]:
    bad_examples: List[Tuple[str, int, str]] = []
    data = {}
    for split in ['train', 'valid', 'test']:
        with (path / f'{split}.pkl').open('rb') as f:
            # relation tuple: (span, rel_type_desc, name, canonical_name)
            with utils.work_in_progress(f"Loading {split} set"):
                dump: List[RawDump] = pickle.load(f)

            examples = []
            for idx, (sent, rels) in enumerate(
                    utils.progress(dump, desc='Reading data')):
                # map (rel_typ, canonical) to list of aliases, since lists aren't hashable
                rel_to_alias: Dict[Tuple[str, str], List[str]] = \
                    {(rel[0][0], obj_id): alias for obj_id, _, _, rel, _, alias in rels}

                # sort it so the order is consistent
                relations: List[RelationWikiID] = sorted([
                    RelationWikiID(WikidataID(rel_id), WikidataID(obj_id),
                                   obj_alias)
                    for (rel_id, obj_id), obj_alias in rel_to_alias.items()
                ])
                rel_to_id: Dict[Tuple[str, str], int] = {
                    (rel_id, obj_id): idx
                    for idx, (rel_id, obj_id,
                              obj_alias) in enumerate(relations)
                }
                # dedup to remove duplicate (-1, -1)
                mentions: List[EntityMention] = list(
                    set(
                        EntityMention(span, surface, rel_to_id[(rel_info[0][0],
                                                                obj_id)]) for
                        obj_id, head_id, span, rel_info, surface, _ in rels))
                try:
                    # must exist - head id with the relation: @TITLE@ is the topic WikidataID
                    topic_id = next(
                        head_id
                        for _, head_id, _, rel_info, surface, alias in rels
                        if rel_info[0][0] == "@TITLE@")
                except StopIteration:
                    bad_examples.append((split, idx, ' '.join(sent)[:100]))
                    continue

                converted_relations = []
                for r in relations:
                    converted_relations.append(
                        RelationWikiID(
                            TOPIC_ITSELF if r.rel_typ == "@TITLE@" else
                            r.rel_typ, r.obj, r.obj_alias))

                example = RawExampleWikiID(WikidataID(topic_id), sent,
                                           converted_relations, mentions)
                examples.append(example)
            data[split] = examples

    if len(bad_examples) > 0:
        Logging.warn(f"{len(bad_examples)} bad examples:\n"
                     f"{pprint.pformat(bad_examples)}")
    else:
        Logging.verbose("All examples are good")

    return data
Beispiel #11
0
 def color_if_less(val: float,
                   threshold: float,
                   format_str: str = '{:.4f}',
                   color: str = 'yellow'):
     s = format_str.format(val)
     return Logging.color(color, s) if val < threshold else s
Beispiel #12
0
    def span_log_probs(example,
                       _,
                       bptt_size=140,
                       ppl_threshold=200.0,
                       split_len=20,
                       n_context=5,
                       max_segments=-1):
        (init_batch, batches) = dataset.create_one_batch([example], bptt_size)
        rels: List[Relation] = init_batch[0]

        posterior_probs: List[List[Optional[Tuple[float, float]]]] = [
        ]  # list(n_batches) of list(n_spans)
        seq_loss: List[float] = []

        # noinspection PyShadowingNames
        def callback(loss: Tensor, batch: BatchSequence) -> None:
            assert batch.spans is not None
            probs_dict: Dict[MatchedSpan, Tuple[
                float, float]] = model.model_cache['posterior_log_probs'][0]
            probs = [probs_dict.get(span, None) for span in batch.spans[0]]
            posterior_probs.append(probs)
            seq_loss.append(loss.item())

        def color_if_less(val: float,
                          threshold: float,
                          format_str: str = '{:.4f}',
                          color: str = 'yellow'):
            s = format_str.format(val)
            return Logging.color(color, s) if val < threshold else s

        with torch.no_grad():
            model.eval()
            compute_batch_loss(model,
                               init_batch,
                               batches,
                               use_unk_probs=True,
                               callback=callback,
                               evaluate=True,
                               calc_loss_kwargs={'dump_posterior_probs': True})

        n_words = 0
        for seq_idx, (batch, probs) in enumerate(zip(batches,
                                                     posterior_probs)):
            if max_segments != -1 and seq_idx >= max_segments:
                break
            print(
                Logging.color(
                    'green', f"Segment #{seq_idx}: "
                    f"words {n_words} - {n_words + batch.ntokens}, "
                    f"ppl = {math.exp(seq_loss[seq_idx]):.4f}"))
            n_words += batch.ntokens

            tokens = batch.raw_sequence[0][1:]
            spans = batch.spans[0]
            is_in_span = [False] * batch.ntokens
            for span in spans:
                if span.start > span.end or span.end >= batch.ntokens:
                    continue
                is_in_span[span.start:(
                    span.end + 1)] = [True] * (span.end - span.start + 1)
            for idx in range(0, batch.ntokens, split_len):
                print(
                    f'{idx:3d}:', ' '.join(
                        Logging.color('red', w) if in_span else w
                        for w, in_span in zip(
                            tokens[idx:(idx + split_len)], is_in_span[idx:(
                                idx + split_len)])))
            print()

            for span, prob in sorted(zip(spans, probs)):
                if prob is None:
                    continue
                rel_prob, word_prob = prob
                l = max(0, span.start - n_context)
                r = min(batch.ntokens, span.end + 1 + n_context)
                print(
                    f"[{span.start}, {span.end}]"
                    f" <{dataset.rel_vocab.i2w[span.rel_typ]}>"
                    f" {Logging.color('red', rels[span.rel_idx].obj_alias[span.surface_idx])}"
                    f"{' (alias)' if span.surface_idx > 0 else ''}"
                    f": rel = {color_if_less(math.exp(-rel_prob), ppl_threshold)}"
                    f", word = {color_if_less(math.exp(-word_prob), ppl_threshold)}"
                )
                print(
                    '    ', '... ' if l > 0 else '', ' '.join(
                        Logging.color('red', tokens[idx])
                        if span.start <= idx <= span.end else tokens[idx]
                        for idx in range(l, r)),
                    ' ...' if r < batch.ntokens else '')
            print()
Beispiel #13
0
    def posterior_log_probs(example,
                            _,
                            bptt_size=140,
                            n_context=5,
                            max_segments=-1,
                            ignore_single_relation=False,
                            file=sys.stdout):
        from collections import defaultdict
        from dataset.utils import flip_batches, search_paths

        (init_batch, batches) = dataset.create_one_batch([example], bptt_size)
        flipped_batches = flip_batches(batches)

        word_prob: Dict[str, List[np.ndarray]] = {
            k: []
            for k in ["forward", "backward"]
        }
        marginal_prob: Dict[str, List[np.ndarray]] = {
            k: []
            for k in ["forward", "backward"]
        }
        posterior_probs: List[Dict[MatchedSpan, Tuple[float, float]]] = [
        ]  # list(n_batches) of list(n_spans)
        seq_loss: List[float] = []

        def callback(loss: Tensor, _batch: BatchSequence):
            posterior_probs.append(model.model_cache['posterior_log_probs'][0])
            word_prob['forward'].append(
                model.model_cache['target_cond_log_probs'][0])
            marginal_prob['forward'].append(
                model.model_cache['stacked_log_probs'][0])
            seq_loss.append(loss.item())

        def callback_flip(_loss: Tensor, _batch: BatchSequence):
            marginal_prob['backward'].append(
                model.model_cache['stacked_log_probs'][0])

        with torch.no_grad():
            model.eval()
            compute_batch_loss(model,
                               init_batch,
                               batches,
                               use_unk_probs=True,
                               callback=callback,
                               evaluate=True,
                               calc_loss_kwargs={'dump_posterior_probs': True})

            compute_batch_loss(model,
                               init_batch,
                               flipped_batches,
                               use_unk_probs=True,
                               callback=callback_flip,
                               evaluate=True,
                               calc_loss_kwargs={'dump_posterior_probs': True})

        # credit: http://bayesjumping.net/log-sum-exp-trick/ for without using scipy
        def log_sum_exp(ns: List[int]):
            max_ = np.max(ns)
            sum_exp = np.exp(ns - max_).sum()
            return max_ + np.log(sum_exp)

        n_words = 0
        for seq_idx, (batch,
                      probs_dict) in enumerate(zip(batches, posterior_probs)):
            if max_segments != -1 and seq_idx >= max_segments:
                break

            tokens = batch.raw_sequence[0][1:]
            spans = batch.spans[0]
            if len(spans) == 0:
                continue
            overlap_group = defaultdict(list)
            sorted_spans = sorted(spans, key=lambda x: (x.start, x.end))
            latest = (sorted_spans[0].start, sorted_spans[0].end
                      )  # The most recent span group
            overlap_group[latest] = [sorted_spans[0]]
            for sp in sorted_spans[1:]:
                if sp.start > sp.end or sp.end >= batch.ntokens:
                    continue
                if sp.start <= latest[1]:
                    grp = overlap_group[latest]
                    del overlap_group[latest]
                    latest = (latest[0], max(latest[1], sp.end))
                    overlap_group[latest] = grp + [sp]
                else:
                    latest = (sp.start, sp.end)
                    overlap_group[latest] = [sp]

            if np.any([len(g) > 1 for g in overlap_group.values()
                       ]) or not ignore_single_relation:
                print(Logging.color(
                    'green', f"Segment #{seq_idx}: "
                    f"words {n_words} - {n_words + batch.ntokens}, "
                    f"ppl = {math.exp(seq_loss[seq_idx]):.4f}"),
                      file=file)
                n_words += batch.ntokens

            for span, group in overlap_group.items():
                if ignore_single_relation and len(group) == 1:
                    continue

                alpha = marginal_prob["forward"][seq_idx][span[0] - 1]
                beta = marginal_prob["backward"][::-1][seq_idx][
                    batch.lengths[0] - (span[1] + 1) - 1]
                # Enumerate all the paths
                paths = search_paths(group, span[0], span[1])
                log_probs = []
                annotations = []
                delimiters = []
                for path in paths:
                    path_anno = []
                    path_delims = []
                    logprob = alpha + beta
                    for hop in path:
                        if hop.rel_typ == -100:  # dummy relation - word transition
                            logprob += word_prob["forward"][seq_idx][span[0]]
                            path_anno.append("word")
                        else:
                            logprob += probs_dict[hop][0]
                            path_anno.append(
                                dataset.rel_vocab.i2w[hop.rel_typ])
                        path_delims += ["   "] * (hop.end - hop.start)
                        if hop.end < span[1]:
                            path_delims.append(" | ")

                    delimiters.append(path_delims)
                    log_probs.append(logprob)
                    annotations.append(path_anno)
                log_denom = log_sum_exp(log_probs)
                normalized_probs = [
                    np.exp(log_prob - log_denom) for log_prob in log_probs
                ]

                l = max(0, span[0] - n_context)
                r = min(batch.ntokens, span[1] + n_context)
                token_string = " ".join([
                    '    ', '... ' if l > 0 else '',
                    ' '.join(tokens[idx]
                             for idx in range(l, span[0])), '|', '   '.join(
                                 Logging.color('red', tokens[idx])
                                 for idx in range(span[0], span[1] + 1)), '|',
                    ' '.join(tokens[idx] for idx in range(span[1] + 1, r)),
                    ' ...' if r < batch.ntokens else ''
                ])
                print(token_string, file=file)

                annotation_strings = [" => ".join(a) for a in annotations]
                max_anno_len = max([len(a) for a in annotation_strings])
                max_score_idx = np.argmax(normalized_probs)

                for idx, (delim, anno, prob) in enumerate(
                        zip(delimiters, annotation_strings, normalized_probs)):
                    matched_span_tokens = (
                        " " * token_string.index(" | ") + " | " +
                        ' _ '.join(tokens[idx]
                                   for idx in range(span[0], span[1] + 1)) +
                        " | ")
                    delim_positions = re.finditer(r" _ ", matched_span_tokens)
                    for d, match in zip(delim, delim_positions):
                        pos = match.start(0)
                        matched_span_tokens = matched_span_tokens[:
                                                                  pos] + d + matched_span_tokens[
                                                                      (pos +
                                                                       3):]
                    score = Logging.color(
                        "green", f"  {prob:1.4f}"
                    ) if idx == max_score_idx else f"  {prob:1.4f}"
                    matched_span_tokens += "  " + f"{anno}{' ' * (max_anno_len - len(anno))}" + score
                    print(matched_span_tokens, file=file)

                print(file=file)
Beispiel #14
0
def repl(dataset: KBLMDataset, model: BaseLM):
    import re
    from run import compute_batch_loss  # avoid circular import

    print(
        Logging.color(
            s="Execute `sample(name=\"Barack Obama\")` to generate samples for given entity.\n"
            "Execute `sample(split='test', index=1)` to generate samples for specific data entry.\n"
            "For more configurable settings, please refer to method `models.lrlm.LRLM.sampling_decode`.\n",
            col='green'))

    def get_example(func):
        def find_name(name: str) -> Tuple[str, int]:
            for split in dataset.data:
                try:
                    index = next(idx
                                 for idx, ex in enumerate(dataset.data[split])
                                 if name in ' '.join(ex.sentence[:10]))
                    return split, index
                except StopIteration:
                    continue
            else:
                raise ValueError("Name not found!")

        @functools.wraps(func)
        def wrapped(name: Optional[str] = None,
                    split: str = 'test',
                    index: int = 1,
                    start_symbol: str = '<s>',
                    **kwargs):
            if name is not None:
                try:
                    split, index = find_name(f"= {name} =")
                except ValueError:
                    split, index = find_name(name)
            if start_symbol is None:
                start_symbol = '<s>'
            start_symbol = dataset.word_vocab.w2i[start_symbol] if isinstance(
                start_symbol, str) else start_symbol
            example = dataset.data[split][index]
            try:
                name = ' '.join(
                    example.sentence[2:example.sentence.index('=', 2)])
            except ValueError:
                # probably WikiFacts
                pos = min((example.sentence.index(token)
                           for token in ['is', 'was', '(']
                           if token in example.sentence),
                          default=3)
                name = ' '.join(example.sentence[1:pos])
            if "file" not in kwargs:
                file = sys.stdout
            else:
                file = kwargs["file"]
            print(f"Data split {split}, index {index}, name \"{name}\"",
                  file=file)
            func(example, start_symbol, **kwargs)

        return wrapped

    end_symbol = dataset.word_vocab.w2i['</s>']

    @get_example
    def sample(example,
               start_symbol,
               max_length=500,
               n_tries=1,
               bptt_size=150,
               **kwargs):
        (init_batch, batches) = dataset.create_one_batch([example], bptt_size)
        if kwargs.get('generate_unk', False):
            fulli2w = [
                w for w, i in sorted(dataset.total_w2i.items(),
                                     key=lambda x: x[1])
            ]
            unkinfo = (dataset.unk_probs, fulli2w)
            del kwargs["generate_unk"]
        else:
            unkinfo = None

        best_output = None
        for _ in range(n_tries):
            print_info = kwargs.get('print_info', n_tries == 1)
            output: SampledOutput = model.sampling_decode(
                dataset.vocab,
                example,
                begin_symbol=
                start_symbol,  # this is the actual start symbol in all articles
                end_symbol=end_symbol,
                max_length=max_length,
                color_outputs=True,  # generate colored outputs in terminal
                print_info=print_info,
                unkinfo=unkinfo,
                init_batch=init_batch,
                **kwargs)
            if best_output is None or output.sample_loss < best_output.sample_loss:
                best_output = output

        # format title & subtitles
        if n_tries > 1:
            print(
                f"Sample loss: {best_output.sample_loss:.3f}, PPL: {math.exp(best_output.sample_loss):.3f}"
            )
            print(
                f"Complete / incomplete entities: {best_output.complete_copies} / {best_output.incomplete_copies}"
            )
        sentence = re.sub(r'= = (.*?) = = ', r'\n\n== \1 ==\n    ',
                          ' '.join(best_output.sentence))
        sentence = re.sub(r'= (.*?) = ', r'= \1 =\n    ', sentence)
        print(sentence)

    @get_example
    def average_copies(example,
                       start_symbol,
                       max_length=200,
                       n_tries=10,
                       progress=False,
                       show_samples=False,
                       **kwargs):
        complete_copies = utils.SimpleAverage()
        incomplete_copies = utils.SimpleAverage()

        if show_samples:
            sample_kwargs = dict(print_info=True,
                                 color_outputs=True,
                                 color_incomplete=False,
                                 **kwargs)
        else:
            sample_kwargs = dict(print_info=False,
                                 color_outputs=False,
                                 **kwargs)

        for _ in utils.progress(n_tries, verbose=progress):
            output: SampledOutput = model.sampling_decode(
                dataset.vocab,
                example,
                begin_symbol=start_symbol,
                end_symbol=end_symbol,
                max_length=max_length,
                **sample_kwargs)
            complete_copies.add(output.complete_copies)
            incomplete_copies.add(output.incomplete_copies)
            if show_samples:
                print(' '.join(output.sentence))
        if isinstance(example, LRLMExample):
            n_gold_rels = sum(
                int(span.end < max_length) for span in example.spans)
            print(
                f"Complete / Gold entities: {complete_copies.value()} / {n_gold_rels}"
            )
        else:
            print(
                f"Complete / Incomplete entities: {complete_copies.value()} / {incomplete_copies.value()}"
            )

    @get_example
    def logprob_article(example, _start_symbol, bptt_size=150, **_kwargs):
        (init_batch, batches) = dataset.create_one_batch([example], bptt_size)

        # log-prob calculation
        def callback(loss: Tensor, batch: BatchSequence, *_extra):
            if batch.spans is None:
                # NKLM
                rel_ids = np.concatenate(
                    [[-1], batch.seqs['rel_ids'][0].cpu().numpy(), [-1]])
                diff = np.ediff1d(rel_ids)
                n_rels = np.count_nonzero(np.cumsum(diff[diff != 0]))
            else:
                n_rels = len(batch.spans[0])
            print(f"{n_rels:2d}\t{math.exp(loss.item()):2.3f}")

        print("#rels\tPPL")
        with torch.no_grad():
            model.eval()
            total_loss = compute_batch_loss(model,
                                            init_batch,
                                            batches,
                                            use_unk_probs=True,
                                            callback=callback,
                                            evaluate=True)
        total_loss /= sum(batch.ntokens for batch in batches)
        print(f"Total PPL: {math.exp(total_loss)}")

    @get_example
    def posterior_log_probs(example,
                            _,
                            bptt_size=140,
                            n_context=5,
                            max_segments=-1,
                            ignore_single_relation=False,
                            file=sys.stdout):
        from collections import defaultdict
        from dataset.utils import flip_batches, search_paths

        (init_batch, batches) = dataset.create_one_batch([example], bptt_size)
        flipped_batches = flip_batches(batches)

        word_prob: Dict[str, List[np.ndarray]] = {
            k: []
            for k in ["forward", "backward"]
        }
        marginal_prob: Dict[str, List[np.ndarray]] = {
            k: []
            for k in ["forward", "backward"]
        }
        posterior_probs: List[Dict[MatchedSpan, Tuple[float, float]]] = [
        ]  # list(n_batches) of list(n_spans)
        seq_loss: List[float] = []

        def callback(loss: Tensor, _batch: BatchSequence):
            posterior_probs.append(model.model_cache['posterior_log_probs'][0])
            word_prob['forward'].append(
                model.model_cache['target_cond_log_probs'][0])
            marginal_prob['forward'].append(
                model.model_cache['stacked_log_probs'][0])
            seq_loss.append(loss.item())

        def callback_flip(_loss: Tensor, _batch: BatchSequence):
            marginal_prob['backward'].append(
                model.model_cache['stacked_log_probs'][0])

        with torch.no_grad():
            model.eval()
            compute_batch_loss(model,
                               init_batch,
                               batches,
                               use_unk_probs=True,
                               callback=callback,
                               evaluate=True,
                               calc_loss_kwargs={'dump_posterior_probs': True})

            compute_batch_loss(model,
                               init_batch,
                               flipped_batches,
                               use_unk_probs=True,
                               callback=callback_flip,
                               evaluate=True,
                               calc_loss_kwargs={'dump_posterior_probs': True})

        # credit: http://bayesjumping.net/log-sum-exp-trick/ for without using scipy
        def log_sum_exp(ns: List[int]):
            max_ = np.max(ns)
            sum_exp = np.exp(ns - max_).sum()
            return max_ + np.log(sum_exp)

        n_words = 0
        for seq_idx, (batch,
                      probs_dict) in enumerate(zip(batches, posterior_probs)):
            if max_segments != -1 and seq_idx >= max_segments:
                break

            tokens = batch.raw_sequence[0][1:]
            spans = batch.spans[0]
            if len(spans) == 0:
                continue
            overlap_group = defaultdict(list)
            sorted_spans = sorted(spans, key=lambda x: (x.start, x.end))
            latest = (sorted_spans[0].start, sorted_spans[0].end
                      )  # The most recent span group
            overlap_group[latest] = [sorted_spans[0]]
            for sp in sorted_spans[1:]:
                if sp.start > sp.end or sp.end >= batch.ntokens:
                    continue
                if sp.start <= latest[1]:
                    grp = overlap_group[latest]
                    del overlap_group[latest]
                    latest = (latest[0], max(latest[1], sp.end))
                    overlap_group[latest] = grp + [sp]
                else:
                    latest = (sp.start, sp.end)
                    overlap_group[latest] = [sp]

            if np.any([len(g) > 1 for g in overlap_group.values()
                       ]) or not ignore_single_relation:
                print(Logging.color(
                    'green', f"Segment #{seq_idx}: "
                    f"words {n_words} - {n_words + batch.ntokens}, "
                    f"ppl = {math.exp(seq_loss[seq_idx]):.4f}"),
                      file=file)
                n_words += batch.ntokens

            for span, group in overlap_group.items():
                if ignore_single_relation and len(group) == 1:
                    continue

                alpha = marginal_prob["forward"][seq_idx][span[0] - 1]
                beta = marginal_prob["backward"][::-1][seq_idx][
                    batch.lengths[0] - (span[1] + 1) - 1]
                # Enumerate all the paths
                paths = search_paths(group, span[0], span[1])
                log_probs = []
                annotations = []
                delimiters = []
                for path in paths:
                    path_anno = []
                    path_delims = []
                    logprob = alpha + beta
                    for hop in path:
                        if hop.rel_typ == -100:  # dummy relation - word transition
                            logprob += word_prob["forward"][seq_idx][span[0]]
                            path_anno.append("word")
                        else:
                            logprob += probs_dict[hop][0]
                            path_anno.append(
                                dataset.rel_vocab.i2w[hop.rel_typ])
                        path_delims += ["   "] * (hop.end - hop.start)
                        if hop.end < span[1]:
                            path_delims.append(" | ")

                    delimiters.append(path_delims)
                    log_probs.append(logprob)
                    annotations.append(path_anno)
                log_denom = log_sum_exp(log_probs)
                normalized_probs = [
                    np.exp(log_prob - log_denom) for log_prob in log_probs
                ]

                l = max(0, span[0] - n_context)
                r = min(batch.ntokens, span[1] + n_context)
                token_string = " ".join([
                    '    ', '... ' if l > 0 else '',
                    ' '.join(tokens[idx]
                             for idx in range(l, span[0])), '|', '   '.join(
                                 Logging.color('red', tokens[idx])
                                 for idx in range(span[0], span[1] + 1)), '|',
                    ' '.join(tokens[idx] for idx in range(span[1] + 1, r)),
                    ' ...' if r < batch.ntokens else ''
                ])
                print(token_string, file=file)

                annotation_strings = [" => ".join(a) for a in annotations]
                max_anno_len = max([len(a) for a in annotation_strings])
                max_score_idx = np.argmax(normalized_probs)

                for idx, (delim, anno, prob) in enumerate(
                        zip(delimiters, annotation_strings, normalized_probs)):
                    matched_span_tokens = (
                        " " * token_string.index(" | ") + " | " +
                        ' _ '.join(tokens[idx]
                                   for idx in range(span[0], span[1] + 1)) +
                        " | ")
                    delim_positions = re.finditer(r" _ ", matched_span_tokens)
                    for d, match in zip(delim, delim_positions):
                        pos = match.start(0)
                        matched_span_tokens = matched_span_tokens[:
                                                                  pos] + d + matched_span_tokens[
                                                                      (pos +
                                                                       3):]
                    score = Logging.color(
                        "green", f"  {prob:1.4f}"
                    ) if idx == max_score_idx else f"  {prob:1.4f}"
                    matched_span_tokens += "  " + f"{anno}{' ' * (max_anno_len - len(anno))}" + score
                    print(matched_span_tokens, file=file)

                print(file=file)

    @get_example
    def span_log_probs(example,
                       _,
                       bptt_size=140,
                       ppl_threshold=200.0,
                       split_len=20,
                       n_context=5,
                       max_segments=-1):
        (init_batch, batches) = dataset.create_one_batch([example], bptt_size)
        rels: List[Relation] = init_batch[0]

        posterior_probs: List[List[Optional[Tuple[float, float]]]] = [
        ]  # list(n_batches) of list(n_spans)
        seq_loss: List[float] = []

        # noinspection PyShadowingNames
        def callback(loss: Tensor, batch: BatchSequence) -> None:
            assert batch.spans is not None
            probs_dict: Dict[MatchedSpan, Tuple[
                float, float]] = model.model_cache['posterior_log_probs'][0]
            probs = [probs_dict.get(span, None) for span in batch.spans[0]]
            posterior_probs.append(probs)
            seq_loss.append(loss.item())

        def color_if_less(val: float,
                          threshold: float,
                          format_str: str = '{:.4f}',
                          color: str = 'yellow'):
            s = format_str.format(val)
            return Logging.color(color, s) if val < threshold else s

        with torch.no_grad():
            model.eval()
            compute_batch_loss(model,
                               init_batch,
                               batches,
                               use_unk_probs=True,
                               callback=callback,
                               evaluate=True,
                               calc_loss_kwargs={'dump_posterior_probs': True})

        n_words = 0
        for seq_idx, (batch, probs) in enumerate(zip(batches,
                                                     posterior_probs)):
            if max_segments != -1 and seq_idx >= max_segments:
                break
            print(
                Logging.color(
                    'green', f"Segment #{seq_idx}: "
                    f"words {n_words} - {n_words + batch.ntokens}, "
                    f"ppl = {math.exp(seq_loss[seq_idx]):.4f}"))
            n_words += batch.ntokens

            tokens = batch.raw_sequence[0][1:]
            spans = batch.spans[0]
            is_in_span = [False] * batch.ntokens
            for span in spans:
                if span.start > span.end or span.end >= batch.ntokens:
                    continue
                is_in_span[span.start:(
                    span.end + 1)] = [True] * (span.end - span.start + 1)
            for idx in range(0, batch.ntokens, split_len):
                print(
                    f'{idx:3d}:', ' '.join(
                        Logging.color('red', w) if in_span else w
                        for w, in_span in zip(
                            tokens[idx:(idx + split_len)], is_in_span[idx:(
                                idx + split_len)])))
            print()

            for span, prob in sorted(zip(spans, probs)):
                if prob is None:
                    continue
                rel_prob, word_prob = prob
                l = max(0, span.start - n_context)
                r = min(batch.ntokens, span.end + 1 + n_context)
                print(
                    f"[{span.start}, {span.end}]"
                    f" <{dataset.rel_vocab.i2w[span.rel_typ]}>"
                    f" {Logging.color('red', rels[span.rel_idx].obj_alias[span.surface_idx])}"
                    f"{' (alias)' if span.surface_idx > 0 else ''}"
                    f": rel = {color_if_less(math.exp(-rel_prob), ppl_threshold)}"
                    f", word = {color_if_less(math.exp(-word_prob), ppl_threshold)}"
                )
                print(
                    '    ', '... ' if l > 0 else '', ' '.join(
                        Logging.color('red', tokens[idx])
                        if span.start <= idx <= span.end else tokens[idx]
                        for idx in range(l, r)),
                    ' ...' if r < batch.ntokens else '')
            print()

    from IPython import embed
    embed()
Beispiel #15
0
    def sampling_decode(self, vocab: Dict[str, Vocab], example: NKLMExample,
                        begin_symbol: int = 2, end_symbol: int = 5,
                        initial_hidden: Optional[HiddenState] = None, warm_up: Optional[int] = None,
                        max_length: int = 200, greedy=False, topk=None,
                        fill_incomplete=False, allow_invalid_pos=False,
                        print_info=True, color_outputs=False, color_incomplete=True,
                        show_ellipses=True, show_rel_type=True, show_copy_pos=False,
                        sanity_check=False, unkinfo: Optional[Tuple[Tensor, List[str]]] = None, **kwargs) \
            -> SampledOutput:
        """
        Sampling for NKLM.

        Output format:
        - Red words:       Copied from canonical form of entity.
        - Green words:     Copied from alias form of entity.
        - Yellow words:    Warm-up context.
        - word_[type]:     "word" is an entity of type "type".
        - word...(a_b_c):  "word" is a partially copied entity with remaining suffix "a b c".
        - (a_b_c)...word:  "word" is a partially copied entity with remaining prefix "a b c".
        - @-@:             A dash in the original text without spaces around, e.g. M @-@ 82 => M-82.
        - <X>:             A token copied from an invalid position of an entity.

        :param vocab: Vocabulary containing id2word mapping.
        :param example: The :class:`Example` object of the current topic.
        :param begin_symbol: Start of sentence symbol.
        :param end_symbol: End of sentence symbol. Sampling stops when this symbol is generated.
        :param initial_hidden: If not specified, default hidden states returned by :meth:`init_hidden` is used.
        :param warm_up: Number of tokens to provide as context before performing sampling.
        :param max_length: If generated sentence exceeds specified length, sampling is force terminated.
        :param greedy: If ``True``, use greedy decoding instead of sampling.
        :param topk: If not ``None``, only sample from indices with top-k probabilites.
        :param fill_incomplete: If ``True``, entities that are partially copied will be completed.
        :param allow_invalid_pos: If ``True``, allowing copying from invalid positions, and use <unk> as input.

        :param print_info: If ``True``, print information about sampled result.
        :param color_outputs: If ``True``, include annotations for each output token. Tokens from entities will be
            colored red.
        :param color_incomplete: If ``True`` and ``color_outputs`` is ``True``, also color partially copied entities.
        :param show_ellipses: If ``True``, show ellipses at beginning or end of partially copied entities.
        :param show_rel_type: If ``True``, show relation types for copied entities.
        :param show_copy_pos: If ``True``, show the position from which the entity tokens are copied.

        :param sanity_check: If ``True``, perform sanity check on generated sample.

        :return: A tuple of (loss_value, formatted list of words).
        """

        if unkinfo is not None:
            unkprob, unki2w = unkinfo
            unkprob = unkprob[self._vocab_size:]
            unki2w = unki2w[self._vocab_size:]
            normalized_unkprob = F.log_softmax(unkprob, dim=0)

        # noinspection PyPep8Naming
        UNK, INVALID, UNK_TOKEN, CANONICAL_IDX, EPS = -100, -1, 0, 0, 1e-4

        self.eval()
        self.init_hidden(1, [example.relations])

        word_vocab, rel_vocab = vocab['word'], vocab['rel']

        tensor = functools.partial(sample_utils.tensor, device=self.device)
        randint = sample_utils.randint
        sample = functools.partial(sample_utils.sample,
                                   greedy=greedy,
                                   topk=topk)
        np_sample = functools.partial(sample_utils.np_sample,
                                      greedy=greedy,
                                      topk=topk)

        # noinspection PyShadowingNames
        def compute_loss(
                inputs: List[int],
                rel_ids: List[int],
                copy_pos: List[int],
                surface_indices: List[int],
                hidden: Optional[HiddenState] = None
        ) -> Tuple[float, HiddenState]:
            batch = SimpleNamespace(
                sequence=tensor(inputs[:-1]),
                target=tensor(inputs[1:]),
                unkprob=None,
                seqs={
                    'rel_ids': tensor(rel_ids),
                    'copy_pos': tensor(copy_pos),
                    'surface_indices': tensor(surface_indices)
                },
                ntokens=len(inputs) - 1,
            )
            loss, next_hidden = self.calc_loss(batch,
                                               hidden=hidden)  # type: ignore
            return loss.item(), next_hidden

        # Initialization
        if warm_up is None:
            inputs = [begin_symbol]
            rel_ids = [INVALID]
            copy_pos = [INVALID]
            surface_indices = [INVALID]
            total_log_prob = 0.0
            hidden = initial_hidden
        else:
            inputs = list(word_vocab.numericalize(example.sentence[:warm_up]))
            rel_ids = list(example.rel_ids[:warm_up])
            copy_pos = list(example.copy_pos[:warm_up])
            surface_indices = list(example.surface_indices[:warm_up])
            total_log_prob, hidden = compute_loss(inputs, rel_ids, copy_pos,
                                                  surface_indices,
                                                  initial_hidden)
            total_log_prob = -total_log_prob * (len(inputs) - 1)

        # Sampling procedure
        while len(inputs) < max_length and inputs[-1] != end_symbol:
            fact_log_probs, output, _, next_hidden = \
                self._compute_fact_log_probs(tensor(inputs[-1]), tensor(rel_ids[-1]), tensor(copy_pos[-1]), hidden)
            rel_id, fact_loss = sample(fact_log_probs[0])
            rel_id -= 1
            total_log_prob += fact_loss

            # next_fact_embed: (1, 1, fact_embed_dim)
            next_fact_embed = self._get_fact_embeds(tensor(rel_id))
            copy_indicator, alias_log_probs, pos_log_probs, vocab_log_probs = \
                self._compute_generate_log_probs(output, next_fact_embed, tensor([rel_ids[-1], rel_id]))
            if torch.bernoulli(copy_indicator).item():
                total_log_prob += torch.log(copy_indicator).item()
                assert rel_id != -1
                # copy entity
                aliases = example.relations[rel_id].obj_alias
                if self._alias_disamb is AliasDisamb.FastText:
                    assert alias_log_probs is not None
                    surface_idx, surface_loss = sample(alias_log_probs[0])
                else:
                    surface_idx, surface_loss = 0, 0.0
                alias = self.alias_list[aliases[surface_idx]]

                entity: List[str] = alias.split()
                # normalization not required
                # TODO: keep consistent with _mask_invalid_pos setting
                pos, pos_loss = sample(pos_log_probs if allow_invalid_pos else
                                       pos_log_probs.squeeze()[:len(entity)])
                if self._mask_invalid_pos:
                    pos_loss -= torch.logsumexp(
                        pos_log_probs.squeeze()[:len(entity)], dim=0).item()
                total_log_prob += surface_loss + pos_loss
                token = UNK_TOKEN if pos >= len(
                    entity) else word_vocab.w2i.get(entity[pos], UNK_TOKEN)
            else:
                total_log_prob += torch.log(1.0 - copy_indicator).item()
                assert rel_id == -1
                # generate word
                token, token_loss = sample(vocab_log_probs)
                total_log_prob += token_loss
                pos = INVALID
                surface_idx = INVALID

                if token == 0 and unkinfo is not None:  # unk
                    unk_idx, unk_loss = np_sample(normalized_unkprob)
                    total_log_prob += unk_loss
                    # Ugly multi-purpose use of variables.
                    surface_idx = unk_idx  # Record unk word index in surface_indices.
                    rel_id = UNK  # Record UNK in rel_ids.

            inputs.append(token)
            rel_ids.append(rel_id)
            copy_pos.append(pos)
            surface_indices.append(surface_idx)
            hidden = next_hidden

        sample_loss = -total_log_prob / (len(inputs) - 1)
        if print_info:
            print(
                f"Sample loss: {sample_loss:.3f}, PPL: {math.exp(sample_loss):.3f}"
            )
        # Sanity checks
        if sanity_check:
            loss_val, gold_hidden = compute_loss(inputs, rel_ids, copy_pos,
                                                 surface_indices,
                                                 initial_hidden)
            assert hidden is not None
            hidden_state_diff = max(
                torch.max(torch.abs(g - h)).item()
                for g, h in zip(gold_hidden, hidden))
            if hidden_state_diff > EPS:
                Logging.warn(
                    f"Hidden states do not match. Difference: {hidden_state_diff}"
                )
            if abs(sample_loss - loss_val) > EPS:
                Logging.warn(
                    f"Loss values do not match. "
                    f"Forward loss: {loss_val}, difference: {abs(sample_loss - loss_val)}"
                )

        # Format the output
        sentence = list(zip(inputs, rel_ids, copy_pos, surface_indices))
        words = []
        copy_count = 0
        complete_count = 0
        last_entity = None
        entity_continuing = False
        for idx, (token, rel_id, pos, surface_idx) in enumerate(sentence):
            is_warm_up = (warm_up is not None and idx < warm_up)
            if rel_id == INVALID:
                word = word_vocab.i2w[token]
            elif rel_id == UNK:
                word = Logging.color('blue', unki2w[surface_idx])
            else:
                copy_count += 1
                entity_id = example.relations[rel_id].obj_alias[surface_idx]
                entity = self.alias_list[entity_id].split()
                if pos >= len(entity):
                    word = "<X>"
                else:
                    word = entity[pos]
                if show_copy_pos:
                    word = f"{pos}_{rel_id}_{surface_idx}_{word}"
                is_last_word_in_entity = (idx == len(sentence) - 1
                                          or sentence[idx + 1][1:] !=
                                          (rel_id, pos + 1, surface_idx))
                is_first_word_in_entity = (idx == 0 or sentence[idx - 1][1:] !=
                                           (rel_id, pos - 1, surface_idx))
                # add entity tag after the last word
                if show_rel_type and is_last_word_in_entity:
                    word = f"{word}_[{rel_vocab.i2w[example.relations[rel_id].rel_typ]}]"

                # check whether fully copied
                if show_ellipses:
                    if pos < len(entity) - 1 and is_last_word_in_entity:
                        word = word + '...' + (
                            f"({'_'.join(entity[(pos + 1):])})"
                            if fill_incomplete else "")
                    if pos > 0 and is_first_word_in_entity:
                        word = (f"({'_'.join(entity[:pos])})"
                                if fill_incomplete else "") + '...' + word

                if entity_continuing:
                    if last_entity == (rel_id, surface_idx,
                                       pos - 1):  # Continuing
                        last_entity = (rel_id, surface_idx, pos)
                    else:
                        entity_continuing = False
                        last_entity = None

                if pos == 0 and not entity_continuing:  # reset
                    entity_continuing = True
                    last_entity = (rel_id, surface_idx, 0)

                if color_outputs and not is_warm_up and (color_incomplete
                                                         or entity_continuing):
                    word = Logging.color(
                        'red' if surface_idx == 0 else 'green', word)

                if pos == len(entity) - 1 and entity_continuing:  # commit
                    entity_continuing = False
                    complete_count += 1

            if color_outputs and is_warm_up:
                word = Logging.color('yellow', word)
            words.append(word)

        if print_info:
            print(f"Copied, Completed: {copy_count}, {complete_count}")
        sampled_output = SampledOutput(sentence=words,
                                       sample_loss=sample_loss,
                                       complete_copies=complete_count,
                                       incomplete_copies=copy_count)
        return sampled_output