Ejemplo n.º 1
0
 def _logit_for_synset(self, logits, offset_or_synset):
     if isinstance(offset_or_synset, str):
         offset = offset_or_synset
     else:
         offset = make_offset(offset_or_synset)
     idx = self.dictionary.index(offset)
     return logits[idx].item()
Ejemplo n.º 2
0
    def _get_adjacency_WSD(self):
        offsets = ResourceManager.get_offsets_dictionary()

        coordinates = []
        values = []
        size = torch.Size([len(offsets)] * 2)

        for i, offset1 in enumerate(offsets.symbols):
            if offset1.startswith('wn:'):
                synset1 = wordnet.synset_from_pos_and_offset(
                    offset1[-1], int(offset1[3:-1]))
                for synset2 in itertools.chain(
                        synset1.hypernyms(),
                        synset1.hyponyms(),
                        synset1.similar_tos(),
                ):
                    offset2 = make_offset(synset2)
                    j = offsets.index(offset2)
                    coordinates.extend([(i, j), (j, i)])
                    values.extend([1., 1.])

        coordinates = torch.LongTensor(sorted(coordinates)).t()
        values = torch.FloatTensor(values)
        adjacency = torch.sparse.FloatTensor(coordinates, values, size)
        return adjacency
Ejemplo n.º 3
0
    def _tag_docs_from_model_output(self,
                                    logits,
                                    sequences_spacy,
                                    additional_data=None):

        logits = logits.detach().cpu()
        logits[:, :, 0:2] = -1e7
        if (additional_data is not None) and (additional_data.get('prelogits')
                                              is not None):
            prelogits = additional_data['prelogits'].detach().cpu()
            prelogits[:, :, 0:2] = -1e7
        else:
            prelogits = None

        for i, seq in enumerate(sequences_spacy):
            for j, t in enumerate(seq):

                logits_token = logits[i, j]

                lemma = t._.lemma_preset_else_spacy.lower()
                pos = t._.pos_preset_else_spacy
                wnpos = UD_WNPOS.get(pos)

                if lemma and wnpos:
                    if self.lang == 'en':
                        synsets = wn.synsets(lemma, wnpos)
                        offsets = [make_offset(s) for s in synsets]
                    else:
                        lemma_pos = lemma + '#' + wnpos
                        lemma_pos_index = self.lemma_pos_dictionary.index(
                            lemma_pos)
                        offsets_indices = self.lemma_pos_to_possible_offsets[
                            lemma_pos_index]
                        offsets = [
                            self.output_dictionary.symbols[i]
                            for i in offsets_indices
                        ]
                        offsets = [o for o in offsets if o.startswith('wn:')]

                    if not offsets:
                        continue
                    indices = np.array(
                        [self.output_dictionary.index(o) for o in offsets])
                else:
                    continue

                logits_synsets = logits_token[indices]
                index = torch.max(logits_synsets, -1).indices.item()
                t._.offset = offsets[index]

                if self.save_wsd_details:
                    internals = t._.disambiguator_internals = DisambiguatorInternals(
                        self, t)
                    internals.logits = logits_token
                    if prelogits is not None:
                        internals.logits_z = prelogits[i, j]
Ejemplo n.º 4
0
def main(args):

    print("Loading checkpoints: " + " ".join(args.checkpoints))

    data = torch.load(
        args.checkpoints[0],
        map_location='cpu',
    )
    model_args = data['args']
    model_args.cpu = 'cuda' not in args.device
    model_args.context_embeddings_cache = args.device
    state = data['model']
    dictionary = Dictionary.load(DEFAULT_DICTIONARY)
    output_dictionary = ResourceManager.get_offsets_dictionary()

    target_manager = TargetManager(SequenceLabelingTaskKind.WSD)
    task = SequenceLabelingTask(model_args, dictionary, output_dictionary)

    if len(args.checkpoints) == 1:
        model = task.build_model(model_args).cpu().eval()
        model.load_state_dict(state, strict=True)
    else:
        checkpoints = LinearTaggerEnsembleModel.make_args_iterator(
            args.checkpoints)
        model = LinearTaggerEnsembleModel.build_model(
            checkpoints,
            task,
        )

    model = model.eval()
    model.to(args.device)

    datasets = []

    for corpus in args.xmls:
        if corpus.endswith('.data.xml'):
            dataset = WSDDataset.read_raganato(
                corpus,
                dictionary,
                use_synsets=True,
                max_length=args.max_length,
                on_error='keep',
                quiet=args.quiet,
                read_by=args.read_by,
            )
        else:
            with open(corpus, 'rb') as pkl:
                dataset = pickle.load(pkl)

        datasets.append(dataset)

    corpora = zip(args.xmls, datasets)

    for corpus, dataset in corpora:

        hit, tot = 0, 0
        all_answers = {}
        for sample_original in DataLoader(dataset,
                                          collate_fn=dataset.collater,
                                          batch_size=args.batch_size):
            with torch.no_grad():
                net_output = model(
                    **{
                        k:
                        v.to(args.device) if isinstance(v, torch.Tensor) else v
                        for k, v in sample_original['net_input'].items()
                    })
                lprobs = model.get_normalized_probs(net_output,
                                                    log_probs=True).cpu()

            results, answers = target_manager.calulate_metrics(
                lprobs, sample_original)
            all_answers.update(answers)
            hit += results['hit']
            tot += results['tot']

        T = 0
        gold_answers = defaultdict(set)
        gold_path = Path(corpus.replace('data.xml', 'gold.key.txt'))
        bnids_map = None
        for line in gold_path.read_text().splitlines():
            pieces = line.strip().split(' ')
            if not pieces:
                continue
            trg, *gold = pieces
            T += 1
            for g in gold:
                if g.startswith('bn:'):
                    if bnids_map is None:
                        bnids_map = ResourceManager.get_bnids_to_offset_map()
                    o = bnids_map.get(g)
                    if o is None:
                        if args.on_error == 'keep':
                            o = {
                                g,
                            }
                            gold_answers[trg] |= o
                    else:
                        gold_answers[trg] |= o
                elif g.startswith('wn:'):
                    gold_answers[trg].add(g)
                else:
                    try:
                        o = make_offset(patched_lemma_from_key(g).synset())
                    except Exception:
                        o = None
                    if o is None:
                        if args.on_error == 'keep':
                            gold_answers[trg].add(g)
                    else:
                        gold_answers[trg].add(o)

        all_answers = {
            k: output_dictionary.symbols[v]
            for k, v in all_answers.items()
        }

        if args.on_error == 'skip':
            N = len([t for t, aa in gold_answers.items() if aa])
        else:
            N = len(gold_answers)
        ok, notok = 0, 0
        for k, answ in all_answers.items():
            gold = gold_answers.get(k)

            if not gold:
                continue
            if not answ or answ == '<unk>':
                continue
            if answ in gold:
                ok += 1
            else:
                notok += 1

        M = 0
        for k, gg in gold_answers.items():
            if args.on_error == 'skip' and (not gg):
                continue
            valid = False
            for g in gg:
                if g.startswith('wn:'):
                    valid = True
            if not valid:
                print(k, all_answers.get(k), gg)
            a = all_answers.get(k)
            if a is None or a == '<unk>':
                M += 1

        try:
            precision = ok / (ok + notok)
        except ZeroDivisionError:
            precision = 0.

        try:
            recall = ok / N
        except ZeroDivisionError:
            recall = 0.

        try:
            f1 = (2 * precision * recall) / (precision + recall)
        except ZeroDivisionError:
            f1 = 0.

        print(corpus)
        print(
            f'P: {precision}\tR: {recall}\tF1: {f1}\tN/T:{N}/{T}\tY/N/M/S: {ok}/{notok}/{M}/{T-N}'
        )

        if args.predictions:
            if not os.path.exists(args.predictions):
                os.mkdir(args.predictions)
            name = ".".join(
                os.path.split(corpus)[-1].split('.')[:-2]) + '.results.key.txt'
            path = os.path.join(args.predictions, name)
            with open(path, 'w') as results_file:
                for k, v in sorted(all_answers.items()):
                    if not v or v == '<unk>':
                        v = ''
                    results_file.write(k + ' ' + v + '\n')
Ejemplo n.º 5
0
def _read_raganato_gold_(
        gold_path: str,
        _use_synsets: bool = False,
        input_keys: str = "sensekeys",
        on_error: str = "skip", # skip, keep, raise
        quiet: bool = False,
) -> Dict[str, List[int]]:

    if input_keys == 'bnids':
        bnids_map = ResourceManager.get_bnids_to_offset_map()

    target_dict = {}
    dictionary = \
        ResourceManager.get_offsets_dictionary() if _use_synsets else ResourceManager.get_sensekeys_dictionary()
    with open(gold_path, encoding="utf8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            if input_keys == 'sensekeys':
                instance, *sensekeys = line.split()
                try:
                    senses = [make_offset(patched_lemma_from_key(sk, wordnet).synset()) for sk in sensekeys]
                except Exception as e:
                    print(instance, sensekeys)
                    raise e
            elif input_keys == 'offsets':
                instance, *offsets = line.split()
                offsets_ = []
                for offset in offsets:
                    if offset not in dictionary.indices:
                        msg = f'Error in gold file for instance {instance}: {offset} is not valid.'
                        if on_error == 'keep':
                            offsets_.append(offset)
                            if not quiet:
                                logging.warning('KEEP: ' + msg)
                        elif on_error == 'skip':
                            if not quiet:
                                logging.warning('SKIP: ' + msg)
                        else:
                            raise KeyError(msg)
                    else:
                        offsets_.append(offset)
                senses = offsets_
            elif input_keys == 'bnids':
                instance, *bnids = line.split()
                bnids_ = []
                for bnid in bnids:
                    if bnid not in bnids_map:
                        msg = f'Error in gold file for instance {instance}: {bnid} is not valid or not in WordNet subgraph.'
                        if on_error == 'keep':
                            bnids_.append(bnid)
                            if not quiet:
                                logging.warning('KEEP: ' + msg)
                        elif on_error == 'skip':
                            if not quiet:
                                logging.warning('SKIP: ' + msg)
                        else:
                            raise KeyError(msg)
                    else:
                        bnids_.append(bnid)
                bnids = bnids_
                senses = list({s for b in bnids for s in bnids_map[b]})
            else:
                senses = sensekeys

            if senses:
                senses = [dictionary.index(s) for s in senses]
                senses = remove_dup(senses, dictionary)
                target_dict[instance] = senses
            elif on_error == 'skip':
                if not quiet:
                    logging.warning(f'SKIP: empty gold for instance {instance}.')
            elif on_error == 'keep':
                target_dict[instance] = senses
                if not quiet:
                    logging.warning(f'KEEP: empty gold for instance {instance}.')
            else:
                raise ValueError(f'empty gold for instance {instance}.')
    return target_dict
Ejemplo n.º 6
0
    def get_lemma_pos_to_possible_offsets(cls, lang='en') -> List[List[int]]:
        if cls._lemma_pos_to_possible_offsets.get(lang) is None:
            offsets_dictionary = cls.get_offsets_dictionary()
            lemma_pos_to_possible_offsets = []
            if lang == 'en':
                lemma_pos_dictionary = cls.get_lemma_pos_dictionary(lang=lang)
                for i, lemma_pos in enumerate(lemma_pos_dictionary.symbols):
                    if i < lemma_pos_dictionary.nspecial:
                        lemma_pos_to_possible_offsets.append(
                            [lemma_pos_dictionary.index(lemma_pos)])
                    else:
                        lemma, pos = lemma_pos[:-2], lemma_pos[-1]
                        senses = [
                            offsets_dictionary.index(make_offset(s.synset()))
                            for s in wordnet.lemmas(lemma, pos)
                        ]
                        lemma_pos_to_possible_offsets.append(senses)
                        if lemma_pos not in lemma_pos_dictionary.indices:
                            raise KeyError(
                                f'Lemma pos {lemma_pos} from the lemma pos to possible offsets dictionary'
                                'is not in the lemma pos dictionary.')
            else:
                lemma_pos_dictionary = cls.get_lemma_pos_dictionary(lang)
                offsets_dictionary = cls.get_offsets_dictionary()
                string_map = cls.get_bnids_to_offset_map()

                lemma_pos_string_to_offsets_strings = {}

                with open(
                        os.path.join(EWISER_RES_DIR, 'dictionaries',
                                     'lemma_pos2offsets.' + lang +
                                     '.txt')) as file:
                    for line in file:
                        line = line.strip()
                        if not line:
                            continue
                        lemma_pos, *bnids = line.split(cls._LEMMA2OFFSETS_SEP)
                        if lemma_pos not in lemma_pos_dictionary.indices:
                            raise KeyError(
                                f'Lemma pos {lemma_pos} from the lemma pos to possible offsets dictionary'
                                'is not in the lemma pos dictionary.')

                        if lemma_pos in lemma_pos_string_to_offsets_strings:
                            offsets = lemma_pos_string_to_offsets_strings[
                                lemma_pos]
                        else:
                            offsets = []

                        offsets = offsets + [
                            offset for bnid in bnids
                            for offset in string_map[bnid]
                        ]
                        offsets = list(OrderedDict.fromkeys(offsets).keys())
                        lemma_pos_string_to_offsets_strings[
                            lemma_pos] = offsets

                for i, lemma_pos in enumerate(lemma_pos_dictionary.symbols):
                    if lemma_pos not in lemma_pos_string_to_offsets_strings:
                        lemma_pos_to_possible_offsets.append(
                            [lemma_pos_dictionary.index(lemma_pos)])
                    else:
                        offsets = lemma_pos_string_to_offsets_strings[
                            lemma_pos]
                        senses = [offsets_dictionary.index(o) for o in offsets]
                        lemma_pos_to_possible_offsets.append(senses)

            cls._lemma_pos_to_possible_offsets[
                lang] = lemma_pos_to_possible_offsets
        return cls._lemma_pos_to_possible_offsets[lang]