def _logit_for_synset(self, logits, offset_or_synset): if isinstance(offset_or_synset, str): offset = offset_or_synset else: offset = make_offset(offset_or_synset) idx = self.dictionary.index(offset) return logits[idx].item()
def _get_adjacency_WSD(self): offsets = ResourceManager.get_offsets_dictionary() coordinates = [] values = [] size = torch.Size([len(offsets)] * 2) for i, offset1 in enumerate(offsets.symbols): if offset1.startswith('wn:'): synset1 = wordnet.synset_from_pos_and_offset( offset1[-1], int(offset1[3:-1])) for synset2 in itertools.chain( synset1.hypernyms(), synset1.hyponyms(), synset1.similar_tos(), ): offset2 = make_offset(synset2) j = offsets.index(offset2) coordinates.extend([(i, j), (j, i)]) values.extend([1., 1.]) coordinates = torch.LongTensor(sorted(coordinates)).t() values = torch.FloatTensor(values) adjacency = torch.sparse.FloatTensor(coordinates, values, size) return adjacency
def _tag_docs_from_model_output(self, logits, sequences_spacy, additional_data=None): logits = logits.detach().cpu() logits[:, :, 0:2] = -1e7 if (additional_data is not None) and (additional_data.get('prelogits') is not None): prelogits = additional_data['prelogits'].detach().cpu() prelogits[:, :, 0:2] = -1e7 else: prelogits = None for i, seq in enumerate(sequences_spacy): for j, t in enumerate(seq): logits_token = logits[i, j] lemma = t._.lemma_preset_else_spacy.lower() pos = t._.pos_preset_else_spacy wnpos = UD_WNPOS.get(pos) if lemma and wnpos: if self.lang == 'en': synsets = wn.synsets(lemma, wnpos) offsets = [make_offset(s) for s in synsets] else: lemma_pos = lemma + '#' + wnpos lemma_pos_index = self.lemma_pos_dictionary.index( lemma_pos) offsets_indices = self.lemma_pos_to_possible_offsets[ lemma_pos_index] offsets = [ self.output_dictionary.symbols[i] for i in offsets_indices ] offsets = [o for o in offsets if o.startswith('wn:')] if not offsets: continue indices = np.array( [self.output_dictionary.index(o) for o in offsets]) else: continue logits_synsets = logits_token[indices] index = torch.max(logits_synsets, -1).indices.item() t._.offset = offsets[index] if self.save_wsd_details: internals = t._.disambiguator_internals = DisambiguatorInternals( self, t) internals.logits = logits_token if prelogits is not None: internals.logits_z = prelogits[i, j]
def main(args): print("Loading checkpoints: " + " ".join(args.checkpoints)) data = torch.load( args.checkpoints[0], map_location='cpu', ) model_args = data['args'] model_args.cpu = 'cuda' not in args.device model_args.context_embeddings_cache = args.device state = data['model'] dictionary = Dictionary.load(DEFAULT_DICTIONARY) output_dictionary = ResourceManager.get_offsets_dictionary() target_manager = TargetManager(SequenceLabelingTaskKind.WSD) task = SequenceLabelingTask(model_args, dictionary, output_dictionary) if len(args.checkpoints) == 1: model = task.build_model(model_args).cpu().eval() model.load_state_dict(state, strict=True) else: checkpoints = LinearTaggerEnsembleModel.make_args_iterator( args.checkpoints) model = LinearTaggerEnsembleModel.build_model( checkpoints, task, ) model = model.eval() model.to(args.device) datasets = [] for corpus in args.xmls: if corpus.endswith('.data.xml'): dataset = WSDDataset.read_raganato( corpus, dictionary, use_synsets=True, max_length=args.max_length, on_error='keep', quiet=args.quiet, read_by=args.read_by, ) else: with open(corpus, 'rb') as pkl: dataset = pickle.load(pkl) datasets.append(dataset) corpora = zip(args.xmls, datasets) for corpus, dataset in corpora: hit, tot = 0, 0 all_answers = {} for sample_original in DataLoader(dataset, collate_fn=dataset.collater, batch_size=args.batch_size): with torch.no_grad(): net_output = model( **{ k: v.to(args.device) if isinstance(v, torch.Tensor) else v for k, v in sample_original['net_input'].items() }) lprobs = model.get_normalized_probs(net_output, log_probs=True).cpu() results, answers = target_manager.calulate_metrics( lprobs, sample_original) all_answers.update(answers) hit += results['hit'] tot += results['tot'] T = 0 gold_answers = defaultdict(set) gold_path = Path(corpus.replace('data.xml', 'gold.key.txt')) bnids_map = None for line in gold_path.read_text().splitlines(): pieces = line.strip().split(' ') if not pieces: continue trg, *gold = pieces T += 1 for g in gold: if g.startswith('bn:'): if bnids_map is None: bnids_map = ResourceManager.get_bnids_to_offset_map() o = bnids_map.get(g) if o is None: if args.on_error == 'keep': o = { g, } gold_answers[trg] |= o else: gold_answers[trg] |= o elif g.startswith('wn:'): gold_answers[trg].add(g) else: try: o = make_offset(patched_lemma_from_key(g).synset()) except Exception: o = None if o is None: if args.on_error == 'keep': gold_answers[trg].add(g) else: gold_answers[trg].add(o) all_answers = { k: output_dictionary.symbols[v] for k, v in all_answers.items() } if args.on_error == 'skip': N = len([t for t, aa in gold_answers.items() if aa]) else: N = len(gold_answers) ok, notok = 0, 0 for k, answ in all_answers.items(): gold = gold_answers.get(k) if not gold: continue if not answ or answ == '<unk>': continue if answ in gold: ok += 1 else: notok += 1 M = 0 for k, gg in gold_answers.items(): if args.on_error == 'skip' and (not gg): continue valid = False for g in gg: if g.startswith('wn:'): valid = True if not valid: print(k, all_answers.get(k), gg) a = all_answers.get(k) if a is None or a == '<unk>': M += 1 try: precision = ok / (ok + notok) except ZeroDivisionError: precision = 0. try: recall = ok / N except ZeroDivisionError: recall = 0. try: f1 = (2 * precision * recall) / (precision + recall) except ZeroDivisionError: f1 = 0. print(corpus) print( f'P: {precision}\tR: {recall}\tF1: {f1}\tN/T:{N}/{T}\tY/N/M/S: {ok}/{notok}/{M}/{T-N}' ) if args.predictions: if not os.path.exists(args.predictions): os.mkdir(args.predictions) name = ".".join( os.path.split(corpus)[-1].split('.')[:-2]) + '.results.key.txt' path = os.path.join(args.predictions, name) with open(path, 'w') as results_file: for k, v in sorted(all_answers.items()): if not v or v == '<unk>': v = '' results_file.write(k + ' ' + v + '\n')
def _read_raganato_gold_( gold_path: str, _use_synsets: bool = False, input_keys: str = "sensekeys", on_error: str = "skip", # skip, keep, raise quiet: bool = False, ) -> Dict[str, List[int]]: if input_keys == 'bnids': bnids_map = ResourceManager.get_bnids_to_offset_map() target_dict = {} dictionary = \ ResourceManager.get_offsets_dictionary() if _use_synsets else ResourceManager.get_sensekeys_dictionary() with open(gold_path, encoding="utf8") as f: for line in f: line = line.strip() if not line: continue if input_keys == 'sensekeys': instance, *sensekeys = line.split() try: senses = [make_offset(patched_lemma_from_key(sk, wordnet).synset()) for sk in sensekeys] except Exception as e: print(instance, sensekeys) raise e elif input_keys == 'offsets': instance, *offsets = line.split() offsets_ = [] for offset in offsets: if offset not in dictionary.indices: msg = f'Error in gold file for instance {instance}: {offset} is not valid.' if on_error == 'keep': offsets_.append(offset) if not quiet: logging.warning('KEEP: ' + msg) elif on_error == 'skip': if not quiet: logging.warning('SKIP: ' + msg) else: raise KeyError(msg) else: offsets_.append(offset) senses = offsets_ elif input_keys == 'bnids': instance, *bnids = line.split() bnids_ = [] for bnid in bnids: if bnid not in bnids_map: msg = f'Error in gold file for instance {instance}: {bnid} is not valid or not in WordNet subgraph.' if on_error == 'keep': bnids_.append(bnid) if not quiet: logging.warning('KEEP: ' + msg) elif on_error == 'skip': if not quiet: logging.warning('SKIP: ' + msg) else: raise KeyError(msg) else: bnids_.append(bnid) bnids = bnids_ senses = list({s for b in bnids for s in bnids_map[b]}) else: senses = sensekeys if senses: senses = [dictionary.index(s) for s in senses] senses = remove_dup(senses, dictionary) target_dict[instance] = senses elif on_error == 'skip': if not quiet: logging.warning(f'SKIP: empty gold for instance {instance}.') elif on_error == 'keep': target_dict[instance] = senses if not quiet: logging.warning(f'KEEP: empty gold for instance {instance}.') else: raise ValueError(f'empty gold for instance {instance}.') return target_dict
def get_lemma_pos_to_possible_offsets(cls, lang='en') -> List[List[int]]: if cls._lemma_pos_to_possible_offsets.get(lang) is None: offsets_dictionary = cls.get_offsets_dictionary() lemma_pos_to_possible_offsets = [] if lang == 'en': lemma_pos_dictionary = cls.get_lemma_pos_dictionary(lang=lang) for i, lemma_pos in enumerate(lemma_pos_dictionary.symbols): if i < lemma_pos_dictionary.nspecial: lemma_pos_to_possible_offsets.append( [lemma_pos_dictionary.index(lemma_pos)]) else: lemma, pos = lemma_pos[:-2], lemma_pos[-1] senses = [ offsets_dictionary.index(make_offset(s.synset())) for s in wordnet.lemmas(lemma, pos) ] lemma_pos_to_possible_offsets.append(senses) if lemma_pos not in lemma_pos_dictionary.indices: raise KeyError( f'Lemma pos {lemma_pos} from the lemma pos to possible offsets dictionary' 'is not in the lemma pos dictionary.') else: lemma_pos_dictionary = cls.get_lemma_pos_dictionary(lang) offsets_dictionary = cls.get_offsets_dictionary() string_map = cls.get_bnids_to_offset_map() lemma_pos_string_to_offsets_strings = {} with open( os.path.join(EWISER_RES_DIR, 'dictionaries', 'lemma_pos2offsets.' + lang + '.txt')) as file: for line in file: line = line.strip() if not line: continue lemma_pos, *bnids = line.split(cls._LEMMA2OFFSETS_SEP) if lemma_pos not in lemma_pos_dictionary.indices: raise KeyError( f'Lemma pos {lemma_pos} from the lemma pos to possible offsets dictionary' 'is not in the lemma pos dictionary.') if lemma_pos in lemma_pos_string_to_offsets_strings: offsets = lemma_pos_string_to_offsets_strings[ lemma_pos] else: offsets = [] offsets = offsets + [ offset for bnid in bnids for offset in string_map[bnid] ] offsets = list(OrderedDict.fromkeys(offsets).keys()) lemma_pos_string_to_offsets_strings[ lemma_pos] = offsets for i, lemma_pos in enumerate(lemma_pos_dictionary.symbols): if lemma_pos not in lemma_pos_string_to_offsets_strings: lemma_pos_to_possible_offsets.append( [lemma_pos_dictionary.index(lemma_pos)]) else: offsets = lemma_pos_string_to_offsets_strings[ lemma_pos] senses = [offsets_dictionary.index(o) for o in offsets] lemma_pos_to_possible_offsets.append(senses) cls._lemma_pos_to_possible_offsets[ lang] = lemma_pos_to_possible_offsets return cls._lemma_pos_to_possible_offsets[lang]