Ejemplo n.º 1
0
 def build_bpe_encoder(cls,
                       encoder_json_path=None,
                       vocab_bpe_path=None,
                       *,
                       node_freq_min=5,
                       node_file_path=None,
                       others_file_path=None):
     # use the default path if not provided
     encoder_json_path = encoder_json_path or file_utils.cached_path(
         DEFAULT_ENCODER_JSON)
     vocab_bpe_path = vocab_bpe_path or file_utils.cached_path(
         DEFAULT_VOCAB_BPE)
     # build the gpt2 bpe encoder
     with open(encoder_json_path, "r") as f:
         encoder = json.load(f)
     with open(vocab_bpe_path, "r", encoding="utf-8") as f:
         bpe_data = f.read()
     bpe_merges = [
         tuple(merge_str.split())
         for merge_str in bpe_data.split("\n")[1:-1]
     ]
     bpe_encoder = cls(encoder, bpe_merges, errors="replace")
     bpe_encoder.add_amr_action_vocabulary(
         node_freq_min=node_freq_min,
         node_file_path=node_file_path,
         others_file_path=others_file_path)
     return bpe_encoder
Ejemplo n.º 2
0
 def __init__(self, args):
     encoder_json = file_utils.cached_path(
         getattr(args, 'gpt2_encoder_json', DEFAULT_ENCODER_JSON)
     )
     vocab_bpe = file_utils.cached_path(
         getattr(args, 'gpt2_vocab_bpe', DEFAULT_VOCAB_BPE)
     )
     self.bpe = get_encoder(encoder_json, vocab_bpe)
Ejemplo n.º 3
0
    def __init__(self, cfg):
        try:
            from tokenizers import ByteLevelBPETokenizer
        except ImportError:
            raise ImportError("Please install huggingface/tokenizers with: "
                              "pip install tokenizers")

        bpe_vocab = file_utils.cached_path(cfg.bpe_vocab)
        bpe_merges = file_utils.cached_path(cfg.bpe_merges)

        self.bpe = ByteLevelBPETokenizer(
            bpe_vocab,
            bpe_merges,
            add_prefix_space=cfg.bpe_add_prefix_space,
        )
Ejemplo n.º 4
0
 def __init__(self, args):
     if args.bpe_codes is None:
         raise ValueError('--bpe-codes is required for --bpe=subword_nmt')
     codes = file_utils.cached_path(args.bpe_codes)
     try:
         from subword_nmt import apply_bpe
         bpe_parser = apply_bpe.create_parser()
         bpe_args = bpe_parser.parse_args([
             '--codes',
             codes,
             '--separator',
             args.bpe_separator,
         ])
         import codecs
         bpe_args.codes = codecs.open(codes, encoding='utf-8')
         self.bpe = apply_bpe.BPE(
             bpe_args.codes,
             bpe_args.merges,
             bpe_args.separator,
             None,
             bpe_args.glossaries,
         )
         self.bpe_symbol = bpe_args.separator + ' '
     except ImportError:
         raise ImportError(
             'Please install subword_nmt with: pip install subword-nmt')
Ejemplo n.º 5
0
    def __init__(self,
                 dict_txt_path=None,
                 node_freq_min=5,
                 node_file_path=None,
                 others_file_path=None,
                 **kwargs):
        super().__init__(**kwargs)

        # build the base dictionary on gpt2 bpe token ids (used by BART)
        dict_txt_path = dict_txt_path or file_utils.cached_path(
            DEFAULT_DICT_TXT)
        self.add_from_file(dict_txt_path)

        # the size of the original BART vocabulary; this is needed to truncate the pretrained BART vocabulary, as
        # it comes with '<mask>' from denoising task and
        # bart.base has larger vocabulary with more padded <madeupwordxxxx>
        self.bart_vocab_size = len(self.symbols)

        # build the extended bpe tokenizer and encoder
        self.bpe = AMRActionBPEEncoder.build_bpe_encoder(
            node_freq_min=node_freq_min,
            node_file_path=node_file_path,
            others_file_path=others_file_path)

        # add the new tokens to the vocabulary (NOTE the added symbols are the index in the bpe vocabulary)
        for tok in self.bpe.additions:
            self.add_symbol(str(self.bpe.encoder[tok]))
Ejemplo n.º 6
0
    def __init__(self, cfg):
        if cfg.bpe_codes is None:
            raise ValueError("--bpe-codes is required for --bpe=subword_nmt")
        codes = file_utils.cached_path(cfg.bpe_codes)
        try:
            from subword_nmt import apply_bpe

            bpe_parser = apply_bpe.create_parser()
            bpe_args = bpe_parser.parse_args([
                "--codes",
                codes,
                "--separator",
                cfg.bpe_separator,
            ])
            self.bpe = apply_bpe.BPE(
                bpe_args.codes,
                bpe_args.merges,
                bpe_args.separator,
                None,
                bpe_args.glossaries,
            )
            self.bpe_symbol = bpe_args.separator + " "
        except ImportError:
            raise ImportError(
                "Please install subword_nmt with: pip install subword-nmt")
Ejemplo n.º 7
0
 def __init__(self, args):
     vocab = file_utils.cached_path(args.sentencepiece_vocab)
     try:
         import sentencepiece as spm
         self.sp = spm.SentencePieceProcessor()
         self.sp.Load(vocab)
     except ImportError:
         raise ImportError('Please install sentencepiece with: pip install sentencepiece')
Ejemplo n.º 8
0
 def __init__(self, args):
     if args.bpe_codes is None:
         raise ValueError('--bpe-codes is required for --bpe=subword_nmt')
     codes = file_utils.cached_path(args.bpe_codes)
     try:
         import fastBPE
         self.bpe = fastBPE.fastBPE(codes)
         self.bpe_symbol = "@@ "
     except ImportError:
         raise ImportError('Please install fastBPE with: pip install fastBPE')
Ejemplo n.º 9
0
    def __init__(self, cfg):
        sentencepiece_model = file_utils.cached_path(cfg.sentencepiece_model)
        try:
            import sentencepiece as spm

            self.sp = spm.SentencePieceProcessor()
            self.sp.Load(sentencepiece_model)
        except ImportError:
            raise ImportError(
                "Please install sentencepiece with: pip install sentencepiece")
Ejemplo n.º 10
0
    def __init__(self, cfg):
        if cfg.bpe_codes is None:
            raise ValueError("--bpe-codes is required for --bpe=fastbpe")
        codes = file_utils.cached_path(cfg.bpe_codes)
        try:
            import fastBPE

            self.bpe = fastBPE.fastBPE(codes)
            self.bpe_symbol = "@@ "
        except ImportError:
            raise ImportError("Please install fastBPE with: pip install fastBPE")
Ejemplo n.º 11
0
    def __init__(self, args):
        vocab = file_utils.cached_path(args.sentencepiece_model)
        try:
            import sentencepiece as spm
            self.sp = spm.SentencePieceProcessor()
            self.sp.Load(vocab)
        except ImportError:
            raise ImportError('Please install sentencepiece with: pip install sentencepiece')

        self.mixed_case_regex = regex.compile('(▁?[[:upper:]]?[^[:upper:]\s▁]+|▁?[[:upper:]]+|▁)')
        self.tags = ['<medical>']  # protect these tags at the beginning of sentences
        self.medical = args.medical
Ejemplo n.º 12
0
def parse_label_schema(path):
    LabelSchema = namedtuple(
        "LabelSchema",
        [
            "labels",
            "group_name_to_labels",
            "label_categories",
            "category_to_group_names",
            "separator",
            "group_names",
            "null",
            "null_leaf",
            "ignore_categories",
        ],
    )
    path = file_utils.cached_path(path)
    with open(path, "r") as fp:
        j_obj = json.load(fp)
    return LabelSchema(**j_obj)
Ejemplo n.º 13
0
 def __init__(self, args):
     codes = file_utils.cached_path(args.bpe_codes)
     try:
         from subword_nmt import apply_bpe
         bpe_parser = apply_bpe.create_parser()
         bpe_args = bpe_parser.parse_args([
             '--codes',
             codes,
             '--separator',
             args.bpe_separator,
         ])
         self.bpe = apply_bpe.BPE(
             bpe_args.codes,
             bpe_args.merges,
             bpe_args.separator,
             None,
             bpe_args.glossaries,
         )
         self.bpe_symbol = bpe_args.separator + ' '
     except ImportError:
         raise ImportError(
             'Please install subword_nmt with: pip install subword-nmt')
Ejemplo n.º 14
0
 def __init__(self, args):
     vocab_bpe = file_utils.cached_path(
         getattr(args, 'vocab_bpe', DEFAULT_VOCAB_BPE))
     self.bpe = yttm.BPE(model=vocab_bpe)
Ejemplo n.º 15
0
 def __init__(self, cfg):
     encoder_json = file_utils.cached_path(cfg.gpt2_encoder_json)
     vocab_bpe = file_utils.cached_path(cfg.gpt2_vocab_bpe)
     self.bpe = get_encoder(encoder_json, vocab_bpe)
Ejemplo n.º 16
0
 def __init__(self, encoder_json, vocab_bpe):
     encoder_json = file_utils.cached_path(encoder_json)
     vocab_bpe = file_utils.cached_path(vocab_bpe)
     self.bpe = get_encoder(encoder_json, vocab_bpe)
Ejemplo n.º 17
0
class Diagnostic():
    DEFAULT_ENCODER_JSON = 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json'
    DEFAULT_VOCAB_BPE = 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe'

    encoder_json = file_utils.cached_path(
        getattr(None, 'gpt2_encoder_json', DEFAULT_ENCODER_JSON))
    vocab_bpe = file_utils.cached_path(
        getattr(None, 'gpt2_vocab_bpe', DEFAULT_VOCAB_BPE))

    filter_tokens = [cd.pad_df]
    replace_tokens = [
        cd.head_token_df,
        cd.tail_token_df,
        cd.blank_token_df,
        cd.e1_start_token_df,
        cd.e1_end_token_df,
        cd.e2_start_token_df,
        cd.e2_end_token_df,
        cd.blank_head_other_df,
        cd.blank_tail_other_df,
    ]

    def __init__(self, dictionary, entity_dictionary, task=None):
        self.bpe = get_encoder(self.encoder_json, self.vocab_bpe)
        self.dictionary = dictionary
        self.entity_dictionary = entity_dictionary
        self.task = task

    def decode_text(self, text):
        token_ids = self.dictionary.string(text).split()
        processed_ids = []
        for token in token_ids:
            if token in self.replace_tokens:
                processed_ids += self.bpe.encode(' ' + token)
            elif token not in self.filter_tokens:
                processed_ids.append(token)

        decoded_text = self.bpe.decode([int(x) for x in processed_ids])

        return decoded_text

    def inspect_item(self, text_id, head_id=None, tail_id=None):

        decoded_text = self.decode_text(text_id)
        print('\n\nTEXT ID LIST:\n {}\n'.format(
            self.dictionary.string(text_id).split()))
        print('DECODED TEXT:\n {}\n'.format(decoded_text))

        if head_id is not None:
            head_ent = self.entity_dictionary[head_id]
            print('<head> ENTITY:\n {} (ID={})\n'.format(head_ent, head_id))
        if tail_id is not None:
            tail_ent = self.entity_dictionary[tail_id]
            print('<tail> ENTITY:\n {} (ID={})\n'.format(tail_ent, tail_id))
        print('\n')

    def inspect_mtb_pairs(self, A_dict=None, B_dict=None):

        if A_dict is not None:
            print('\n')
            for key, val in A_dict.items():
                if key == 'textA':
                    decoded_text = self.decode_text(val)
                    print('{} ID LIST:\n {}\n'.format(
                        key,
                        self.dictionary.string(val).split()))
                    print('{} DECODED TEXT:\n {}\n'.format(key, decoded_text))
                else:
                    ent = self.entity_dictionary[val]
                    print('<{}> ENTITY:\n {} (ID={})\n'.format(key, ent, val))
            print('\n')

        if B_dict is not None:
            for i in range(len(B_dict)):
                pair_type = 'POSITIVE' if i == 0 else 'NEGATIVE {}'.format(i)
                decoded_text = self.decode_text(B_dict['textB'][i])
                print('\n{} ID LIST ({}):\n {}\n'.format(
                    'textB', pair_type,
                    self.dictionary.string(B_dict['textB'][i]).split()))
                print('{} DECODED TEXT ({}):\n {}\n'.format(
                    'textB', pair_type, decoded_text))

                head_ent = self.entity_dictionary[B_dict['headB'][i]]
                print('<{}> ENTITY ({}):\n {} (ID={})\n'.format(
                    'headB', pair_type, head_ent, B_dict['headB'][i]))

                tail_ent = self.entity_dictionary[B_dict['tailB'][i]]
                print('<{}> ENTITY ({}):\n {} (ID={})\n\n'.format(
                    'tailB', pair_type, tail_ent, B_dict['tailB'][i]))

        print('------------------------------------------------------------')

    def inspect_batch(self, batch, ent_filter=None, scores=None):

        batch_size = batch['size']
        target = batch['target']

        if self.task.args.task == 'triplet_inference':
            text_id = batch['text']
            head_id = batch['head']
            tail_id = batch['tail']
        elif self.task.args.task == 'fewrel':
            text_id = batch['text']
            n_way = self.task.args.n_way
            n_shot = self.task.args.n_shot
            exemplars_id = batch['exemplars'].reshape(batch_size, n_way,
                                                      n_shot, -1)
        elif self.task.args.task in ['kbp37', 'semeval2010task8', 'tacred']:
            text_id = batch['text']
        elif self.task.args.task == 'mtb':
            textA_id = batch['textA']
            textB_id = []
            for cluster_id, cluster_texts in batch['textB'].items():
                textB_chunks = list(
                    torch.chunk(cluster_texts, cluster_texts.shape[0], dim=0))
                textB_id += textB_chunks
            headA_id = batch['headA']
            tailA_id = batch['tailA']
            headB_id = batch['headB']
            tailB_id = batch['tailB']
            n_pairs = int(batch['A2B'].numel() / batch['size'])

        for i in range(batch_size):
            if self.task.args.task == 'triplet_inference':
                if ent_filter is None:
                    pass
                elif head_id[i, 0] not in ent_filter or tail_id[
                        i, 0] not in ent_filter:
                    continue
                decoded_text = self.decode_text(text_id[i])
                pos_head_ent = self.task.entity_dictionary[head_id[i, 0]]
                pos_tail_ent = self.task.entity_dictionary[tail_id[i, 0]]
                neg_head_ent = [
                    self.task.entity_dictionary[head_id[i, j]]
                    for j in range(1, head_id.shape[1])
                ]
                neg_tail_ent = [
                    self.task.entity_dictionary[tail_id[i, j]]
                    for j in range(1, tail_id.shape[1])
                ]

                print('\n\nTEXT ID LIST:\n {}\n'.format(
                    self.task.dictionary.string(text_id[i]).split()))
                print('DECODED TEXT:\n {}\n'.format(decoded_text))
                print('POSITIVE <head> ENTITY:\n {} (ID={})\n'.format(
                    pos_head_ent, head_id[i, 0].item()))
                print('POSITIVE <tail> ENTITY:\n {} (ID={})\n'.format(
                    pos_tail_ent, tail_id[i, 0].item()))
                print('NEGATIVE <head> ENTITIES:\n {} (ID={})\n'.format(
                    neg_head_ent, head_id[i, 1:].cpu().numpy()))
                print('NEGATIVE <tail> ENTITIES:\n {} (ID={})\n'.format(
                    neg_tail_ent, tail_id[i, 1:].cpu().numpy()))

                print('TARGET: \n {}\n'.format(
                    target[i].cpu().detach().numpy()))
                if scores is not None:
                    print('SCORES: \n {}\n'.format(
                        F.softmax(scores[i, :],
                                  dim=-1).cpu().detach().numpy()))
                else:
                    print('\n')

            elif self.task.args.task == 'fewrel':
                decoded_text = self.decode_text(text_id[i])
                decoded_exemplars = {}
                for j in range(n_way):
                    decoded_exemplars[j] = set()
                    for k in range(n_shot):
                        # cur_exemplar_id = self.task.dictionary.string(exemplars_id[i,j,k,:]).split()
                        decoded_exemplars[j].add(
                            self.decode_text(exemplars_id[i, j, k, :]))

                print('\n\nTEXT ID LIST:\n {}\n'.format(
                    self.task.dictionary.string(text_id[i]).split()))
                print('DECODED TEXT:\n {}\n'.format(decoded_text))
                print('DECODED EXEMPLARS (w/o ENTITIES):')
                for j in range(n_way):
                    print('Class {0}: {1}\n'.format(j, decoded_exemplars[j]))

                print('TARGET: \n {}\n'.format(
                    target[i].cpu().detach().numpy()))
                if scores is not None:
                    print('SCORES: \n {}\n'.format(
                        F.softmax(scores[i, :],
                                  dim=-1).cpu().detach().numpy()))
                else:
                    print('\n')

            elif self.task.args.task in [
                    'kbp37', 'semeval2010task8', 'tacred'
            ]:
                decoded_text = self.decode_text(text_id[i])
                print('\n\nTEXT ID LIST:\n {}\n'.format(
                    self.task.dictionary.string(text_id[i]).split()))
                print('DECODED TEXT:\n {}\n'.format(decoded_text))

                print('TARGET: \n {}\n'.format(
                    target[i].cpu().detach().numpy()))
                if scores is not None:
                    print('SCORES: \n {}\n'.format(
                        F.softmax(scores[i, :],
                                  dim=-1).cpu().detach().numpy()))
                else:
                    print('\n')

            elif self.task.args.task == 'mtb':
                if ent_filter is None:
                    pass
                elif headA_id[i] not in ent_filter or tailA_id[
                        i] not in ent_filter or headB_neg_id[
                            i] not in ent_filter or tailB_neg_id[
                                i] not in ent_filter:
                    continue

                # Print textA, headA, and tailA
                cur_textA = textA_id[i]
                decoded_textA = self.decode_text(cur_textA)
                print('\nTEXTA ID LIST:\n {}\n'.format(
                    self.task.dictionary.string(cur_textA).split()))
                print('DECODED TEXTA:\n {}\n'.format(decoded_textA))

                headA_ent = self.task.entity_dictionary[headA_id[i]]
                print('<headA> ENTITY:\n {} (ID={})\n'.format(
                    headA_ent, headA_id[i].item()))

                tailA_ent = self.task.entity_dictionary[tailA_id[i]]
                print('<tailA> ENTITY:\n {} (ID={})\n'.format(
                    tailA_ent, tailA_id[i].item()))

                # Print textB, headB, and tailB
                for j in range(n_pairs):
                    cur_textB = textB_id[batch['A2B'][i * n_pairs + j]]
                    decoded_textB = self.decode_text(cur_textB)
                    pair_type = 'POSITIVE' if j == 0 else 'NEGATIVE {}'.format(
                        j)
                    print('\nTEXTB ID LIST ({}):\n {}\n'.format(
                        pair_type,
                        self.task.dictionary.string(cur_textB).split()))
                    print('DECODED TEXTB ({}):\n {}\n'.format(
                        pair_type, decoded_textB))

                    headB_ent = self.task.entity_dictionary[headB_id[i, j]]
                    print('<headB> ENTITY ({}):\n {} (ID={})\n'.format(
                        pair_type, headB_ent, headB_id[i, j].item()))

                    tailB_ent = self.task.entity_dictionary[tailB_id[i, j]]
                    print('<tailB> ENTITY ({}):\n {} (ID={})\n'.format(
                        pair_type, tailB_ent, tailB_id[i, j].item()))

                # Print targets and scores
                print('TARGETS: \n {}\n'.format(
                    np.array([1] + (n_pairs - 1) * [0])))
                if scores is not None:
                    print('SCORES: \n {}\n\n'.format(
                        np.round(torch.sigmoid(
                            scores.reshape(batch_size,
                                           n_pairs)[i]).detach().cpu().numpy(),
                                 decimals=5)))
                else:
                    print('\n')
                print(
                    '------------------------------------------------------------\n'
                )