def create_exp_dir(path: str, script_path: str, overwrite=False) -> str: """ Create experiment directory, and return path to log file. """ if os.path.exists(path): if not overwrite: print( Logging.color( col='red', s=f"The experiment name: {path} already exists. " f"Training will not proceed without the `--overwrite` flag." )) sys.exit(1) else: # Radical print( Logging.color(col='green', s=f"Overwriting the experiment: {path} ...")) shutil.rmtree(path) os.mkdir(path) shutil.copy(script_path, path) logfile = os.path.join( path, f"{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}.log") return logfile
def _validate(self): rules = self.validate() or [] for pattern, validator in rules: regex = re.compile(rf"^{pattern}$") matches = 0 for k in dir(self): if not k.startswith('_') and regex.match(k): matches += 1 v = getattr(self, k) try: result = validator(v) except Exception: raise ValidationError(k, v, validator.__name__) else: if not result: raise ValidationError(k, v, validator.__name__) if matches == 0: Logging.warn( f"regex \"{pattern}\" did not match any arguments")
def sampling_decode(self, vocab: Dict[str, Vocab], example: LRLMExample, begin_symbol: int = 2, end_symbol: int = 5, initial_hidden: Optional[HiddenState] = None, warm_up: Optional[int] = None, max_length: int = 200, greedy: bool = False, topk: Optional[int] = None, print_info: bool = True, color_outputs: bool = False, init_batch=None, **_kwargs) \ -> SampledOutput: tensor = functools.partial(sample_utils.tensor, device=self.device) sample = functools.partial(sample_utils.sample, greedy=greedy, topk=topk) self.eval() self.init_hidden(1, init_batch) if warm_up is None: inputs = [begin_symbol] hidden = initial_hidden total_log_prob = 0.0 else: inputs = list(vocab["word"].numericalize( example.sentence[:warm_up])) total_log_prob, hidden = self.forward(tensor(inputs[:-1]), target=tensor(inputs[1:])) total_log_prob = -torch.sum(total_log_prob).item() * (len(inputs) - 1) while len(inputs) < max_length and inputs[-1] != end_symbol: # full copy of the forward pass, including dropouts. But they won't be applied due to .eval function. # Run LSTM over the word word_log_probs, new_hidden = self.forward(tensor(inputs[-1]), hidden) word_id, word_log_prob = sample(word_log_probs) inputs.append(word_id) hidden = new_hidden total_log_prob += word_log_prob sample_loss = -total_log_prob / (len(inputs) - 1) if print_info: print( f"Sample loss: {sample_loss:.3f}, PPL: {math.exp(sample_loss):.3f}" ) # Format the output words = [vocab["word"].i2w[token] for token in inputs] if color_outputs and warm_up is not None: words[:warm_up] = [ Logging.color('yellow', w) for w in words[:warm_up] ] output = SampledOutput(sentence=words, sample_loss=sample_loss, complete_copies=0, incomplete_copies=0) return output
def _try_load_cache(self, path: Path) -> bool: r"""Try loading a cached dataset from the data directory. :param path: The path to data directory. :return: Whether loading was successful. """ cache_dir = path / '_cached' if not cache_dir.exists(): return False params_index_path = cache_dir / 'params_index.pkl' params_index: List[Dict[str, Any]] = loadpkl(params_index_path) params = self._get_params() index = next( (idx for idx, p in enumerate(params_index) if (cache_dir / f'{idx}.pkl').exists() and self._compare_params(p, params)), -1) if index != -1: load_path = cache_dir / f'{index}.pkl' self.batches = loadpkl(load_path) self.ntokens = { split: sum(batch.ntokens for _, batches in dataset for batch in batches) for split, dataset in self.batches.items() } LOGGER.info( f"Cached dataset loaded from {load_path}, with settings: {params}" ) # check for excluded keys and warn in case of mismatch load_params = params_index[index] for key in self.EXCLUDE_KEYS: if key in params or key in load_params: current = params.get(key, "<does not exist>") loaded = load_params.get(key, "<does not exist>") if current != loaded: LOGGER.info( Logging.color( 'red', f"Ignored data param '{key}' mismatch " f"(current: {current}, loaded: {loaded})")) return True return False
def _parse_type_spec(cls) -> Dict[str, _ArgTypeSpec]: """ :return: A dict mapping argument names to their type-specs """ _attr_name = '__type_dict__' if hasattr(cls, _attr_name): return getattr(cls, _attr_name) type_dict = {} # get annotations from the current class and all its base classes as well annotations = {} for base in reversed(cls.__mro__): if base not in [object, Arguments]: annotations.update(base.__dict__.get('__annotations__', {})) bad_names = [] warn_names = [] def check_name_conventions(name): if name.startswith('_') or name.endswith('_'): # names should not start or begin with underscores bad_names.append(name) if name != name.lower() or any(ord(c) >= 128 for c in name): # names are recommended to contain non-uppercase ASCII characters only warn_names.append(name) # check that all attributes are annotated, except for `Switch`es for arg_name in dir(cls): if arg_name.startswith('__') or cls._check_reserved( arg_name): # magic stuff continue arg_val = getattr(cls, arg_name) if isinstance(arg_val, Arguments.Switch): check_name_conventions(arg_name) # noinspection PyProtectedMember,PyCallByClass type_dict[arg_name.lower()] = Arguments._ArgTypeSpec( Arguments.Switch, nullable=False, required=False, default=arg_val._default) elif arg_name not in cls.__annotations__: raise ArgumentError( f"Type is not specified for argument '{arg_name}'. " f"Type annotation can omitted only when argument is a `Switch`." ) # iterate over annotated values and generate type-specs for arg_name, arg_typ in annotations.items(): if cls._check_reserved(arg_name): raise ArgumentError( f"'{arg_name}' cannot be used as argument name because it is reserved." ) check_name_conventions(arg_name) nullable = False # hacky check of whether `arg_typ` is `Optional`: `Optional` is `Union` with `type(None)` if getattr(arg_typ, '__origin__', None) is Union and NoneType in arg_typ.__args__: nullable = True # extract the type wrapped inside `Optional` arg_typ = next(t for t in arg_typ.__args__ if not isinstance(t, NoneType)) # type: ignore arg_val = getattr(cls, arg_name, None) required = not hasattr(cls, arg_name) or (arg_val is None and not nullable) type_dict[arg_name] = Arguments._ArgTypeSpec(arg_typ, nullable=nullable, required=required, default=arg_val) if len(bad_names) > 0: bad_names_str = ', '.join(f"'{s}'" for s in bad_names) raise ArgumentError(f"Invalid argument names: {bad_names_str}. " f"Names cannot begin or end with underscores.") if len(warn_names) > 0: warn_names_str = ', '.join(f"'{s}'" for s in warn_names) Logging.warn( f"Consider changing these argument names: {warn_names_str}. " f"Names are recommended to contain non-uppercase ASCII characters only." ) setattr(cls, _attr_name, type_dict) return type_dict
def __init__(self, *args, **kwargs) -> None: self._check_types() for k, v in kwargs.items(): setattr(self, k, v) # TODO: Add non-null checks # TODO: Add "no-" prefix stuff for switches # TODO: Generate help by inspecting comments if len(args) == 0: argv = sys.argv elif len(args) == 1: argv = args[0] else: raise ValueError( f"Argument class takes zero or one positional arguments but {len(args)} were given" ) i = 1 while i < len(argv): arg: str = argv[i] if arg.startswith('--'): argname = arg[2:].replace('-', '_') if argname.startswith('no_') and not hasattr( self, argname) and hasattr(self, argname[3:]): attr = getattr(self, argname[3:]) if isinstance(attr, Arguments.Switch): attr._value = False i += 1 continue if hasattr(self, argname): attr = getattr(self, argname) if isinstance(attr, Arguments.Switch): attr._value = True i += 1 continue nullable, typ = self._get_arg_type(argname) argval: str = argv[i + 1] if argval.lower() == 'none': if nullable: val = None else: assert typ is str or is_choices(typ), \ f"Cannot assign None to non-nullable, non-str argument '{argname}'" val = argval elif isinstance(typ, custom_types.NoneType): # type: ignore val = None # just to suppress "ref before assign" warning try: # priority: low -> high for target_typ in [str, float, int]: val = target_typ(argval) except ValueError: pass elif typ is str: val = argval elif isinstance( typ, custom_types.Path) or typ is custom_types.Path: val = Path(argval) if isinstance(typ, custom_types.Path) and typ.exists: assert val.exists(), ValueError( f"Argument '{argname}' requires an existing path, " f"but '{argval}' does not exist") elif is_choices(typ): val = argval assert val in typ.__values__, f"Invalid value '{val}' for argument '{arg}', " \ f"available choices are: {typ.__values__}" elif issubclass(typ, Arguments.Enum): # experimental support for custom enum try: # noinspection PyCallingNonCallable val = typ(argval) except ValueError: valid_args = {x.value for x in typ} raise ValueError( f"Invalid value '{argval}' for argument '{argname}', " f"available choices are: {valid_args}" ) from None elif typ is bool: val = argval in ['true', '1', 'True', 'y', 'yes'] else: try: val = ast.literal_eval(argval) except ValueError: raise ValueError( f"Invalid value '{argval}' for argument '{argname}'" ) from None setattr(self, argname, val) i += 2 else: raise ValueError(f"Invalid argument: '{arg}'") else: Logging.warn(f"Unrecognized command line argument: '{arg}'") i += 1 if self.pdb: # enter IPython debugger on exception from IPython.core import ultratb ipython_hook = ultratb.FormattedTB(mode='Context', color_scheme='Linux', call_pdb=1) def excepthook(type, value, traceback): if type is KeyboardInterrupt: # don't capture keyboard interrupts (Ctrl+C) sys.__excepthook__(type, value, traceback) else: ipython_hook(type, value, traceback) sys.excepthook = excepthook self.preprocess() # check whether non-optional attributes are none for arg in dir(self): if not arg.startswith('_') and arg not in self._reserved_keys: attr = getattr(self, arg) nullable, _ = self._get_arg_type(arg) if attr is None and not nullable: raise ValueError(f"argument '{arg}' cannot be none") self._validate() self.postprocess() # convert switches to bool for arg in dir(self): if not arg.startswith('_') and arg not in self._reserved_keys: attr = getattr(self, arg) typ = self.__annotations__.get(arg, None) if isinstance(attr, Arguments.Switch): # noinspection PyProtectedMember setattr(self, arg, bool(attr)) if isinstance(typ, type) and issubclass( typ, Path) and isinstance(attr, str): setattr(self, arg, Path(attr))
def __init__(self, **kwargs) -> None: self._check_types() for k, v in kwargs.items(): setattr(self, k, v) # TODO: Add non-null checks # TODO: Add "no-" prefix stuff for switches # TODO: Generate help by inspecting comments i = 1 while i < len(sys.argv): arg: str = sys.argv[i] if arg.startswith('--'): argname = arg[2:].replace('-', '_') if argname.startswith('no_') and not hasattr( self, argname) and hasattr(self, argname[3:]): attr = getattr(self, argname[3:]) if isinstance(attr, Arguments.Switch): attr._value = False i += 1 continue if hasattr(self, argname): attr = getattr(self, argname) if isinstance(attr, Arguments.Switch): attr._value = True i += 1 continue typ = self.__annotations__.get(argname, type(attr)) nullable = False # TODO: hacks here if hasattr( typ, '__origin__') and typ.__origin__ == Union and type( None) in typ.__args__: # hacky check of whether `typ` is `Optional` nullable = True typ = next(t for t in typ.__args__ if not isinstance(t, custom_types.NoneType) ) # type: ignore argval: str = sys.argv[i + 1] if argval.lower() == 'none': if nullable: val = None else: assert typ is str or is_choices(typ), \ f"Cannot assign None to non-nullable, non-str argument '{argname}'" val = argval elif isinstance(typ, custom_types.NoneType): # type: ignore val = None # just to suppress "ref before assign" warning try: # priority: low -> high for target_typ in [str, float, int]: val = target_typ(argval) except ValueError: pass elif typ is str: val = argval elif isinstance( typ, custom_types.Path) or typ is custom_types.Path: val = Path(argval) if isinstance(typ, custom_types.Path) and typ.exists: assert val.exists(), ValueError( f"Argument '{argname}' requires an existing path, " f"but '{argval}' does not exist") elif is_choices(typ): val = argval assert val in typ.__values__, f"Invalid value '{val}' for argument '{arg}', " \ f"available choices are: {typ.__values__}" elif issubclass(Arguments.Enum, typ): # experimental support for custom enum try: # noinspection PyCallingNonCallable val = typ(argval) except ValueError: valid_args = {x.value for x in typ} raise ValueError( f"Invalid value '{argval}' for argument '{argname}', " f"available choices are: {valid_args}" ) from None elif typ is bool: val = argval in ['true', '1', 'True', 'y', 'yes'] else: try: val = ast.literal_eval(argval) except ValueError: raise ValueError( f"Invalid value '{argval}' for argument '{argname}'" ) from None setattr(self, argname, val) i += 2 else: raise ValueError(f"Invalid argument: '{arg}'") else: Logging.warn(f"Unrecognized command line argument: '{arg}'") i += 1 if self.ipdb: # enter IPython debugger on exception from IPython.core import ultratb sys.excepthook = ultratb.FormattedTB(mode='Context', color_scheme='Linux', call_pdb=1) self.preprocess() self._validate() self.postprocess()
def sampling_decode(self, vocab: Dict[str, Vocab], example: LRLMExample, begin_symbol: int = 2, end_symbol: int = 5, initial_hidden: Optional[HiddenState] = None, warm_up: Optional[int] = None, max_length: int = 200, greedy: bool = False, topk: Optional[int] = None, print_info: bool = True, color_outputs: bool = False, show_rel_type: bool = True, sanity_check: bool = False, unkinfo: Optional[Tuple[Tensor, List[str]]] = None, **kwargs) \ -> SampledOutput: r""" Sampling for LRLM. Output format: - Red words: Copied from canonical form of entity. - Green words: Copied from alias form of entity. - Yellow words: Warm-up context. - word_[type]: "word" is an entity of type "type". - @-@: A dash in the original text without spaces around, e.g. M @-@ 82 => M-82. :param vocab: Vocabulary containing id2word mapping. :param example: The :class:`Example` object of the current topic. :param begin_symbol: Start of sentence symbol. :param end_symbol: End of sentence symbol. Sampling stops when this symbol is generated. :param initial_hidden: If not specified, default hidden states returned by :meth:`init_hidden` is used. :param warm_up: Number of tokens to provide as context before performing sampling. :param max_length: If generated sentence exceeds specified length, sampling is force terminated. :param greedy: If ``True``, use greedy decoding instead of sampling. :param topk: If not ``None``, only sample from indices with top-k probabilites. :param print_info: If ``True``, print information about sampled result. :param color_outputs: If ``True``, include annotations for each output token. Tokens from entities will be colored red. :param show_rel_type: If ``True``, show relation types for copied entities. :param sanity_check: If ``True``, perform sanity check on generated sample. :param unkinfo: Precomputed unkprobs and the index-to-vocabulary mapping. :return: A tuple of (loss_value, formatted list of words). """ if unkinfo is not None: unkprob, unki2w = unkinfo unkprob = unkprob[self._vocab_size:] unki2w = unki2w[self._vocab_size:] normalized_unkprob = F.log_softmax(unkprob, dim=0) # noinspection PyPep8Naming UNK, INVALID, CANONICAL_IDX, WORD_PREDICTOR, REL_PREDICTOR, EPS = -100, -1, 0, 0, 1, 1e-4 self.eval() self.init_hidden(1, [example.relations]) word_vocab, rel_vocab = vocab['word'], vocab['rel'] tensor = functools.partial(sample_utils.tensor, device=self.device) sample = functools.partial(sample_utils.sample, greedy=greedy, topk=topk) np_sample = functools.partial(sample_utils.np_sample, greedy=greedy, topk=topk) # noinspection PyShadowingNames def compute_loss( inputs: List[int], spans: List[MatchedSpan], hidden: Optional[HiddenState] = None ) -> Tuple[float, HiddenState]: batch = SimpleNamespace( sequence=tensor(inputs[:-1]), target=tensor(inputs[1:]), spans=[spans], unkprob=None, lengths=torch.tensor([len(inputs) - 1], device=self.device), ntokens=len(inputs) - 1, ) loss, next_hidden = self.calc_loss(batch, hidden=hidden) # type: ignore return loss.item(), next_hidden if warm_up is None: inputs = [begin_symbol] rel_ids = [INVALID] surface_indices = [INVALID] spans: List[MatchedSpan] = [] total_log_prob = 0.0 marginal_log_prob = 0.0 hidden = initial_hidden else: inputs = list(word_vocab.numericalize(example.sentence[:warm_up])) rel_ids = [INVALID] * len( inputs) # assume everything is generated from vocabulary surface_indices = [INVALID] * len(inputs) spans = [span for span in example.spans if span.end < warm_up] loss, hidden = compute_loss(inputs, spans, initial_hidden) total_log_prob = -loss * (len(inputs) - 1) marginal_log_prob = -loss * (len(inputs) - 1) while len(inputs) < max_length and inputs[-1] != end_symbol: computed_log_probs, new_hidden = self._compute_log_probs( tensor(inputs[-1]), hidden) predictor, selector_loss = sample(computed_log_probs.selector) if predictor == REL_PREDICTOR: rel_id, rel_loss = sample(computed_log_probs.rel[0]) if self._alias_disamb is AliasDisamb.FastText: assert computed_log_probs.alias_logits is not None aliases = example.relations[rel_id].obj_alias alias_vecs = self.alias_vec[aliases] surface_log_prob = F.log_softmax(torch.mv( alias_vecs, computed_log_probs.alias_logits.flatten()), dim=0) surface_idx, alias_loss = sample(surface_log_prob) alias = self.alias_list[aliases[surface_idx]] else: # can't tell which one under oracle, use the canonical (first) alias surface_idx = 0 alias_loss = 0.0 alias = example.relations[rel_id].obj_alias[ 0] # type: ignore # forward the hidden state according to the generated in-vocab tokens raw_tokens: List[str] = alias.split() token_ids: List[int] = word_vocab.numericalize(raw_tokens) if len(raw_tokens) > 1: _, new_hidden = self._compute_log_probs( tensor(token_ids[:-1]), new_hidden) # compute marginal probability for current span span_inputs = tensor([inputs[-1]] + token_ids[:-1]) span_computed_log_probs, _ = self._compute_log_probs( span_inputs, hidden) word_gen_loss = torch.sum( span_computed_log_probs.selector[0, :, WORD_PREDICTOR] + torch.gather(span_computed_log_probs.word, index=tensor(token_ids).unsqueeze(-1), dim=2).flatten()).item() marginal_log_prob += torch.logsumexp(tensor( [selector_loss + rel_loss + alias_loss, word_gen_loss]), dim=1).item() spans.append( MatchedSpan( len(inputs) - 1, len(inputs) + len(token_ids) - 1, example.relations[rel_id].rel_typ, rel_id, surface_idx)) inputs.extend(token_ids) rel_ids.extend([rel_id] + [INVALID] * (len(token_ids) - 1)) surface_indices.extend([surface_idx] + [INVALID] * (len(token_ids) - 1)) total_log_prob += selector_loss + rel_loss + alias_loss elif predictor == WORD_PREDICTOR: word, word_loss = sample(computed_log_probs.word) total_log_prob += selector_loss + word_loss marginal_log_prob += selector_loss + word_loss if word == 0 and unkinfo is not None: # unk unk_idx, unk_loss = np_sample(normalized_unkprob) total_log_prob += unk_loss marginal_log_prob += unk_loss # Ugly multi-purpose use of variables. surface_indices.append( unk_idx) # Record unk word index in surface_indices. rel_ids.append(UNK) # Record UNK in rel_ids. else: rel_ids.append(INVALID) surface_indices.append(INVALID) inputs.append(word) else: raise ValueError hidden = new_hidden sample_loss = -total_log_prob / (len(inputs) - 1) marginal_loss = -marginal_log_prob / (len(inputs) - 1) if print_info: print( f"Sample loss: {sample_loss:.3f}, PPL: {math.exp(sample_loss):.3f}" ) print( f"Marginal sample loss: {marginal_loss:.3f}, PPL: {math.exp(marginal_loss):.3f}" ) # Sanity checks if sanity_check: # noinspection PyTypeChecker loss_val, gold_hidden = compute_loss(inputs, spans, initial_hidden) assert hidden is not None hidden_state_diff = max( torch.max(torch.abs(g - h)).item() for g, h in zip(gold_hidden, hidden)) if hidden_state_diff > EPS: Logging.warn( f"Hidden states do not match. Difference: {hidden_state_diff}" ) if abs(marginal_loss - loss_val) > EPS: Logging.warn( f"Marginal loss values do not match. " f"Forward loss: {loss_val}, difference: {abs(marginal_loss - loss_val)}" ) num_rels_generated = sum(int(rel_id != INVALID) for rel_id in rel_ids) if print_info: print( f"Relations [Generated / Annotated]: " f"[{num_rels_generated} / {len([s for s in example.spans if s.end < max_length])}]" ) words = [] idx = 0 copy_count = 0 while idx < len(inputs): is_warm_up = (warm_up is not None and idx < warm_up) token_id, rel_id, surface_idx = inputs[idx], rel_ids[ idx], surface_indices[idx] if rel_id == INVALID: token = word_vocab.i2w[token_id] idx += 1 elif rel_id == UNK: token = Logging.color('blue', unki2w[surface_idx]) idx += 1 else: copy_count += 1 word_id = example.relations[rel_id].obj_alias[ surface_idx] # multiple words token = self.alias_list[word_id] idx += len(token.split()) if show_rel_type: token = f"{token}_[{rel_vocab.i2w[example.relations[rel_id].rel_typ]}]" if color_outputs and not is_warm_up: token = Logging.color( 'red' if surface_idx == CANONICAL_IDX else 'green', token) if color_outputs and is_warm_up: token = Logging.color('yellow', token) words.append(token) if print_info: print(f"# of copied entities: {copy_count}") output = SampledOutput(sentence=words, sample_loss=sample_loss, complete_copies=copy_count, incomplete_copies=0) return output
def main(): Logging.verbosity_level = Logging.VERBOSE Logging.warn("This program requires lots of memory (preferably >= 30GB).") if not SAVE_DIR.exists(): SAVE_DIR.mkdir(parents=True) # Read the Wikimedia IDs for each article, and filter the relations topic_ids: Set[WikidataID] = set() split_title_id: Dict[str, List[Tuple[str, WikidataID]]] = {} for split in ['train', 'valid', 'test']: with utils.work_in_progress(f"Loading {split} set titles"), \ open(TOPIC_JSON_PATH(split=split)) as f: j = json.load(f) split_title_id[split] = [(article['title'], WikidataID(article['id'])) for article in j] topic_ids.update([wid for _, wid in split_title_id[split]]) del j with utils.work_in_progress("Loading Wikidata ID mapping"): id2rel = load_id2str(WIKIDATA_DUMP_DIR / 'properties.txt') # Match the relations matched_dataset = read_data(ALIGNED_DATA_DIR) # Gather entities & relation vectors found_entities = set() found_rels = set() for split in matched_dataset: for example in matched_dataset[split]: found_entities.add(example.topic_id) for rel in example.relations: found_entities.add(rel.obj_id) found_rels.add(rel.rel_typ) found_entities -= {UNK_ENTITY} found_rels -= {NAF, ANCHOR, TOPIC_ITSELF} with utils.work_in_progress("Building rel vecs"): rel_map = load_relations(found_rels) rel_map.update({NAF: -1, ANCHOR: -2, TOPIC_ITSELF: -3}) unk_rels = found_rels.difference(rel_map) # NOTE: unk_rels is a set, its order is undetermined, so we sort it to make sure it's consistent between runs for idx, rel in enumerate(sorted(unk_rels)): rel_map[rel] = -idx - 4 # starting from -4, going towards -inf with utils.work_in_progress("Building entity vecs"): entity_map = load_entities(found_entities) entity_map.update({UNK_ENTITY: -1}) print( f"Topic ID coverage: {len(topic_ids.intersection(entity_map))}/{len(topic_ids)}" ) # save relation type names for use during generation id_to_rel_name = dict(id2rel) id_to_rel_name.update({ NAF: 'Not-A-Fact', ANCHOR: 'ANCHOR', TOPIC_ITSELF: 'TITLE' }) rel_names: Dict[int, str] = {} for r_rel, rel_id in rel_map.items(): rel_names[rel_id] = id_to_rel_name[r_rel] with (SAVE_DIR / 'rel_names.pkl').open('wb') as f: pickle.dump(rel_names, f) print(f"Relation names saved to {(SAVE_DIR / 'rel_names.pkl')}") # Convert into numbers to create the final dataset for split in matched_dataset: with utils.work_in_progress(f"Converting {split} set"): dataset, matched_spans = numericalize_rel(matched_dataset[split], rel_map, entity_map) path = SAVE_DIR / f'{split}.pkl' with path.open('wb') as f: pickle.dump(dataset, f) print( f"Dataset split '{split}' saved to {path}, {len(dataset)} examples" ) path = SAVE_DIR / f'{split}.span.pkl' with path.open('wb') as f: pickle.dump(matched_spans, f) print(f"Matched spans split '{split}' saved to {path}")
def read_data(path: Path) -> Dict[str, List[RawExampleWikiID]]: bad_examples: List[Tuple[str, int, str]] = [] data = {} for split in ['train', 'valid', 'test']: with (path / f'{split}.pkl').open('rb') as f: # relation tuple: (span, rel_type_desc, name, canonical_name) with utils.work_in_progress(f"Loading {split} set"): dump: List[RawDump] = pickle.load(f) examples = [] for idx, (sent, rels) in enumerate( utils.progress(dump, desc='Reading data')): # map (rel_typ, canonical) to list of aliases, since lists aren't hashable rel_to_alias: Dict[Tuple[str, str], List[str]] = \ {(rel[0][0], obj_id): alias for obj_id, _, _, rel, _, alias in rels} # sort it so the order is consistent relations: List[RelationWikiID] = sorted([ RelationWikiID(WikidataID(rel_id), WikidataID(obj_id), obj_alias) for (rel_id, obj_id), obj_alias in rel_to_alias.items() ]) rel_to_id: Dict[Tuple[str, str], int] = { (rel_id, obj_id): idx for idx, (rel_id, obj_id, obj_alias) in enumerate(relations) } # dedup to remove duplicate (-1, -1) mentions: List[EntityMention] = list( set( EntityMention(span, surface, rel_to_id[(rel_info[0][0], obj_id)]) for obj_id, head_id, span, rel_info, surface, _ in rels)) try: # must exist - head id with the relation: @TITLE@ is the topic WikidataID topic_id = next( head_id for _, head_id, _, rel_info, surface, alias in rels if rel_info[0][0] == "@TITLE@") except StopIteration: bad_examples.append((split, idx, ' '.join(sent)[:100])) continue converted_relations = [] for r in relations: converted_relations.append( RelationWikiID( TOPIC_ITSELF if r.rel_typ == "@TITLE@" else r.rel_typ, r.obj, r.obj_alias)) example = RawExampleWikiID(WikidataID(topic_id), sent, converted_relations, mentions) examples.append(example) data[split] = examples if len(bad_examples) > 0: Logging.warn(f"{len(bad_examples)} bad examples:\n" f"{pprint.pformat(bad_examples)}") else: Logging.verbose("All examples are good") return data
def color_if_less(val: float, threshold: float, format_str: str = '{:.4f}', color: str = 'yellow'): s = format_str.format(val) return Logging.color(color, s) if val < threshold else s
def span_log_probs(example, _, bptt_size=140, ppl_threshold=200.0, split_len=20, n_context=5, max_segments=-1): (init_batch, batches) = dataset.create_one_batch([example], bptt_size) rels: List[Relation] = init_batch[0] posterior_probs: List[List[Optional[Tuple[float, float]]]] = [ ] # list(n_batches) of list(n_spans) seq_loss: List[float] = [] # noinspection PyShadowingNames def callback(loss: Tensor, batch: BatchSequence) -> None: assert batch.spans is not None probs_dict: Dict[MatchedSpan, Tuple[ float, float]] = model.model_cache['posterior_log_probs'][0] probs = [probs_dict.get(span, None) for span in batch.spans[0]] posterior_probs.append(probs) seq_loss.append(loss.item()) def color_if_less(val: float, threshold: float, format_str: str = '{:.4f}', color: str = 'yellow'): s = format_str.format(val) return Logging.color(color, s) if val < threshold else s with torch.no_grad(): model.eval() compute_batch_loss(model, init_batch, batches, use_unk_probs=True, callback=callback, evaluate=True, calc_loss_kwargs={'dump_posterior_probs': True}) n_words = 0 for seq_idx, (batch, probs) in enumerate(zip(batches, posterior_probs)): if max_segments != -1 and seq_idx >= max_segments: break print( Logging.color( 'green', f"Segment #{seq_idx}: " f"words {n_words} - {n_words + batch.ntokens}, " f"ppl = {math.exp(seq_loss[seq_idx]):.4f}")) n_words += batch.ntokens tokens = batch.raw_sequence[0][1:] spans = batch.spans[0] is_in_span = [False] * batch.ntokens for span in spans: if span.start > span.end or span.end >= batch.ntokens: continue is_in_span[span.start:( span.end + 1)] = [True] * (span.end - span.start + 1) for idx in range(0, batch.ntokens, split_len): print( f'{idx:3d}:', ' '.join( Logging.color('red', w) if in_span else w for w, in_span in zip( tokens[idx:(idx + split_len)], is_in_span[idx:( idx + split_len)]))) print() for span, prob in sorted(zip(spans, probs)): if prob is None: continue rel_prob, word_prob = prob l = max(0, span.start - n_context) r = min(batch.ntokens, span.end + 1 + n_context) print( f"[{span.start}, {span.end}]" f" <{dataset.rel_vocab.i2w[span.rel_typ]}>" f" {Logging.color('red', rels[span.rel_idx].obj_alias[span.surface_idx])}" f"{' (alias)' if span.surface_idx > 0 else ''}" f": rel = {color_if_less(math.exp(-rel_prob), ppl_threshold)}" f", word = {color_if_less(math.exp(-word_prob), ppl_threshold)}" ) print( ' ', '... ' if l > 0 else '', ' '.join( Logging.color('red', tokens[idx]) if span.start <= idx <= span.end else tokens[idx] for idx in range(l, r)), ' ...' if r < batch.ntokens else '') print()
def posterior_log_probs(example, _, bptt_size=140, n_context=5, max_segments=-1, ignore_single_relation=False, file=sys.stdout): from collections import defaultdict from dataset.utils import flip_batches, search_paths (init_batch, batches) = dataset.create_one_batch([example], bptt_size) flipped_batches = flip_batches(batches) word_prob: Dict[str, List[np.ndarray]] = { k: [] for k in ["forward", "backward"] } marginal_prob: Dict[str, List[np.ndarray]] = { k: [] for k in ["forward", "backward"] } posterior_probs: List[Dict[MatchedSpan, Tuple[float, float]]] = [ ] # list(n_batches) of list(n_spans) seq_loss: List[float] = [] def callback(loss: Tensor, _batch: BatchSequence): posterior_probs.append(model.model_cache['posterior_log_probs'][0]) word_prob['forward'].append( model.model_cache['target_cond_log_probs'][0]) marginal_prob['forward'].append( model.model_cache['stacked_log_probs'][0]) seq_loss.append(loss.item()) def callback_flip(_loss: Tensor, _batch: BatchSequence): marginal_prob['backward'].append( model.model_cache['stacked_log_probs'][0]) with torch.no_grad(): model.eval() compute_batch_loss(model, init_batch, batches, use_unk_probs=True, callback=callback, evaluate=True, calc_loss_kwargs={'dump_posterior_probs': True}) compute_batch_loss(model, init_batch, flipped_batches, use_unk_probs=True, callback=callback_flip, evaluate=True, calc_loss_kwargs={'dump_posterior_probs': True}) # credit: http://bayesjumping.net/log-sum-exp-trick/ for without using scipy def log_sum_exp(ns: List[int]): max_ = np.max(ns) sum_exp = np.exp(ns - max_).sum() return max_ + np.log(sum_exp) n_words = 0 for seq_idx, (batch, probs_dict) in enumerate(zip(batches, posterior_probs)): if max_segments != -1 and seq_idx >= max_segments: break tokens = batch.raw_sequence[0][1:] spans = batch.spans[0] if len(spans) == 0: continue overlap_group = defaultdict(list) sorted_spans = sorted(spans, key=lambda x: (x.start, x.end)) latest = (sorted_spans[0].start, sorted_spans[0].end ) # The most recent span group overlap_group[latest] = [sorted_spans[0]] for sp in sorted_spans[1:]: if sp.start > sp.end or sp.end >= batch.ntokens: continue if sp.start <= latest[1]: grp = overlap_group[latest] del overlap_group[latest] latest = (latest[0], max(latest[1], sp.end)) overlap_group[latest] = grp + [sp] else: latest = (sp.start, sp.end) overlap_group[latest] = [sp] if np.any([len(g) > 1 for g in overlap_group.values() ]) or not ignore_single_relation: print(Logging.color( 'green', f"Segment #{seq_idx}: " f"words {n_words} - {n_words + batch.ntokens}, " f"ppl = {math.exp(seq_loss[seq_idx]):.4f}"), file=file) n_words += batch.ntokens for span, group in overlap_group.items(): if ignore_single_relation and len(group) == 1: continue alpha = marginal_prob["forward"][seq_idx][span[0] - 1] beta = marginal_prob["backward"][::-1][seq_idx][ batch.lengths[0] - (span[1] + 1) - 1] # Enumerate all the paths paths = search_paths(group, span[0], span[1]) log_probs = [] annotations = [] delimiters = [] for path in paths: path_anno = [] path_delims = [] logprob = alpha + beta for hop in path: if hop.rel_typ == -100: # dummy relation - word transition logprob += word_prob["forward"][seq_idx][span[0]] path_anno.append("word") else: logprob += probs_dict[hop][0] path_anno.append( dataset.rel_vocab.i2w[hop.rel_typ]) path_delims += [" "] * (hop.end - hop.start) if hop.end < span[1]: path_delims.append(" | ") delimiters.append(path_delims) log_probs.append(logprob) annotations.append(path_anno) log_denom = log_sum_exp(log_probs) normalized_probs = [ np.exp(log_prob - log_denom) for log_prob in log_probs ] l = max(0, span[0] - n_context) r = min(batch.ntokens, span[1] + n_context) token_string = " ".join([ ' ', '... ' if l > 0 else '', ' '.join(tokens[idx] for idx in range(l, span[0])), '|', ' '.join( Logging.color('red', tokens[idx]) for idx in range(span[0], span[1] + 1)), '|', ' '.join(tokens[idx] for idx in range(span[1] + 1, r)), ' ...' if r < batch.ntokens else '' ]) print(token_string, file=file) annotation_strings = [" => ".join(a) for a in annotations] max_anno_len = max([len(a) for a in annotation_strings]) max_score_idx = np.argmax(normalized_probs) for idx, (delim, anno, prob) in enumerate( zip(delimiters, annotation_strings, normalized_probs)): matched_span_tokens = ( " " * token_string.index(" | ") + " | " + ' _ '.join(tokens[idx] for idx in range(span[0], span[1] + 1)) + " | ") delim_positions = re.finditer(r" _ ", matched_span_tokens) for d, match in zip(delim, delim_positions): pos = match.start(0) matched_span_tokens = matched_span_tokens[: pos] + d + matched_span_tokens[ (pos + 3):] score = Logging.color( "green", f" {prob:1.4f}" ) if idx == max_score_idx else f" {prob:1.4f}" matched_span_tokens += " " + f"{anno}{' ' * (max_anno_len - len(anno))}" + score print(matched_span_tokens, file=file) print(file=file)
def repl(dataset: KBLMDataset, model: BaseLM): import re from run import compute_batch_loss # avoid circular import print( Logging.color( s="Execute `sample(name=\"Barack Obama\")` to generate samples for given entity.\n" "Execute `sample(split='test', index=1)` to generate samples for specific data entry.\n" "For more configurable settings, please refer to method `models.lrlm.LRLM.sampling_decode`.\n", col='green')) def get_example(func): def find_name(name: str) -> Tuple[str, int]: for split in dataset.data: try: index = next(idx for idx, ex in enumerate(dataset.data[split]) if name in ' '.join(ex.sentence[:10])) return split, index except StopIteration: continue else: raise ValueError("Name not found!") @functools.wraps(func) def wrapped(name: Optional[str] = None, split: str = 'test', index: int = 1, start_symbol: str = '<s>', **kwargs): if name is not None: try: split, index = find_name(f"= {name} =") except ValueError: split, index = find_name(name) if start_symbol is None: start_symbol = '<s>' start_symbol = dataset.word_vocab.w2i[start_symbol] if isinstance( start_symbol, str) else start_symbol example = dataset.data[split][index] try: name = ' '.join( example.sentence[2:example.sentence.index('=', 2)]) except ValueError: # probably WikiFacts pos = min((example.sentence.index(token) for token in ['is', 'was', '('] if token in example.sentence), default=3) name = ' '.join(example.sentence[1:pos]) if "file" not in kwargs: file = sys.stdout else: file = kwargs["file"] print(f"Data split {split}, index {index}, name \"{name}\"", file=file) func(example, start_symbol, **kwargs) return wrapped end_symbol = dataset.word_vocab.w2i['</s>'] @get_example def sample(example, start_symbol, max_length=500, n_tries=1, bptt_size=150, **kwargs): (init_batch, batches) = dataset.create_one_batch([example], bptt_size) if kwargs.get('generate_unk', False): fulli2w = [ w for w, i in sorted(dataset.total_w2i.items(), key=lambda x: x[1]) ] unkinfo = (dataset.unk_probs, fulli2w) del kwargs["generate_unk"] else: unkinfo = None best_output = None for _ in range(n_tries): print_info = kwargs.get('print_info', n_tries == 1) output: SampledOutput = model.sampling_decode( dataset.vocab, example, begin_symbol= start_symbol, # this is the actual start symbol in all articles end_symbol=end_symbol, max_length=max_length, color_outputs=True, # generate colored outputs in terminal print_info=print_info, unkinfo=unkinfo, init_batch=init_batch, **kwargs) if best_output is None or output.sample_loss < best_output.sample_loss: best_output = output # format title & subtitles if n_tries > 1: print( f"Sample loss: {best_output.sample_loss:.3f}, PPL: {math.exp(best_output.sample_loss):.3f}" ) print( f"Complete / incomplete entities: {best_output.complete_copies} / {best_output.incomplete_copies}" ) sentence = re.sub(r'= = (.*?) = = ', r'\n\n== \1 ==\n ', ' '.join(best_output.sentence)) sentence = re.sub(r'= (.*?) = ', r'= \1 =\n ', sentence) print(sentence) @get_example def average_copies(example, start_symbol, max_length=200, n_tries=10, progress=False, show_samples=False, **kwargs): complete_copies = utils.SimpleAverage() incomplete_copies = utils.SimpleAverage() if show_samples: sample_kwargs = dict(print_info=True, color_outputs=True, color_incomplete=False, **kwargs) else: sample_kwargs = dict(print_info=False, color_outputs=False, **kwargs) for _ in utils.progress(n_tries, verbose=progress): output: SampledOutput = model.sampling_decode( dataset.vocab, example, begin_symbol=start_symbol, end_symbol=end_symbol, max_length=max_length, **sample_kwargs) complete_copies.add(output.complete_copies) incomplete_copies.add(output.incomplete_copies) if show_samples: print(' '.join(output.sentence)) if isinstance(example, LRLMExample): n_gold_rels = sum( int(span.end < max_length) for span in example.spans) print( f"Complete / Gold entities: {complete_copies.value()} / {n_gold_rels}" ) else: print( f"Complete / Incomplete entities: {complete_copies.value()} / {incomplete_copies.value()}" ) @get_example def logprob_article(example, _start_symbol, bptt_size=150, **_kwargs): (init_batch, batches) = dataset.create_one_batch([example], bptt_size) # log-prob calculation def callback(loss: Tensor, batch: BatchSequence, *_extra): if batch.spans is None: # NKLM rel_ids = np.concatenate( [[-1], batch.seqs['rel_ids'][0].cpu().numpy(), [-1]]) diff = np.ediff1d(rel_ids) n_rels = np.count_nonzero(np.cumsum(diff[diff != 0])) else: n_rels = len(batch.spans[0]) print(f"{n_rels:2d}\t{math.exp(loss.item()):2.3f}") print("#rels\tPPL") with torch.no_grad(): model.eval() total_loss = compute_batch_loss(model, init_batch, batches, use_unk_probs=True, callback=callback, evaluate=True) total_loss /= sum(batch.ntokens for batch in batches) print(f"Total PPL: {math.exp(total_loss)}") @get_example def posterior_log_probs(example, _, bptt_size=140, n_context=5, max_segments=-1, ignore_single_relation=False, file=sys.stdout): from collections import defaultdict from dataset.utils import flip_batches, search_paths (init_batch, batches) = dataset.create_one_batch([example], bptt_size) flipped_batches = flip_batches(batches) word_prob: Dict[str, List[np.ndarray]] = { k: [] for k in ["forward", "backward"] } marginal_prob: Dict[str, List[np.ndarray]] = { k: [] for k in ["forward", "backward"] } posterior_probs: List[Dict[MatchedSpan, Tuple[float, float]]] = [ ] # list(n_batches) of list(n_spans) seq_loss: List[float] = [] def callback(loss: Tensor, _batch: BatchSequence): posterior_probs.append(model.model_cache['posterior_log_probs'][0]) word_prob['forward'].append( model.model_cache['target_cond_log_probs'][0]) marginal_prob['forward'].append( model.model_cache['stacked_log_probs'][0]) seq_loss.append(loss.item()) def callback_flip(_loss: Tensor, _batch: BatchSequence): marginal_prob['backward'].append( model.model_cache['stacked_log_probs'][0]) with torch.no_grad(): model.eval() compute_batch_loss(model, init_batch, batches, use_unk_probs=True, callback=callback, evaluate=True, calc_loss_kwargs={'dump_posterior_probs': True}) compute_batch_loss(model, init_batch, flipped_batches, use_unk_probs=True, callback=callback_flip, evaluate=True, calc_loss_kwargs={'dump_posterior_probs': True}) # credit: http://bayesjumping.net/log-sum-exp-trick/ for without using scipy def log_sum_exp(ns: List[int]): max_ = np.max(ns) sum_exp = np.exp(ns - max_).sum() return max_ + np.log(sum_exp) n_words = 0 for seq_idx, (batch, probs_dict) in enumerate(zip(batches, posterior_probs)): if max_segments != -1 and seq_idx >= max_segments: break tokens = batch.raw_sequence[0][1:] spans = batch.spans[0] if len(spans) == 0: continue overlap_group = defaultdict(list) sorted_spans = sorted(spans, key=lambda x: (x.start, x.end)) latest = (sorted_spans[0].start, sorted_spans[0].end ) # The most recent span group overlap_group[latest] = [sorted_spans[0]] for sp in sorted_spans[1:]: if sp.start > sp.end or sp.end >= batch.ntokens: continue if sp.start <= latest[1]: grp = overlap_group[latest] del overlap_group[latest] latest = (latest[0], max(latest[1], sp.end)) overlap_group[latest] = grp + [sp] else: latest = (sp.start, sp.end) overlap_group[latest] = [sp] if np.any([len(g) > 1 for g in overlap_group.values() ]) or not ignore_single_relation: print(Logging.color( 'green', f"Segment #{seq_idx}: " f"words {n_words} - {n_words + batch.ntokens}, " f"ppl = {math.exp(seq_loss[seq_idx]):.4f}"), file=file) n_words += batch.ntokens for span, group in overlap_group.items(): if ignore_single_relation and len(group) == 1: continue alpha = marginal_prob["forward"][seq_idx][span[0] - 1] beta = marginal_prob["backward"][::-1][seq_idx][ batch.lengths[0] - (span[1] + 1) - 1] # Enumerate all the paths paths = search_paths(group, span[0], span[1]) log_probs = [] annotations = [] delimiters = [] for path in paths: path_anno = [] path_delims = [] logprob = alpha + beta for hop in path: if hop.rel_typ == -100: # dummy relation - word transition logprob += word_prob["forward"][seq_idx][span[0]] path_anno.append("word") else: logprob += probs_dict[hop][0] path_anno.append( dataset.rel_vocab.i2w[hop.rel_typ]) path_delims += [" "] * (hop.end - hop.start) if hop.end < span[1]: path_delims.append(" | ") delimiters.append(path_delims) log_probs.append(logprob) annotations.append(path_anno) log_denom = log_sum_exp(log_probs) normalized_probs = [ np.exp(log_prob - log_denom) for log_prob in log_probs ] l = max(0, span[0] - n_context) r = min(batch.ntokens, span[1] + n_context) token_string = " ".join([ ' ', '... ' if l > 0 else '', ' '.join(tokens[idx] for idx in range(l, span[0])), '|', ' '.join( Logging.color('red', tokens[idx]) for idx in range(span[0], span[1] + 1)), '|', ' '.join(tokens[idx] for idx in range(span[1] + 1, r)), ' ...' if r < batch.ntokens else '' ]) print(token_string, file=file) annotation_strings = [" => ".join(a) for a in annotations] max_anno_len = max([len(a) for a in annotation_strings]) max_score_idx = np.argmax(normalized_probs) for idx, (delim, anno, prob) in enumerate( zip(delimiters, annotation_strings, normalized_probs)): matched_span_tokens = ( " " * token_string.index(" | ") + " | " + ' _ '.join(tokens[idx] for idx in range(span[0], span[1] + 1)) + " | ") delim_positions = re.finditer(r" _ ", matched_span_tokens) for d, match in zip(delim, delim_positions): pos = match.start(0) matched_span_tokens = matched_span_tokens[: pos] + d + matched_span_tokens[ (pos + 3):] score = Logging.color( "green", f" {prob:1.4f}" ) if idx == max_score_idx else f" {prob:1.4f}" matched_span_tokens += " " + f"{anno}{' ' * (max_anno_len - len(anno))}" + score print(matched_span_tokens, file=file) print(file=file) @get_example def span_log_probs(example, _, bptt_size=140, ppl_threshold=200.0, split_len=20, n_context=5, max_segments=-1): (init_batch, batches) = dataset.create_one_batch([example], bptt_size) rels: List[Relation] = init_batch[0] posterior_probs: List[List[Optional[Tuple[float, float]]]] = [ ] # list(n_batches) of list(n_spans) seq_loss: List[float] = [] # noinspection PyShadowingNames def callback(loss: Tensor, batch: BatchSequence) -> None: assert batch.spans is not None probs_dict: Dict[MatchedSpan, Tuple[ float, float]] = model.model_cache['posterior_log_probs'][0] probs = [probs_dict.get(span, None) for span in batch.spans[0]] posterior_probs.append(probs) seq_loss.append(loss.item()) def color_if_less(val: float, threshold: float, format_str: str = '{:.4f}', color: str = 'yellow'): s = format_str.format(val) return Logging.color(color, s) if val < threshold else s with torch.no_grad(): model.eval() compute_batch_loss(model, init_batch, batches, use_unk_probs=True, callback=callback, evaluate=True, calc_loss_kwargs={'dump_posterior_probs': True}) n_words = 0 for seq_idx, (batch, probs) in enumerate(zip(batches, posterior_probs)): if max_segments != -1 and seq_idx >= max_segments: break print( Logging.color( 'green', f"Segment #{seq_idx}: " f"words {n_words} - {n_words + batch.ntokens}, " f"ppl = {math.exp(seq_loss[seq_idx]):.4f}")) n_words += batch.ntokens tokens = batch.raw_sequence[0][1:] spans = batch.spans[0] is_in_span = [False] * batch.ntokens for span in spans: if span.start > span.end or span.end >= batch.ntokens: continue is_in_span[span.start:( span.end + 1)] = [True] * (span.end - span.start + 1) for idx in range(0, batch.ntokens, split_len): print( f'{idx:3d}:', ' '.join( Logging.color('red', w) if in_span else w for w, in_span in zip( tokens[idx:(idx + split_len)], is_in_span[idx:( idx + split_len)]))) print() for span, prob in sorted(zip(spans, probs)): if prob is None: continue rel_prob, word_prob = prob l = max(0, span.start - n_context) r = min(batch.ntokens, span.end + 1 + n_context) print( f"[{span.start}, {span.end}]" f" <{dataset.rel_vocab.i2w[span.rel_typ]}>" f" {Logging.color('red', rels[span.rel_idx].obj_alias[span.surface_idx])}" f"{' (alias)' if span.surface_idx > 0 else ''}" f": rel = {color_if_less(math.exp(-rel_prob), ppl_threshold)}" f", word = {color_if_less(math.exp(-word_prob), ppl_threshold)}" ) print( ' ', '... ' if l > 0 else '', ' '.join( Logging.color('red', tokens[idx]) if span.start <= idx <= span.end else tokens[idx] for idx in range(l, r)), ' ...' if r < batch.ntokens else '') print() from IPython import embed embed()
def sampling_decode(self, vocab: Dict[str, Vocab], example: NKLMExample, begin_symbol: int = 2, end_symbol: int = 5, initial_hidden: Optional[HiddenState] = None, warm_up: Optional[int] = None, max_length: int = 200, greedy=False, topk=None, fill_incomplete=False, allow_invalid_pos=False, print_info=True, color_outputs=False, color_incomplete=True, show_ellipses=True, show_rel_type=True, show_copy_pos=False, sanity_check=False, unkinfo: Optional[Tuple[Tensor, List[str]]] = None, **kwargs) \ -> SampledOutput: """ Sampling for NKLM. Output format: - Red words: Copied from canonical form of entity. - Green words: Copied from alias form of entity. - Yellow words: Warm-up context. - word_[type]: "word" is an entity of type "type". - word...(a_b_c): "word" is a partially copied entity with remaining suffix "a b c". - (a_b_c)...word: "word" is a partially copied entity with remaining prefix "a b c". - @-@: A dash in the original text without spaces around, e.g. M @-@ 82 => M-82. - <X>: A token copied from an invalid position of an entity. :param vocab: Vocabulary containing id2word mapping. :param example: The :class:`Example` object of the current topic. :param begin_symbol: Start of sentence symbol. :param end_symbol: End of sentence symbol. Sampling stops when this symbol is generated. :param initial_hidden: If not specified, default hidden states returned by :meth:`init_hidden` is used. :param warm_up: Number of tokens to provide as context before performing sampling. :param max_length: If generated sentence exceeds specified length, sampling is force terminated. :param greedy: If ``True``, use greedy decoding instead of sampling. :param topk: If not ``None``, only sample from indices with top-k probabilites. :param fill_incomplete: If ``True``, entities that are partially copied will be completed. :param allow_invalid_pos: If ``True``, allowing copying from invalid positions, and use <unk> as input. :param print_info: If ``True``, print information about sampled result. :param color_outputs: If ``True``, include annotations for each output token. Tokens from entities will be colored red. :param color_incomplete: If ``True`` and ``color_outputs`` is ``True``, also color partially copied entities. :param show_ellipses: If ``True``, show ellipses at beginning or end of partially copied entities. :param show_rel_type: If ``True``, show relation types for copied entities. :param show_copy_pos: If ``True``, show the position from which the entity tokens are copied. :param sanity_check: If ``True``, perform sanity check on generated sample. :return: A tuple of (loss_value, formatted list of words). """ if unkinfo is not None: unkprob, unki2w = unkinfo unkprob = unkprob[self._vocab_size:] unki2w = unki2w[self._vocab_size:] normalized_unkprob = F.log_softmax(unkprob, dim=0) # noinspection PyPep8Naming UNK, INVALID, UNK_TOKEN, CANONICAL_IDX, EPS = -100, -1, 0, 0, 1e-4 self.eval() self.init_hidden(1, [example.relations]) word_vocab, rel_vocab = vocab['word'], vocab['rel'] tensor = functools.partial(sample_utils.tensor, device=self.device) randint = sample_utils.randint sample = functools.partial(sample_utils.sample, greedy=greedy, topk=topk) np_sample = functools.partial(sample_utils.np_sample, greedy=greedy, topk=topk) # noinspection PyShadowingNames def compute_loss( inputs: List[int], rel_ids: List[int], copy_pos: List[int], surface_indices: List[int], hidden: Optional[HiddenState] = None ) -> Tuple[float, HiddenState]: batch = SimpleNamespace( sequence=tensor(inputs[:-1]), target=tensor(inputs[1:]), unkprob=None, seqs={ 'rel_ids': tensor(rel_ids), 'copy_pos': tensor(copy_pos), 'surface_indices': tensor(surface_indices) }, ntokens=len(inputs) - 1, ) loss, next_hidden = self.calc_loss(batch, hidden=hidden) # type: ignore return loss.item(), next_hidden # Initialization if warm_up is None: inputs = [begin_symbol] rel_ids = [INVALID] copy_pos = [INVALID] surface_indices = [INVALID] total_log_prob = 0.0 hidden = initial_hidden else: inputs = list(word_vocab.numericalize(example.sentence[:warm_up])) rel_ids = list(example.rel_ids[:warm_up]) copy_pos = list(example.copy_pos[:warm_up]) surface_indices = list(example.surface_indices[:warm_up]) total_log_prob, hidden = compute_loss(inputs, rel_ids, copy_pos, surface_indices, initial_hidden) total_log_prob = -total_log_prob * (len(inputs) - 1) # Sampling procedure while len(inputs) < max_length and inputs[-1] != end_symbol: fact_log_probs, output, _, next_hidden = \ self._compute_fact_log_probs(tensor(inputs[-1]), tensor(rel_ids[-1]), tensor(copy_pos[-1]), hidden) rel_id, fact_loss = sample(fact_log_probs[0]) rel_id -= 1 total_log_prob += fact_loss # next_fact_embed: (1, 1, fact_embed_dim) next_fact_embed = self._get_fact_embeds(tensor(rel_id)) copy_indicator, alias_log_probs, pos_log_probs, vocab_log_probs = \ self._compute_generate_log_probs(output, next_fact_embed, tensor([rel_ids[-1], rel_id])) if torch.bernoulli(copy_indicator).item(): total_log_prob += torch.log(copy_indicator).item() assert rel_id != -1 # copy entity aliases = example.relations[rel_id].obj_alias if self._alias_disamb is AliasDisamb.FastText: assert alias_log_probs is not None surface_idx, surface_loss = sample(alias_log_probs[0]) else: surface_idx, surface_loss = 0, 0.0 alias = self.alias_list[aliases[surface_idx]] entity: List[str] = alias.split() # normalization not required # TODO: keep consistent with _mask_invalid_pos setting pos, pos_loss = sample(pos_log_probs if allow_invalid_pos else pos_log_probs.squeeze()[:len(entity)]) if self._mask_invalid_pos: pos_loss -= torch.logsumexp( pos_log_probs.squeeze()[:len(entity)], dim=0).item() total_log_prob += surface_loss + pos_loss token = UNK_TOKEN if pos >= len( entity) else word_vocab.w2i.get(entity[pos], UNK_TOKEN) else: total_log_prob += torch.log(1.0 - copy_indicator).item() assert rel_id == -1 # generate word token, token_loss = sample(vocab_log_probs) total_log_prob += token_loss pos = INVALID surface_idx = INVALID if token == 0 and unkinfo is not None: # unk unk_idx, unk_loss = np_sample(normalized_unkprob) total_log_prob += unk_loss # Ugly multi-purpose use of variables. surface_idx = unk_idx # Record unk word index in surface_indices. rel_id = UNK # Record UNK in rel_ids. inputs.append(token) rel_ids.append(rel_id) copy_pos.append(pos) surface_indices.append(surface_idx) hidden = next_hidden sample_loss = -total_log_prob / (len(inputs) - 1) if print_info: print( f"Sample loss: {sample_loss:.3f}, PPL: {math.exp(sample_loss):.3f}" ) # Sanity checks if sanity_check: loss_val, gold_hidden = compute_loss(inputs, rel_ids, copy_pos, surface_indices, initial_hidden) assert hidden is not None hidden_state_diff = max( torch.max(torch.abs(g - h)).item() for g, h in zip(gold_hidden, hidden)) if hidden_state_diff > EPS: Logging.warn( f"Hidden states do not match. Difference: {hidden_state_diff}" ) if abs(sample_loss - loss_val) > EPS: Logging.warn( f"Loss values do not match. " f"Forward loss: {loss_val}, difference: {abs(sample_loss - loss_val)}" ) # Format the output sentence = list(zip(inputs, rel_ids, copy_pos, surface_indices)) words = [] copy_count = 0 complete_count = 0 last_entity = None entity_continuing = False for idx, (token, rel_id, pos, surface_idx) in enumerate(sentence): is_warm_up = (warm_up is not None and idx < warm_up) if rel_id == INVALID: word = word_vocab.i2w[token] elif rel_id == UNK: word = Logging.color('blue', unki2w[surface_idx]) else: copy_count += 1 entity_id = example.relations[rel_id].obj_alias[surface_idx] entity = self.alias_list[entity_id].split() if pos >= len(entity): word = "<X>" else: word = entity[pos] if show_copy_pos: word = f"{pos}_{rel_id}_{surface_idx}_{word}" is_last_word_in_entity = (idx == len(sentence) - 1 or sentence[idx + 1][1:] != (rel_id, pos + 1, surface_idx)) is_first_word_in_entity = (idx == 0 or sentence[idx - 1][1:] != (rel_id, pos - 1, surface_idx)) # add entity tag after the last word if show_rel_type and is_last_word_in_entity: word = f"{word}_[{rel_vocab.i2w[example.relations[rel_id].rel_typ]}]" # check whether fully copied if show_ellipses: if pos < len(entity) - 1 and is_last_word_in_entity: word = word + '...' + ( f"({'_'.join(entity[(pos + 1):])})" if fill_incomplete else "") if pos > 0 and is_first_word_in_entity: word = (f"({'_'.join(entity[:pos])})" if fill_incomplete else "") + '...' + word if entity_continuing: if last_entity == (rel_id, surface_idx, pos - 1): # Continuing last_entity = (rel_id, surface_idx, pos) else: entity_continuing = False last_entity = None if pos == 0 and not entity_continuing: # reset entity_continuing = True last_entity = (rel_id, surface_idx, 0) if color_outputs and not is_warm_up and (color_incomplete or entity_continuing): word = Logging.color( 'red' if surface_idx == 0 else 'green', word) if pos == len(entity) - 1 and entity_continuing: # commit entity_continuing = False complete_count += 1 if color_outputs and is_warm_up: word = Logging.color('yellow', word) words.append(word) if print_info: print(f"Copied, Completed: {copy_count}, {complete_count}") sampled_output = SampledOutput(sentence=words, sample_loss=sample_loss, complete_copies=complete_count, incomplete_copies=copy_count) return sampled_output