def __init__(self, checkpoint, lang='en', dictionary=None, save_wsd_details=True, maxlen=100, batch_size=5): super().__init__() try: self._set_spacy_extensions() except ValueError: pass self.lang = lang if lang != 'en': self.lemma_pos_dictionary = ResourceManager.get_lemma_pos_dictionary( lang=lang) self.lemma_pos_to_possible_offsets = ResourceManager.get_lemma_pos_to_possible_offsets( lang=lang) self.dictionary = dictionary or Dictionary.load(DEFAULT_DICTIONARY) self.output_dictionary = ResourceManager.get_offsets_dictionary() self.model = self._load_model( checkpoint, _FakeTask(self.dictionary, self.output_dictionary, 'wsd')) self.save_wsd_details = save_wsd_details self.maxlen = maxlen self.batch_size = batch_size
def __init__(self, path, dictionary=None, use_synsets=True, keep_string_data=False, lang="en"): assert use_synsets if not dictionary: dictionary = Dictionary.load(DEFAULT_DICTIONARY) if not os.path.exists(path): os.makedirs(path) self.vectors_path = os.path.join(path, 'vectors.hdf5') self.gold_path = os.path.join(path, 'gold.pkl') if os.path.exists(self.vectors_path): os.remove(self.vectors_path) if os.path.exists(self.gold_path): os.remove(self.gold_path) self.keep_string_data = keep_string_data if keep_string_data: self.oov_dictionary_path = os.path.join(path, 'oov.pkl') if os.path.exists(self.oov_dictionary_path): os.remove(self.oov_dictionary_path) self.oov_dictionary = {} else: self.oov_dictionary = None self.metadata = { "lang": lang } self.metadata_path = os.path.join(path, 'metadata.json') if os.path.exists(self.metadata_path): os.remove(self.metadata_path) self.lang = self.metadata['lang'] self.hd5_file = h5py.File(self.vectors_path, mode='w') self.token_data = None self.seq_data = None self.gold = [] self._max_gold = 0 self._max_sequence = 0 self.dictionary = dictionary self.output_dictionary = ResourceManager.get_offsets_dictionary() self.use_synsets = use_synsets
def _read_and_load_in_args(cls, args): args.decoder_structured_logits_edgelists = getattr(args, 'decoder_structured_logits_edgelists', []) adjacency = None if getattr(args, 'decoder_use_structured_logits', False): assert args.decoder_structured_logits_edgelists, 'No edges provided!' if args.decoder_structured_logits_edgelists: if isinstance(args.decoder_structured_logits_edgelists[0], torch.Tensor): adjacency = repack_sparse_tensor(*args.decoder_structured_logits_edgelists) else: from ewiser.fairseq_ext.data.dictionaries import ResourceManager adjacency = ResourceManager.make_adjacency_from_files(*args.decoder_structured_logits_edgelists) args.decoder_structured_logits_edgelists = unpack_sparse_tensor(adjacency.clone().cpu()) else: adjacency = None return { 'adjacency': adjacency, }
def setup_task(cls, args, **kwargs): """Setup the task (e.g., load dictionaries). Args: args (argparse.Namespace): parsed command-line arguments """ dictionary = None output_dictionary = None if args.data: dictionary = Dictionary.load(os.path.join(args.data, 'dict.txt')) print('| dictionary: {} types'.format(len(dictionary))) output_dictionary = ResourceManager.get_senses_dictionary(True) print('| output_dictionary: {} types'.format( len(output_dictionary))) criterion_weights = torch.ones(len(output_dictionary)).float() criterion_weights[:output_dictionary.nspecial] = 0. criterion_weights.requires_grad = False else: raise NotImplementedError return cls(args, dictionary, output_dictionary, criterion_weights=criterion_weights)
def main(args): print("Loading checkpoints: " + " ".join(args.checkpoints)) data = torch.load( args.checkpoints[0], map_location='cpu', ) model_args = data['args'] model_args.cpu = 'cuda' not in args.device model_args.context_embeddings_cache = args.device state = data['model'] dictionary = Dictionary.load(DEFAULT_DICTIONARY) output_dictionary = ResourceManager.get_offsets_dictionary() target_manager = TargetManager(SequenceLabelingTaskKind.WSD) task = SequenceLabelingTask(model_args, dictionary, output_dictionary) if len(args.checkpoints) == 1: model = task.build_model(model_args).cpu().eval() model.load_state_dict(state, strict=True) else: checkpoints = LinearTaggerEnsembleModel.make_args_iterator( args.checkpoints) model = LinearTaggerEnsembleModel.build_model( checkpoints, task, ) model = model.eval() model.to(args.device) datasets = [] for corpus in args.xmls: if corpus.endswith('.data.xml'): dataset = WSDDataset.read_raganato( corpus, dictionary, use_synsets=True, max_length=args.max_length, on_error='keep', quiet=args.quiet, read_by=args.read_by, ) else: with open(corpus, 'rb') as pkl: dataset = pickle.load(pkl) datasets.append(dataset) corpora = zip(args.xmls, datasets) for corpus, dataset in corpora: hit, tot = 0, 0 all_answers = {} for sample_original in DataLoader(dataset, collate_fn=dataset.collater, batch_size=args.batch_size): with torch.no_grad(): net_output = model( **{ k: v.to(args.device) if isinstance(v, torch.Tensor) else v for k, v in sample_original['net_input'].items() }) lprobs = model.get_normalized_probs(net_output, log_probs=True).cpu() results, answers = target_manager.calulate_metrics( lprobs, sample_original) all_answers.update(answers) hit += results['hit'] tot += results['tot'] T = 0 gold_answers = defaultdict(set) gold_path = Path(corpus.replace('data.xml', 'gold.key.txt')) bnids_map = None for line in gold_path.read_text().splitlines(): pieces = line.strip().split(' ') if not pieces: continue trg, *gold = pieces T += 1 for g in gold: if g.startswith('bn:'): if bnids_map is None: bnids_map = ResourceManager.get_bnids_to_offset_map() o = bnids_map.get(g) if o is None: if args.on_error == 'keep': o = { g, } gold_answers[trg] |= o else: gold_answers[trg] |= o elif g.startswith('wn:'): gold_answers[trg].add(g) else: try: o = make_offset(patched_lemma_from_key(g).synset()) except Exception: o = None if o is None: if args.on_error == 'keep': gold_answers[trg].add(g) else: gold_answers[trg].add(o) all_answers = { k: output_dictionary.symbols[v] for k, v in all_answers.items() } if args.on_error == 'skip': N = len([t for t, aa in gold_answers.items() if aa]) else: N = len(gold_answers) ok, notok = 0, 0 for k, answ in all_answers.items(): gold = gold_answers.get(k) if not gold: continue if not answ or answ == '<unk>': continue if answ in gold: ok += 1 else: notok += 1 M = 0 for k, gg in gold_answers.items(): if args.on_error == 'skip' and (not gg): continue valid = False for g in gg: if g.startswith('wn:'): valid = True if not valid: print(k, all_answers.get(k), gg) a = all_answers.get(k) if a is None or a == '<unk>': M += 1 try: precision = ok / (ok + notok) except ZeroDivisionError: precision = 0. try: recall = ok / N except ZeroDivisionError: recall = 0. try: f1 = (2 * precision * recall) / (precision + recall) except ZeroDivisionError: f1 = 0. print(corpus) print( f'P: {precision}\tR: {recall}\tF1: {f1}\tN/T:{N}/{T}\tY/N/M/S: {ok}/{notok}/{M}/{T-N}' ) if args.predictions: if not os.path.exists(args.predictions): os.mkdir(args.predictions) name = ".".join( os.path.split(corpus)[-1].split('.')[:-2]) + '.results.key.txt' path = os.path.join(args.predictions, name) with open(path, 'w') as results_file: for k, v in sorted(all_answers.items()): if not v or v == '<unk>': v = '' results_file.write(k + ' ' + v + '\n')
def read_graph(*paths, input_keys=None): self_loops_count = 0 offsets = ResourceManager.get_offsets_dictionary() if input_keys is None: with open(paths[0]) as file: for line in file: line = line.strip() if not line: continue else: offset1, offset2, *info = line.split() if offset1.startswith('bn:'): input_keys = 'bnids' elif offset1.startswith('wn:'): input_keys = 'offsets' else: input_keys = 'sensekeys' break assert input_keys is not None remap = ResourceManager.get_bnids_to_offset_map() g = nx.DiGraph() for path in paths: with open(path) as file: for line in file: line = line.strip() if not line: continue offset1, offset2, *info = line.split() if 0 <= len(info) <= 1: w = None else: try: w = float(info[1]) except ValueError: w = None if offset1.startswith('bn:'): offsets1 = remap.get(offset1) elif offset1.startswith('wn:'): offsets1 = [offset1] else: raise NotImplementedError if offset2.startswith('bn:'): offsets2 = remap.get(offset2) elif offset2.startswith('wn:'): offsets2 = [offset2] else: raise NotImplementedError for offset1, offset2 in itertools.product(offsets1, offsets2): offset1 = fix_offset(offset1) # v -> child in hypernymy offset2 = fix_offset(offset2) # u -> father in hypernymy trg_node = offsets.index(offset1) src_node = offsets.index(offset2) g.add_edge(src_node, trg_node, w=w) self_loops_count += int(src_node == trg_node) return g
def _read_raganato_xml( xml_path: str, read_by: Union[str, RaganatoReadBy] = RaganatoReadBy.TEXT, dictionary: Optional[Dictionary] = None, tagset='universal', #universal, multilingual lang='en', inst_to_keep=None, on_error='skip', # skip, keep, raise quiet=True, ) -> Tuple[np.ndarray, List[str], Dict[int, str]]: if isinstance(read_by, str): read_by = getattr(RaganatoReadBy, read_by.upper()) elif isinstance(read_by, RaganatoReadBy): read_by = read_by.value else: raise TypeError if not dictionary: dictionary = Dictionary.load(DEFAULT_DICTIONARY) assert tagset in {'universal', 'multilingual'} oov_dictionary = {} pos_dictionary = ResourceManager.get_pos_dictionary() lemma_pos_dictionary = ResourceManager.get_lemma_pos_dictionary(lang) lemma_pos_to_possible_offsets_map = ResourceManager.get_lemma_pos_to_possible_offsets(lang) offsets_dictionary = ResourceManager.get_offsets_dictionary() total_text_n = [] total_token = [] total_lemma_pos = [] total_pos = [] total_gold_indices = [] token = [] token_strings = [] lemma_pos = [] pos = [] gold_indices = [] target_labels = [] gold_idx = 0 def discharge(): cond1 = bool(token) # cond2 = True cond2 = any(g > -1 for g in gold_indices) if cond1 and cond2: if not total_text_n: old_text_n = -1 else: old_text_n = total_text_n[-1] text_n = [old_text_n + 1] * len(token) for token_number, (t, ts) in enumerate(zip(token, token_strings), start=len(total_token)): if t == dictionary.unk_index: oov_dictionary[token_number] = ts total_text_n.extend(text_n) total_token.extend(token) total_lemma_pos.extend(lemma_pos) total_pos.extend(pos) total_gold_indices.extend(gold_indices) token.clear() token_strings.clear() lemma_pos.clear() pos.clear() gold_indices.clear() old_text_number = -1 for token_number, (text_number, word) in enumerate(read_by(xml_path)): if not word.text: continue if text_number != old_text_number: discharge() old_text_number = text_number if tagset == 'universal': t = word.text.replace('_', ' ') p = word.attrib["pos"] if len(p) == 1: lp = lemma_pos_dictionary.index( word.attrib.get("lemma", 'X').lower() + '#' + p.lower().replace('j', 'a') ) else: lp = lemma_pos_dictionary.index( word.attrib.get("lemma", 'X').lower() + '#' + _ud_to_wn.get(p, 'n') ) pos.append(pos_dictionary.index(p)) elif tagset == 'multilingual': raise try: text = word.text.strip().replace(' ', '_') except Exception as e: print(etree.tostring(word)) raise e try: t, l, p = text.split('/') except ValueError as e: t = text.split('/')[0] p = word.attrib["pos"][0].lower() l = word.attrib["lemma"].lower() print(etree.tostring(word)) l = l.lower().strip() p = p.lower().strip() lp = lemma_pos_dictionary.index(l + '#' + p) pos.append(pos_dictionary.unk_index) else: raise token_strings.append(t) token.append(dictionary.index(t)) lemma_pos.append(lp) idx = word.attrib.get("id") if idx and word.tag == 'instance': ignore = False in_lemma_pos = True if len(p) == 1: lp_string = unicodedata.normalize('NFC', word.attrib["lemma"].lower()) + '#' + word.attrib['pos'].lower().replace('j', 'a') else: lem = word.attrib.get("lemma").lower() wnpos = _ud_to_wn.get(word.attrib["pos"], 'n') if not lem and word.tag == 'instance': lem = wordnet.morphy(t, wnpos) if not lem: lem = '' lp = lemma_pos_dictionary.index( lem + '#' + wnpos ) lp_string = unicodedata.normalize('NFC', lem) + '#' + wnpos if lp_string not in lemma_pos_dictionary.indices: in_lemma_pos = False msg = f'Lemma and pos "{lp_string}" for instance "{idx}" not in the lemma pos dictionary.' if on_error == 'skip': ignore = True if not quiet: logging.warning('SKIP:' + msg) elif on_error == 'keep': if not quiet: logging.warning('KEEP:' + msg) else: raise KeyError(msg) if inst_to_keep and (idx in inst_to_keep) and in_lemma_pos: gold = inst_to_keep[idx] possible = lemma_pos_to_possible_offsets_map[lp] possible_str = [offsets_dictionary.symbols[x] for x in possible] for g in gold: o = offsets_dictionary.symbols[g] if o not in possible_str: msg = ( f'"{o}" (instance "{idx}") ' f'not among the possible for lemma pos "{lp_string}". ' f'Possible: {possible_str}.' ) if on_error == 'skip': ignore = True if not quiet: logging.warning('SKIP:' + msg) elif on_error == 'keep': if not quiet: logging.warning('KEEP:' + msg) else: raise KeyError(msg) if ignore: gold_indices.append(-1) elif (not inst_to_keep) or (idx in inst_to_keep): target_labels.append(idx) gold_indices.append(gold_idx) gold_idx += 1 else: gold_indices.append(-1) else: gold_indices.append(-1) discharge() text_n = np.array(total_text_n, dtype=np.int64) token = np.array(total_token, dtype=np.int32) lemma_pos = np.array(total_lemma_pos, dtype=np.int32) pos = np.array(total_pos, dtype=np.int8) gold_indices = np.array(total_gold_indices, dtype=np.int32) raw_data = np.rec.fromarrays( [text_n, token, lemma_pos, pos, gold_indices], names=['text_n', 'token', 'lemma_pos', 'pos', 'gold_indices'] ) return raw_data, target_labels, oov_dictionary
def __init__( self, path: str, dictionary: Optional[Dictionary] = None, target_classes: str = "offsets", add_monosemous: bool = False, shuffle: bool = False, lazy: bool = False, ) -> None: self._loaded = False metadata_path = os.path.join(path, 'metadata.json') if os.path.exists(metadata_path): with open(metadata_path) as json_hdl: self.metadata = json.load(json_hdl) else: self.metadata = {"lang": "en"} self.use_synsets = target_classes == 'offsets' if not dictionary: dictionary = Dictionary.load(DEFAULT_DICTIONARY) self.dictionary = dictionary self.lemma_pos_dictionary = ResourceManager.get_lemma_pos_dictionary(self.lang) self.remap_senses = None if target_classes == 'offsets': self.output_dictionary = ResourceManager.get_offsets_dictionary() self.lemma_pos_to_possible_senses = ResourceManager.get_lemma_pos_to_possible_offsets(self.lang) elif target_classes == 'sensekeys': self.output_dictionary = ResourceManager.get_sensekeys_dictionary() self.lemma_pos_to_possible_senses = ResourceManager.get_lemma_pos_to_possible_sensekeys(self.lang) elif target_classes == 'bnids': self.output_dictionary = ResourceManager.get_offsets_dictionary() #self.remap_senses = ResourceManager.get_index_remap_offset_bnids() self.lemma_pos_to_possible_senses = ResourceManager.get_lemma_pos_to_possible_offsets(self.lang) else: raise ValueError('target_classes must be in {"sensekeys", "offsets", "bnids"} but was"' + target_classes + '" instead') self.add_monosemous = add_monosemous self.shuffle = shuffle self.path = os.path.abspath(path) vectors_path = os.path.join(path, 'vectors.hdf5') gold_path = os.path.join(path, 'gold.pkl') oov_dictionary_path = os.path.join(path, 'oov.pkl') self._h5_file = h5py.File(vectors_path, mode="r") self._token_data_h5 = self._h5_file['.']['token_data'] self._token_data_mem = None self._seq_data_h5 = self._h5_file['.']['seq_data'] self._seq_data_mem = None with open(gold_path, 'rb') as pkl: self.gold = pickle.load(pkl) if os.path.exists(oov_dictionary_path): with open(oov_dictionary_path, 'rb') as pkl: self.oov_dictionary = pickle.load(pkl) else: self.oov_dictionary = None self.original_sizes = self.seq_data()[:, 1] self.target_sizes = None self._from_tmp = False if not lazy: self.load_in_memory() self.embeddings = None self.trg_manager = TargetManager('wsd')
def _read_plaintext(plaintext_path: str, dictionary=None, merge_with_prob: float=0.0) -> Tuple[np.ndarray, List[str], Dict[int, str]]: assert dictionary is not None oov_dictionary = {} pos_dictionary = ResourceManager.get_pos_dictionary() lemma_pos_dictionary = ResourceManager.get_lemma_pos_dictionary() text_n = [] token = [] lemma_pos = [] pos = [] gold_indices = [] target_labels = [] gold_idx = 0 text_i = 0 old_text_n_real = 0 for t_n, (text_n_real, word) in enumerate(_parse_with_spacy(plaintext_path)): if merge_with_prob <= 0.: text_i = text_n_real else: if text_n_real == 0: pass elif old_text_n_real != text_n_real: if random.random() > merge_with_prob: pass else: text_i += 1 old_text_n_real = text_n_real text_n.append(text_i) t = word.text.replace(' ', '_') if t not in dictionary.indices: oov_dictionary[t_n] = t t = _longest(t) token.append(dictionary.index(t)) p = word.attrib["pos"] lp = lemma_pos_dictionary.index(word.attrib["lemma"].lower() + '#' + _ud_to_wn.get(p, 'x')) lemma_pos.append(lp) pos.append(pos_dictionary.index(p)) idx = word.attrib.get("id") if idx: target_labels.append(idx) gold_indices.append(gold_idx) gold_idx += 1 else: gold_indices.append(-1) text_n = np.array(text_n, dtype=np.int64) token = np.array(token, dtype=np.int32) lemma_pos = np.array(lemma_pos, dtype=np.int32) pos = np.array(pos, dtype=np.int8) gold_indices = np.array(gold_indices, dtype=np.int32) raw_data = np.rec.fromarrays( [text_n, token, lemma_pos, pos, gold_indices], names=['text_n', 'token', 'lemma_pos', 'pos', 'gold_indices'] ) return raw_data, target_labels, oov_dictionary
def _read_raganato_gold_( gold_path: str, _use_synsets: bool = False, input_keys: str = "sensekeys", on_error: str = "skip", # skip, keep, raise quiet: bool = False, ) -> Dict[str, List[int]]: if input_keys == 'bnids': bnids_map = ResourceManager.get_bnids_to_offset_map() target_dict = {} dictionary = \ ResourceManager.get_offsets_dictionary() if _use_synsets else ResourceManager.get_sensekeys_dictionary() with open(gold_path, encoding="utf8") as f: for line in f: line = line.strip() if not line: continue if input_keys == 'sensekeys': instance, *sensekeys = line.split() try: senses = [make_offset(patched_lemma_from_key(sk, wordnet).synset()) for sk in sensekeys] except Exception as e: print(instance, sensekeys) raise e elif input_keys == 'offsets': instance, *offsets = line.split() offsets_ = [] for offset in offsets: if offset not in dictionary.indices: msg = f'Error in gold file for instance {instance}: {offset} is not valid.' if on_error == 'keep': offsets_.append(offset) if not quiet: logging.warning('KEEP: ' + msg) elif on_error == 'skip': if not quiet: logging.warning('SKIP: ' + msg) else: raise KeyError(msg) else: offsets_.append(offset) senses = offsets_ elif input_keys == 'bnids': instance, *bnids = line.split() bnids_ = [] for bnid in bnids: if bnid not in bnids_map: msg = f'Error in gold file for instance {instance}: {bnid} is not valid or not in WordNet subgraph.' if on_error == 'keep': bnids_.append(bnid) if not quiet: logging.warning('KEEP: ' + msg) elif on_error == 'skip': if not quiet: logging.warning('SKIP: ' + msg) else: raise KeyError(msg) else: bnids_.append(bnid) bnids = bnids_ senses = list({s for b in bnids for s in bnids_map[b]}) else: senses = sensekeys if senses: senses = [dictionary.index(s) for s in senses] senses = remove_dup(senses, dictionary) target_dict[instance] = senses elif on_error == 'skip': if not quiet: logging.warning(f'SKIP: empty gold for instance {instance}.') elif on_error == 'keep': target_dict[instance] = senses if not quiet: logging.warning(f'KEEP: empty gold for instance {instance}.') else: raise ValueError(f'empty gold for instance {instance}.') return target_dict
from nltk.corpus import wordnet as wn import numpy as np import torch from ewiser.fairseq_ext.data.dictionaries import Dictionary, ResourceManager, DEFAULT_DICTIONARY from ewiser.fairseq_ext.data.utils import make_offset from ewiser.fairseq_ext.models.sequence_tagging import LinearTaggerModel from ewiser.fairseq_ext.modules.logit_convolution import repack_sparse_tensor from spacy.tokens.token import Token _FakeTask = namedtuple('_FakeTask', ('dictionary', 'output_dictionary', 'kind')) UD_WNPOS = {'NOUN': 'n', 'VERB': 'v', 'ADJ': 'a', 'ADV': 'r'} babelnet_map = ResourceManager.get_offset_to_bnids_map() def entropy_getter(token): n = len(token._.offsets_distribution) if n < 2: return 0. else: probs = np.array(list(token._.offsets_distribution.values())) entropy = -np.sum(probs * np.log(probs)) normalized_entropy = entropy / np.log(n) return normalized_entropy class DisambiguatorInternals:
parser.add_argument('--read-by', default='text', choices=['text', 'sentence']) parser.add_argument( '--on-error', default='skip', choices=('skip', 'keep', 'raise'), help='What to do when some inconsistency is encountered.') parser.add_argument( '--quiet', action='store_true', help='Do not print to stderr when some inconsistency is encountered.') args = parser.parse_args() dictionary = Dictionary.load(DEFAULT_DICTIONARY) output_dictionary = ResourceManager.get_senses_dictionary(use_synsets=True) output = WSDDatasetBuilder(args.output, dictionary=dictionary, use_synsets=True, keep_string_data=True, lang=args.lang) for xml_path in args.xmls: output.add_raganato( xml_path=xml_path, max_length=args.max_length, input_keys=args.input_keys, on_error=args.on_error, quiet=args.quiet, read_by=args.read_by,