コード例 #1
0
ファイル: disambiguate.py プロジェクト: FrancescaSAL/ewiser
    def __init__(self,
                 checkpoint,
                 lang='en',
                 dictionary=None,
                 save_wsd_details=True,
                 maxlen=100,
                 batch_size=5):
        super().__init__()

        try:
            self._set_spacy_extensions()
        except ValueError:
            pass

        self.lang = lang
        if lang != 'en':
            self.lemma_pos_dictionary = ResourceManager.get_lemma_pos_dictionary(
                lang=lang)
            self.lemma_pos_to_possible_offsets = ResourceManager.get_lemma_pos_to_possible_offsets(
                lang=lang)

        self.dictionary = dictionary or Dictionary.load(DEFAULT_DICTIONARY)
        self.output_dictionary = ResourceManager.get_offsets_dictionary()
        self.model = self._load_model(
            checkpoint,
            _FakeTask(self.dictionary, self.output_dictionary, 'wsd'))
        self.save_wsd_details = save_wsd_details
        self.maxlen = maxlen
        self.batch_size = batch_size
コード例 #2
0
    def __init__(self, path, dictionary=None, use_synsets=True, keep_string_data=False, lang="en"):

        assert use_synsets

        if not dictionary:

            dictionary = Dictionary.load(DEFAULT_DICTIONARY)

        if not os.path.exists(path):
            os.makedirs(path)

        self.vectors_path = os.path.join(path, 'vectors.hdf5')
        self.gold_path = os.path.join(path, 'gold.pkl')
        if os.path.exists(self.vectors_path):
            os.remove(self.vectors_path)
        if os.path.exists(self.gold_path):
            os.remove(self.gold_path)

        self.keep_string_data = keep_string_data
        if keep_string_data:
            self.oov_dictionary_path = os.path.join(path, 'oov.pkl')
            if os.path.exists(self.oov_dictionary_path):
                os.remove(self.oov_dictionary_path)
            self.oov_dictionary = {}
        else:
            self.oov_dictionary = None

        self.metadata = {
            "lang": lang
        }
        self.metadata_path = os.path.join(path, 'metadata.json')
        if os.path.exists(self.metadata_path):
            os.remove(self.metadata_path)
        self.lang = self.metadata['lang']

        self.hd5_file = h5py.File(self.vectors_path, mode='w')
        self.token_data = None
        self.seq_data = None
        self.gold = []

        self._max_gold = 0
        self._max_sequence = 0

        self.dictionary = dictionary
        self.output_dictionary = ResourceManager.get_offsets_dictionary()

        self.use_synsets = use_synsets
コード例 #3
0
def main(args):

    print("Loading checkpoints: " + " ".join(args.checkpoints))

    data = torch.load(
        args.checkpoints[0],
        map_location='cpu',
    )
    model_args = data['args']
    model_args.cpu = 'cuda' not in args.device
    model_args.context_embeddings_cache = args.device
    state = data['model']
    dictionary = Dictionary.load(DEFAULT_DICTIONARY)
    output_dictionary = ResourceManager.get_offsets_dictionary()

    target_manager = TargetManager(SequenceLabelingTaskKind.WSD)
    task = SequenceLabelingTask(model_args, dictionary, output_dictionary)

    if len(args.checkpoints) == 1:
        model = task.build_model(model_args).cpu().eval()
        model.load_state_dict(state, strict=True)
    else:
        checkpoints = LinearTaggerEnsembleModel.make_args_iterator(
            args.checkpoints)
        model = LinearTaggerEnsembleModel.build_model(
            checkpoints,
            task,
        )

    model = model.eval()
    model.to(args.device)

    datasets = []

    for corpus in args.xmls:
        if corpus.endswith('.data.xml'):
            dataset = WSDDataset.read_raganato(
                corpus,
                dictionary,
                use_synsets=True,
                max_length=args.max_length,
                on_error='keep',
                quiet=args.quiet,
                read_by=args.read_by,
            )
        else:
            with open(corpus, 'rb') as pkl:
                dataset = pickle.load(pkl)

        datasets.append(dataset)

    corpora = zip(args.xmls, datasets)

    for corpus, dataset in corpora:

        hit, tot = 0, 0
        all_answers = {}
        for sample_original in DataLoader(dataset,
                                          collate_fn=dataset.collater,
                                          batch_size=args.batch_size):
            with torch.no_grad():
                net_output = model(
                    **{
                        k:
                        v.to(args.device) if isinstance(v, torch.Tensor) else v
                        for k, v in sample_original['net_input'].items()
                    })
                lprobs = model.get_normalized_probs(net_output,
                                                    log_probs=True).cpu()

            results, answers = target_manager.calulate_metrics(
                lprobs, sample_original)
            all_answers.update(answers)
            hit += results['hit']
            tot += results['tot']

        T = 0
        gold_answers = defaultdict(set)
        gold_path = Path(corpus.replace('data.xml', 'gold.key.txt'))
        bnids_map = None
        for line in gold_path.read_text().splitlines():
            pieces = line.strip().split(' ')
            if not pieces:
                continue
            trg, *gold = pieces
            T += 1
            for g in gold:
                if g.startswith('bn:'):
                    if bnids_map is None:
                        bnids_map = ResourceManager.get_bnids_to_offset_map()
                    o = bnids_map.get(g)
                    if o is None:
                        if args.on_error == 'keep':
                            o = {
                                g,
                            }
                            gold_answers[trg] |= o
                    else:
                        gold_answers[trg] |= o
                elif g.startswith('wn:'):
                    gold_answers[trg].add(g)
                else:
                    try:
                        o = make_offset(patched_lemma_from_key(g).synset())
                    except Exception:
                        o = None
                    if o is None:
                        if args.on_error == 'keep':
                            gold_answers[trg].add(g)
                    else:
                        gold_answers[trg].add(o)

        all_answers = {
            k: output_dictionary.symbols[v]
            for k, v in all_answers.items()
        }

        if args.on_error == 'skip':
            N = len([t for t, aa in gold_answers.items() if aa])
        else:
            N = len(gold_answers)
        ok, notok = 0, 0
        for k, answ in all_answers.items():
            gold = gold_answers.get(k)

            if not gold:
                continue
            if not answ or answ == '<unk>':
                continue
            if answ in gold:
                ok += 1
            else:
                notok += 1

        M = 0
        for k, gg in gold_answers.items():
            if args.on_error == 'skip' and (not gg):
                continue
            valid = False
            for g in gg:
                if g.startswith('wn:'):
                    valid = True
            if not valid:
                print(k, all_answers.get(k), gg)
            a = all_answers.get(k)
            if a is None or a == '<unk>':
                M += 1

        try:
            precision = ok / (ok + notok)
        except ZeroDivisionError:
            precision = 0.

        try:
            recall = ok / N
        except ZeroDivisionError:
            recall = 0.

        try:
            f1 = (2 * precision * recall) / (precision + recall)
        except ZeroDivisionError:
            f1 = 0.

        print(corpus)
        print(
            f'P: {precision}\tR: {recall}\tF1: {f1}\tN/T:{N}/{T}\tY/N/M/S: {ok}/{notok}/{M}/{T-N}'
        )

        if args.predictions:
            if not os.path.exists(args.predictions):
                os.mkdir(args.predictions)
            name = ".".join(
                os.path.split(corpus)[-1].split('.')[:-2]) + '.results.key.txt'
            path = os.path.join(args.predictions, name)
            with open(path, 'w') as results_file:
                for k, v in sorted(all_answers.items()):
                    if not v or v == '<unk>':
                        v = ''
                    results_file.write(k + ' ' + v + '\n')
コード例 #4
0
def read_graph(*paths, input_keys=None):

    self_loops_count = 0
    offsets = ResourceManager.get_offsets_dictionary()

    if input_keys is None:
        with open(paths[0]) as file:
            for line in file:
                line = line.strip()
                if not line:
                    continue
                else:
                    offset1, offset2, *info = line.split()
                    if offset1.startswith('bn:'):
                        input_keys = 'bnids'
                    elif offset1.startswith('wn:'):
                        input_keys = 'offsets'
                    else:
                        input_keys = 'sensekeys'
                    break
    assert input_keys is not None

    remap = ResourceManager.get_bnids_to_offset_map()
    g = nx.DiGraph()

    for path in paths:

        with open(path) as file:
            for line in file:
                line = line.strip()
                if not line:
                    continue
                offset1, offset2, *info = line.split()
                if 0 <= len(info) <= 1:
                    w = None
                else:
                    try:
                        w = float(info[1])
                    except ValueError:
                        w = None

                if offset1.startswith('bn:'):
                    offsets1 = remap.get(offset1)
                elif offset1.startswith('wn:'):
                    offsets1 = [offset1]
                else:
                    raise NotImplementedError

                if offset2.startswith('bn:'):
                    offsets2 = remap.get(offset2)
                elif offset2.startswith('wn:'):
                    offsets2 = [offset2]
                else:
                    raise NotImplementedError

                for offset1, offset2 in itertools.product(offsets1, offsets2):
                    offset1 = fix_offset(offset1)  # v -> child in hypernymy
                    offset2 = fix_offset(offset2)  # u -> father in hypernymy
                    trg_node = offsets.index(offset1)
                    src_node = offsets.index(offset2)
                    g.add_edge(src_node, trg_node, w=w)
                    self_loops_count += int(src_node == trg_node)

    return g
コード例 #5
0
def _read_raganato_xml(
        xml_path: str,
        read_by: Union[str, RaganatoReadBy] = RaganatoReadBy.TEXT,
        dictionary: Optional[Dictionary] = None,
        tagset='universal', #universal, multilingual
        lang='en',
        inst_to_keep=None,
        on_error='skip', # skip, keep, raise
        quiet=True,
) -> Tuple[np.ndarray, List[str], Dict[int, str]]:

    if isinstance(read_by, str):
        read_by = getattr(RaganatoReadBy, read_by.upper())
    elif isinstance(read_by, RaganatoReadBy):
        read_by = read_by.value
    else:
        raise TypeError

    if not dictionary:
        dictionary = Dictionary.load(DEFAULT_DICTIONARY)

    assert tagset in {'universal', 'multilingual'}

    oov_dictionary = {}

    pos_dictionary = ResourceManager.get_pos_dictionary()
    lemma_pos_dictionary = ResourceManager.get_lemma_pos_dictionary(lang)
    lemma_pos_to_possible_offsets_map = ResourceManager.get_lemma_pos_to_possible_offsets(lang)
    offsets_dictionary = ResourceManager.get_offsets_dictionary()

    total_text_n = []
    total_token = []
    total_lemma_pos = []
    total_pos = []
    total_gold_indices = []

    token = []
    token_strings = []
    lemma_pos = []
    pos = []
    gold_indices = []
    target_labels = []
    gold_idx = 0

    def discharge():

        cond1 = bool(token)
        # cond2 = True
        cond2 = any(g > -1 for g in gold_indices)

        if cond1 and cond2:

            if not total_text_n:
                old_text_n = -1
            else:
                old_text_n = total_text_n[-1]

            text_n = [old_text_n + 1] * len(token)

            for token_number, (t, ts) in enumerate(zip(token, token_strings), start=len(total_token)):

                if t == dictionary.unk_index:
                    oov_dictionary[token_number] = ts

            total_text_n.extend(text_n)
            total_token.extend(token)
            total_lemma_pos.extend(lemma_pos)
            total_pos.extend(pos)
            total_gold_indices.extend(gold_indices)

        token.clear()
        token_strings.clear()
        lemma_pos.clear()
        pos.clear()
        gold_indices.clear()

    old_text_number = -1

    for token_number, (text_number, word) in enumerate(read_by(xml_path)):

        if not word.text:
            continue

        if text_number != old_text_number:
            discharge()

        old_text_number = text_number

        if tagset == 'universal':
            t = word.text.replace('_', ' ')
            p = word.attrib["pos"]
            if len(p) == 1:
                lp = lemma_pos_dictionary.index(
                    word.attrib.get("lemma", 'X').lower() + '#' + p.lower().replace('j', 'a')
                )
            else:
                lp = lemma_pos_dictionary.index(
                    word.attrib.get("lemma", 'X').lower() + '#' + _ud_to_wn.get(p, 'n')
                )
            pos.append(pos_dictionary.index(p))

        elif tagset == 'multilingual':
            raise
            try:
                text = word.text.strip().replace(' ', '_')
            except Exception as e:
                print(etree.tostring(word))
                raise e
            try:
                t, l, p = text.split('/')
            except ValueError as e:
                t = text.split('/')[0]
                p = word.attrib["pos"][0].lower()
                l = word.attrib["lemma"].lower()
                print(etree.tostring(word))
            l = l.lower().strip()
            p = p.lower().strip()
            lp = lemma_pos_dictionary.index(l + '#' + p)
            pos.append(pos_dictionary.unk_index)

        else:
            raise

        token_strings.append(t)
        token.append(dictionary.index(t))
        lemma_pos.append(lp)

        idx = word.attrib.get("id")
        if idx and word.tag == 'instance':

            ignore = False
            in_lemma_pos = True

            if len(p) == 1:
                lp_string = unicodedata.normalize('NFC', word.attrib["lemma"].lower()) + '#' + word.attrib['pos'].lower().replace('j', 'a')
            else:
                lem = word.attrib.get("lemma").lower()
                wnpos = _ud_to_wn.get(word.attrib["pos"], 'n')
                if not lem and word.tag == 'instance':
                    lem = wordnet.morphy(t, wnpos)
                    if not lem:
                        lem = ''
                lp = lemma_pos_dictionary.index(
                    lem + '#' + wnpos
                )
                lp_string = unicodedata.normalize('NFC', lem) + '#' + wnpos
            if lp_string not in lemma_pos_dictionary.indices:
                in_lemma_pos = False
                msg = f'Lemma and pos "{lp_string}" for instance "{idx}" not in the lemma pos dictionary.'
                if on_error == 'skip':
                    ignore = True
                    if not quiet:
                        logging.warning('SKIP:' + msg)
                elif on_error == 'keep':
                    if not quiet:
                        logging.warning('KEEP:' + msg)
                else:
                    raise KeyError(msg)

            if inst_to_keep and (idx in inst_to_keep) and in_lemma_pos:

                gold = inst_to_keep[idx]
                possible = lemma_pos_to_possible_offsets_map[lp]
                possible_str = [offsets_dictionary.symbols[x] for x in possible]
                for g in gold:

                    o = offsets_dictionary.symbols[g]

                    if o not in possible_str:

                        msg = (
                            f'"{o}" (instance "{idx}") '
                            f'not among the possible for lemma pos "{lp_string}". '
                            f'Possible: {possible_str}.'
                        )
                        if on_error == 'skip':
                            ignore = True
                            if not quiet:
                                logging.warning('SKIP:' + msg)
                        elif on_error == 'keep':
                            if not quiet:
                                logging.warning('KEEP:' + msg)
                        else:
                            raise KeyError(msg)

            if ignore:
                gold_indices.append(-1)

            elif (not inst_to_keep) or (idx in inst_to_keep):

                target_labels.append(idx)
                gold_indices.append(gold_idx)
                gold_idx += 1

            else:
                gold_indices.append(-1)
        else:
            gold_indices.append(-1)

    discharge()

    text_n = np.array(total_text_n, dtype=np.int64)
    token = np.array(total_token, dtype=np.int32)
    lemma_pos = np.array(total_lemma_pos, dtype=np.int32)
    pos = np.array(total_pos, dtype=np.int8)
    gold_indices = np.array(total_gold_indices, dtype=np.int32)

    raw_data = np.rec.fromarrays(
        [text_n, token, lemma_pos, pos, gold_indices],
        names=['text_n', 'token', 'lemma_pos', 'pos', 'gold_indices']
    )

    return raw_data, target_labels, oov_dictionary
コード例 #6
0
    def __init__(
            self,
            path: str,
            dictionary: Optional[Dictionary] = None,
            target_classes: str = "offsets",
            add_monosemous: bool = False,
            shuffle: bool = False,
            lazy: bool = False,
    ) -> None:

        self._loaded = False

        metadata_path = os.path.join(path, 'metadata.json')
        if os.path.exists(metadata_path):
            with open(metadata_path) as json_hdl:
                self.metadata = json.load(json_hdl)
        else:
            self.metadata = {"lang": "en"}

        self.use_synsets = target_classes == 'offsets'

        if not dictionary:
            dictionary = Dictionary.load(DEFAULT_DICTIONARY)
        self.dictionary = dictionary

        self.lemma_pos_dictionary = ResourceManager.get_lemma_pos_dictionary(self.lang)

        self.remap_senses = None
        if target_classes == 'offsets':
            self.output_dictionary = ResourceManager.get_offsets_dictionary()
            self.lemma_pos_to_possible_senses = ResourceManager.get_lemma_pos_to_possible_offsets(self.lang)
        elif target_classes == 'sensekeys':
            self.output_dictionary = ResourceManager.get_sensekeys_dictionary()
            self.lemma_pos_to_possible_senses = ResourceManager.get_lemma_pos_to_possible_sensekeys(self.lang)
        elif target_classes == 'bnids':
            self.output_dictionary = ResourceManager.get_offsets_dictionary()
            #self.remap_senses = ResourceManager.get_index_remap_offset_bnids()
            self.lemma_pos_to_possible_senses = ResourceManager.get_lemma_pos_to_possible_offsets(self.lang)
        else:
            raise ValueError('target_classes must be in {"sensekeys", "offsets", "bnids"} but was"' + target_classes +
                             '" instead')

        self.add_monosemous = add_monosemous
        self.shuffle = shuffle

        self.path = os.path.abspath(path)

        vectors_path = os.path.join(path, 'vectors.hdf5')
        gold_path = os.path.join(path, 'gold.pkl')
        oov_dictionary_path = os.path.join(path, 'oov.pkl')

        self._h5_file = h5py.File(vectors_path, mode="r")
        self._token_data_h5 = self._h5_file['.']['token_data']
        self._token_data_mem = None
        self._seq_data_h5 = self._h5_file['.']['seq_data']
        self._seq_data_mem = None
        with open(gold_path, 'rb') as pkl:
            self.gold = pickle.load(pkl)
        if os.path.exists(oov_dictionary_path):
            with open(oov_dictionary_path, 'rb') as pkl:
                self.oov_dictionary = pickle.load(pkl)
        else:
            self.oov_dictionary = None

        self.original_sizes = self.seq_data()[:, 1]
        self.target_sizes = None
        self._from_tmp = False

        if not lazy:
            self.load_in_memory()

        self.embeddings = None
        self.trg_manager = TargetManager('wsd')
コード例 #7
0
def _read_raganato_gold_(
        gold_path: str,
        _use_synsets: bool = False,
        input_keys: str = "sensekeys",
        on_error: str = "skip", # skip, keep, raise
        quiet: bool = False,
) -> Dict[str, List[int]]:

    if input_keys == 'bnids':
        bnids_map = ResourceManager.get_bnids_to_offset_map()

    target_dict = {}
    dictionary = \
        ResourceManager.get_offsets_dictionary() if _use_synsets else ResourceManager.get_sensekeys_dictionary()
    with open(gold_path, encoding="utf8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            if input_keys == 'sensekeys':
                instance, *sensekeys = line.split()
                try:
                    senses = [make_offset(patched_lemma_from_key(sk, wordnet).synset()) for sk in sensekeys]
                except Exception as e:
                    print(instance, sensekeys)
                    raise e
            elif input_keys == 'offsets':
                instance, *offsets = line.split()
                offsets_ = []
                for offset in offsets:
                    if offset not in dictionary.indices:
                        msg = f'Error in gold file for instance {instance}: {offset} is not valid.'
                        if on_error == 'keep':
                            offsets_.append(offset)
                            if not quiet:
                                logging.warning('KEEP: ' + msg)
                        elif on_error == 'skip':
                            if not quiet:
                                logging.warning('SKIP: ' + msg)
                        else:
                            raise KeyError(msg)
                    else:
                        offsets_.append(offset)
                senses = offsets_
            elif input_keys == 'bnids':
                instance, *bnids = line.split()
                bnids_ = []
                for bnid in bnids:
                    if bnid not in bnids_map:
                        msg = f'Error in gold file for instance {instance}: {bnid} is not valid or not in WordNet subgraph.'
                        if on_error == 'keep':
                            bnids_.append(bnid)
                            if not quiet:
                                logging.warning('KEEP: ' + msg)
                        elif on_error == 'skip':
                            if not quiet:
                                logging.warning('SKIP: ' + msg)
                        else:
                            raise KeyError(msg)
                    else:
                        bnids_.append(bnid)
                bnids = bnids_
                senses = list({s for b in bnids for s in bnids_map[b]})
            else:
                senses = sensekeys

            if senses:
                senses = [dictionary.index(s) for s in senses]
                senses = remove_dup(senses, dictionary)
                target_dict[instance] = senses
            elif on_error == 'skip':
                if not quiet:
                    logging.warning(f'SKIP: empty gold for instance {instance}.')
            elif on_error == 'keep':
                target_dict[instance] = senses
                if not quiet:
                    logging.warning(f'KEEP: empty gold for instance {instance}.')
            else:
                raise ValueError(f'empty gold for instance {instance}.')
    return target_dict