Ejemplo n.º 1
0
    def __init__(self, chunked_sents=[], feature_func=None, model_file=None, training_opt={}):
 
        # Transform the trees in IOB annotated sentences [(word, pos, chunk), ...]
        # chunked_sents = [tree2conlltags(sent) for sent in chunked_sents]
 
        # Transform the triplets in pairs, make it compatible with the tagger interface [((word, pos), chunk), ...]
        def triplets2tagged_pairs(iob_sent):
            return [((word, pos), chunk) for word, pos, chunk in iob_sent]
        
        chunked_sents = [convert_scheme(sent) for sent in chunked_sents]
        chunked_sents = [triplets2tagged_pairs(sent) for sent in chunked_sents]
 
        if feature_func is not None:
            feat_func = feature_func
        else:
            feat_func = self._feature_detector
        training_opt = {'feature.minfreq': 20, 'c2': 4.}
        self.tagger = CRFTagger(feature_func=feat_func, training_opt=training_opt)
        if not model_file:
            raise Exception("Provide path to save model file")
        self.model_file = model_file
        if chunked_sents:
            self.train(chunked_sents)
        else:
            self.tagger.set_model_file(self.model_file)
Ejemplo n.º 2
0
 def __init__(self, config, task, vocab=None, parser=None):
     super(Substituting, self).__init__(vocab)
     self.task = task
     self.config = config
     self.index = self.get_index(config, vocab, parser)
     self.aug = get_blackbox_augmentor(config.blackbox_model,
                                       config.path,
                                       config.revised_rate,
                                       vocab=vocab,
                                       ftrain=config.ftrain)
     self.tag_dict = gen_tag_dict(Corpus.load(config.ftrain), vocab, 2,
                                  False)
     self.crf_tagger = CRFTagger()
Ejemplo n.º 3
0
class CRFChunkParser2(ChunkParserI):
    def __init__(self, chunked_sents=[], feature_func=None, model_file=None, training_opt={}):
 
        # Transform the trees in IOB annotated sentences [(word, pos, chunk), ...]
        # chunked_sents = [tree2conlltags(sent) for sent in chunked_sents]
 
        # Transform the triplets in pairs, make it compatible with the tagger interface [((word, pos), chunk), ...]
        def triplets2tagged_pairs(iob_sent):
            return [((word, pos), chunk) for word, pos, chunk in iob_sent]
        
        chunked_sents = [convert_scheme(sent) for sent in chunked_sents]
        chunked_sents = [triplets2tagged_pairs(sent) for sent in chunked_sents]
 
        if feature_func is not None:
            feat_func = feature_func
        else:
            feat_func = self._feature_detector
        training_opt = {'feature.minfreq': 20, 'c2': 4.}
        self.tagger = CRFTagger(feature_func=feat_func, training_opt=training_opt)
        if not model_file:
            raise Exception("Provide path to save model file")
        self.model_file = model_file
        if chunked_sents:
            self.train(chunked_sents)
        else:
            self.tagger.set_model_file(self.model_file)

    def train(self, chunked_sents):
        self.tagger.train(chunked_sents, self.model_file)
    
    def load(self, model_file):
        self.tagger.set_model_file(model_file)
 
    def parse(self, tagged_sent, return_tree = True):
        chunks = self.tagger.tag(tagged_sent)

 
        # Transform the result from [((w1, t1), iob1), ...] 
        # to the preferred list of triplets format [(w1, t1, iob1), ...]
        iob_triplets = revert_scheme([(w, t, c) for ((w, t), c) in chunks])
 
        # Transform the list of triplets to nltk.Tree format
        return conlltags2tree(iob_triplets) if return_tree else iob_triplets


    def _feature_detector(self, tokens, index):
        def shape(word):
            if re.match('[0-9]+(\.[0-9]*)?|[0-9]*\.[0-9]+$', word, re.UNICODE):
                return 'number'
            elif re.match('\W+$', word, re.UNICODE):
                return 'punct'
            elif re.match('\w+$', word, re.UNICODE):
                if word.istitle():
                    return 'upcase'
                elif word.islower():
                    return 'downcase'
                else:
                    return 'mixedcase'
            else:
                return 'other'


        def simplify_pos(s):
            if s.startswith('V'):
                return "V"
            else:
                return s.split('-')[0]

        word = tokens[index][0]
        pos = simplify_pos(tokens[index][1])

        if index == 0:
            prevword = prevprevword = ""
            prevpos = prevprevpos = ""
            prevshape = ""
        elif index == 1:
            prevword = tokens[index - 1][0].lower()
            prevprevword = ""
            prevpos = simplify_pos(tokens[index - 1][1])
            prevprevpos = ""
            prevshape = ""
        else:
            prevword = tokens[index - 1][0].lower()
            prevprevword = tokens[index - 2][0].lower()
            prevpos = simplify_pos(tokens[index - 1][1])
            prevprevpos = simplify_pos(tokens[index - 2][1])
            prevshape = shape(prevword)
        if index == len(tokens) - 1:
            nextword = nextnextword = ""
            nextpos = nextnextpos = ""
        elif index == len(tokens) - 2:
            nextword = tokens[index + 1][0].lower()
            nextpos = tokens[index + 1][1].lower()
            nextnextword = ""
            nextnextpos = ""
        else:
            nextword = tokens[index + 1][0].lower()
            nextpos = tokens[index + 1][1].lower()
            nextnextword = tokens[index + 2][0].lower()
            nextnextpos = tokens[index + 2][1].lower()

        def get_suffix_prefix(wordm, length):
            if len(word)>length:
                pref3 = word[:length].lower()
                suf3 = word[-length:].lower()
            else:
                pref3 = ""
                suf3 = ""
            return pref3, suf3

        suf_pref_lengths = [2,3]
        words = {
            'word': {'w': word, 'pos': pos, 'shape': shape(word)},
            'nword': {'w': nextword, 'pos': nextpos, 'shape': shape(nextword)},
            'nnword': {'w': nextnextword, 'pos': nextnextpos, 'shape': shape(nextnextword)},
            'pword': {'w': prevword, 'pos': prevpos, 'shape': shape(prevprevword)},
            'ppword': {'w': prevprevword, 'pos': prevprevpos, 'shape': shape(prevprevword)}
        }

        base_features = {}
        for word_position in words:
            for item in words[word_position]:
                base_features[word_position+"."+item] = words[word_position][item]

        prefix_suffix_features = {}
        for word_position in words:
            for l in suf_pref_lengths:
                feature_name_base = word_position+"."+repr(l)+"."
                pref, suf = get_suffix_prefix(words[word_position]['w'], l)
                prefix_suffix_features[feature_name_base+'pref'] = pref
                prefix_suffix_features[feature_name_base+'suf'] = suf
                prefix_suffix_features[feature_name_base+'pref.suf'] = '{}+{}'.format(pref, suf)
                prefix_suffix_features[feature_name_base+'posfix'] = '{}+{}+{}'.format(pref, words[word_position]['pos'], suf)
                prefix_suffix_features[feature_name_base+'shapefix'] = '{}+{}+{}'.format(pref, words[word_position]['shape'], suf)

        # pref3, suf3 = get_suffix_prefix(word)
        # prevpref3, prevsuf3 = get_suffix_prefix(prevword)
        # prevprevpref3, prevprevsuf3 = get_suffix_prefix(prevprevword)
        # nextpref3, nextsuf3 = get_suffix_prefix(nextword)
        # nextnextpref3, nextnextsuf3 = get_suffix_prefix(nextnextword)

        # postfix = '{}+{}+{}'.format(pref3, pos, suf3)

        # 89.6
        features = {
            # 'shape': shape(word),
            # 'wordlen': len(word),
            # 'prefix3': pref3,
            # 'suffix3': suf3,

            'pos': pos,
            'prevpos': prevpos,
            'nextpos': nextpos,
            'prevprevpos': prevprevpos,
            'nextnextpos': nextnextpos,

            # 'posfix': '{}+{}+{}'.format(pref3, pos, suf3),
            # 'prevposfix': '{}+{}+{}'.format(prevpref3, prevpos, prevsuf3),
            # 'prevprevposfix': '{}+{}+{}'.format(prevprevpref3, prevprevpos, prevprevsuf3),
            # 'nextposfix': '{}+{}+{}'.format(nextpref3, nextpos, nextsuf3),
            # 'nextnextposfix': '{}+{}+{}'.format(nextpref3, nextpos, nextsuf3),

            # 'word': word,
            # 'prevword': '{}'.format(prevword),
            # 'nextword': '{}'.format(nextword),
            # 'prevprevword': '{}'.format(prevprevword),
            # 'nextnextword': '{}'.format(nextnextword),

            # 'word+nextpos': '{0}+{1}'.format(postfix, nextpos),
            # 'word+nextnextpos': '{0}+{1}'.format(postfix, nextnextpos),
            # 'word+prevpos': '{0}+{1}'.format(postfix, prevpos),
            # 'word+prevprevpos': '{0}+{1}'.format(postfix, prevprevpos),
            
            'pos+nextpos': '{0}+{1}'.format(pos, nextpos),
            'pos+nextnextpos': '{0}+{1}'.format(pos, nextnextpos),
            'pos+prevpos': '{0}+{1}'.format(pos, prevpos),
            'pos+prevprevpos': '{0}+{1}'.format(pos, prevprevpos),
        }

        features.update(base_features)
        features.update(prefix_suffix_features)

        # return list(features.values())
        return features
Ejemplo n.º 4
0
import os
from argparse import ArgumentParser
from operator import itemgetter

import pandas as pd
from nltk import CRFTagger

THIS_FILE_DIR = os.path.dirname(__file__)
DEEP_DISFLUENCY_FOLDER = os.path.join(THIS_FILE_DIR, 'deep_disfluency')
TAGGER_PATH = os.path.join(DEEP_DISFLUENCY_FOLDER,
                           'deep_disfluency/feature_extraction/crfpostagger')

POS_TAGGER = CRFTagger()
POS_TAGGER.set_model_file(TAGGER_PATH)


def pos_tag(in_tokens):
    tags = POS_TAGGER.tag(in_tokens)
    return map(itemgetter(1), tags)


def configure_argument_parser():
    parser = ArgumentParser(description='POS tag dataset')
    parser.add_argument('dataset')
    parser.add_argument('result_file')

    return parser


def main(in_src_file, in_result_file):
    dataset = pd.read_json(in_src_file)
Ejemplo n.º 5
0
class Substituting(BlackBoxMethod):
    def __init__(self, config, task, vocab=None, parser=None):
        super(Substituting, self).__init__(vocab)
        self.task = task
        self.config = config
        self.index = self.get_index(config, vocab, parser)
        self.aug = get_blackbox_augmentor(config.blackbox_model,
                                          config.path,
                                          config.revised_rate,
                                          vocab=vocab,
                                          ftrain=config.ftrain)
        self.tag_dict = gen_tag_dict(Corpus.load(config.ftrain), vocab, 2,
                                     False)
        self.crf_tagger = CRFTagger()
        self.crf_tagger.set_model_file(config.crf_tagger_path)

    def get_index(self, config, vocab=None, parser=None):
        if config.mode == 'augmentation':
            return AttackIndexRandomGenerator(config)
        if config.blackbox_index == 'pos':
            return AttackIndexPosTag(config)
        else:
            if parser is None and vocab is None:
                print('unk replacement can not missing dpattack and vocab')
                exit()
            return AttackIndexUnkReplacement(config,
                                             vocab=vocab,
                                             parser=parser)

    def generate_attack_seq(self,
                            seqs,
                            seq_idx,
                            tags,
                            tag_idx,
                            chars,
                            arcs,
                            rels,
                            mask,
                            raw_metric=None):
        # generate word index to be attacked
        attack_index = self.index.get_attack_index(self.copy_str_to_list(seqs),
                                                   seq_idx, tags, tag_idx,
                                                   chars, arcs, mask)
        # generate word candidates to be attacked
        candidates, indexes = self.substituting(seqs, attack_index)
        # check candidates by pos_tagger
        candidates, indexes = self.check_pos_tag(seqs, tags, candidates,
                                                 indexes)
        attack_seq, revised_number = self.check_uas(seqs, tag_idx, arcs, rels,
                                                    candidates, indexes,
                                                    raw_metric)

        return [Corpus.ROOT
                ] + attack_seq, tag_idx, mask, arcs, rels, revised_number

    def substituting(self, seq, index):
        try:
            # generate the attack sentence by index
            candidates, revised_indexes = self.aug.substitute(seq,
                                                              aug_idxes=index,
                                                              n=99)
        except Exception:
            try:
                # if error happens, generate the attack sentence by random
                candidates, revised_indexes = self.aug.substitute(seq)
            except Exception:
                candidates = None
                revised_indexes = []
        return candidates, revised_indexes

    def update_mask_arc_rel(self, mask, arc, rel, revised_list):
        return mask, arc, rel

    def check_pos_tag(self, seqs, tags, origin_candidates, indexes):
        tag_check_candidates = []
        tag_check_indexes = []
        for index, candidate in zip(indexes, origin_candidates):
            candidate = self.check_pos_tag_under_each_index(
                seqs, tags, candidate, index)
            if len(candidate) != 0:
                tag_check_indexes.append(index)
                tag_check_candidates.append(candidate)
        return tag_check_candidates, tag_check_indexes

    def check_pos_tag_under_each_index(self, seqs, tags, candidate, index):
        if self.config.blackbox_tagger == 'dict':
            if tags[index + 1] not in self.tag_dict:
                return []

            word_list_with_same_tag = self.tag_dict[tags[index + 1]]
            tag_check_candidate = []
            for i, cand in enumerate(candidate):
                cand_idx = self.vocab.word_dict.get(cand.lower(),
                                                    self.vocab.unk_index)
                if cand_idx != self.vocab.unk_index:
                    if cand_idx in word_list_with_same_tag:
                        tag_check_candidate.append(cand)
                        if len(tag_check_candidate
                               ) > self.config.blackbox_candidates:
                            break
            return tag_check_candidate
        elif self.config.blackbox_tagger == 'crf':
            tag_check_candidate = []
            sents = self.duplicate_sentence_with_candidate_replacement(
                seqs, candidate, index)
            word_tag_list = self.crf_tagger.tag_sents(sents)
            for count, word_tag in enumerate(word_tag_list):
                if word_tag[index][1] == tags[index + 1]:
                    tag_check_candidate.append(candidate[count])
                    if len(tag_check_candidate
                           ) > self.config.blackbox_candidates:
                        break
            return tag_check_candidate

    def check_uas(self, seqs, tag_idx, arcs, rels, candidates, indexes,
                  raw_metric):
        final_attack_seq = self.copy_str_to_list(seqs)
        revised_number = 0
        for index, candidate in zip(indexes, candidates):
            index_flag = self.check_uas_under_each_index(
                seqs, tag_idx, arcs, rels, candidate, index, raw_metric)
            final_attack_seq[index] = candidate[index_flag]
            revised_number += 1
        return final_attack_seq, revised_number

    def check_uas_under_each_index(self, seqs, tag_idx, arcs, rels, candidate,
                                   index, raw_metric):
        current_compare_uas = raw_metric.uas
        current_index = CONSTANT.FALSE_TOKEN
        for i, cand in enumerate(candidate):
            attack_seqs = self.copy_str_to_list(seqs)
            attack_seqs[index] = cand
            attack_metric = self.get_metric_by_seqs(attack_seqs, tag_idx, arcs,
                                                    rels)
            if attack_metric.uas < current_compare_uas:
                current_compare_uas = attack_metric.uas
                current_index = i
        if current_index == CONSTANT.FALSE_TOKEN:
            return 0
        else:
            return current_index

    def get_metric_by_seqs(self, attack_seqs, tag_idx, arcs, rels):
        attack_seq_idx = self.vocab.word2id([Corpus.ROOT] +
                                            attack_seqs).unsqueeze(0)
        if torch.cuda.is_available():
            attack_seq_idx = attack_seq_idx.cuda()
        if is_chars_judger(self.task.model):
            attack_chars = self.get_chars_idx_by_seq(attack_seqs)
            _, attack_metric = self.task.evaluate(
                [(attack_seq_idx, None, attack_chars, arcs, rels)],
                mst=self.config.mst)
        else:
            attack_tag_idx = tag_idx.clone()
            _, attack_metric = self.task.evaluate(
                [(attack_seq_idx, attack_tag_idx, None, arcs, rels)],
                mst=self.config.mst)
        return attack_metric

    def get_chars_idx_by_seq(self, sentence):
        chars = self.vocab.char2id(sentence).unsqueeze(0)
        if torch.cuda.is_available():
            chars = chars.cuda()
        return chars
Ejemplo n.º 6
0
 def crf_tagger(self) -> CRFTagger:
     if self.__crf_tagger is None:
         self.__crf_tagger = CRFTagger()
         self.__crf_tagger.set_model_file(self.config.crf_tagger_path)
     return self.__crf_tagger
Ejemplo n.º 7
0
class IHack:
    def __init__(self):
        self.config: Config

        self.train_corpus: Corpus
        self.corpus: Corpus

        self.task: ParserTask

        self.__nn_tagger: PosTagger = None
        self.__trigram_tagger: TrigramTagger = None
        self.__crf_tagger: CRFTagger = None
        self.__tag_dict: dict = None
        self.__bert_aug: ContextualWordEmbsAug = None

        self.embed_searcher: EmbeddingSearcher

        self.loader: DataLoader

    @property
    def nn_tagger(self) -> PosTagger:
        if self.__nn_tagger is None:
            self.__nn_tagger = PosTagger.load(
                fetch_best_ckpt_name(self.config.tagger_model))
        return self.__nn_tagger

    @property
    def trigram_tagger(self) -> TrigramTagger:
        if self.__trigram_tagger is None:
            self.__trigram_tagger = auto_create(
                "trigram_tagger",
                lambda: train_gram_tagger(self.train_corpus, ngram=3),
                cache=True,
                path=self.config.workspace + '/saved_vars')
        return self.__trigram_tagger

    @property
    def crf_tagger(self) -> CRFTagger:
        if self.__crf_tagger is None:
            self.__crf_tagger = CRFTagger()
            self.__crf_tagger.set_model_file(self.config.crf_tagger_path)
        return self.__crf_tagger

    @property
    def tag_dict(self) -> dict:
        if self.__tag_dict is None:
            self.__tag_dict = auto_create(
                "tagdict3",
                lambda: gen_tag_dict(self.train_corpus, self.vocab, 3, False),
                cache=True,
                path=self.config.workspace + '/saved_vars')
            self.__tag_dict = {
                k: torch.tensor(v)
                for k, v in self.tag_dict.items()
            }
        return self.__tag_dict

    @property
    def bert_aug(self) -> ContextualWordEmbsAug:
        if self.__bert_aug is None:
            self.__bert_aug = ContextualWordEmbsAug(
                model_path=self.config.path, top_k=2048)
        return self.__bert_aug

    @property
    def vocab(self) -> Vocab:
        return self.task.vocab

    @property
    def parser(self) -> Union[WordTagParser, WordParser]:
        return self.task.model

    def init_logger(self, config):
        if config.logf == 'on':
            if config.hk_use_worker == 'on':
                worker_info = "-{}@{}".format(config.hk_num_worker,
                                              config.hk_worker_id)
            else:
                worker_info = ""
            log_config('{}{}'.format(config.mode, worker_info),
                       log_path=config.workspace,
                       default_target='cf')
            from dpattack.libs.luna import log
        else:
            log = print

        log('[General Settings]')
        log(config)
        log('[Hack Settings]')
        for arg in config.kwargs:
            if arg.startswith('hk'):
                log(arg, '\t', config.kwargs[arg])
        log('------------------')

    def setup(self, config):
        self.config = config

        print("Load the models")
        vocab = torch.load(config.vocab)  # type: Vocab
        parser = load_parser(fetch_best_ckpt_name(config.parser_model))

        self.task = ParserTask(vocab, parser)

        print("Load the dataset")

        self.train_corpus = Corpus.load(config.ftrain)

        if config.hk_training_set == 'on':
            self.corpus = self.train_corpus
        else:
            self.corpus = Corpus.load(config.fdata)
        dataset = TextDataset(vocab.numericalize(self.corpus, True))
        # set the data loader
        self.loader = DataLoader(dataset=dataset, collate_fn=collate_fn)

        def embed_backward_hook(module, grad_in, grad_out):
            ram_write('embed_grad', grad_out[0])

        self.parser.embed.register_backward_hook(embed_backward_hook)
        self.parser.eval()

        self.embed_searcher = EmbeddingSearcher(
            embed=self.parser.embed.weight,
            idx2word=lambda x: self.vocab.words[x],
            word2idx=lambda x: self.vocab.word_dict[x])

        random.seed(self.config.seed)
        np.random.seed(self.config.seed)
        torch.manual_seed(self.config.seed)

    def hack(self, instance, **kwargs):
        raise NotImplementedError

    @lru_cache(maxsize=None)
    def _gen_tag_mask(self, tags: tuple, tsr_device, tsr_size):
        word_idxs = []
        for tag in tags:
            if tag in self.tag_dict:
                word_idxs.extend(self.tag_dict[tag])
        legal_tag_index = torch.tensor(word_idxs, device=tsr_device).long()
        legal_tag_mask = torch.zeros(tsr_size, device=tsr_device)\
            .index_fill_(0, legal_tag_index, 1.).byte()
        return legal_tag_mask

    @lru_cache(maxsize=10)
    def _gen_bert_mask(self, text, idx, tsr_device, tsr_size):
        bert_sub, _ = self.bert_aug.substitute(text, [idx], n=2000)
        bert_sub_idxs = self.vocab.word2id(bert_sub[0]).to(tsr_device).long()
        bert_mask = torch.zeros(tsr_size, device=tsr_device)\
            .index_fill_(0, bert_sub_idxs, 1.).byte()
        return bert_mask

    @torch.no_grad()
    def find_replacement(
        self,
        changed,
        must_tags,
        dist_measure,
        forbidden_idxs__,
        repl_method='tagdict',
        words=None,
        word_sid=None,  # Only need when using a tagger
        raw_words=None,
    ) -> (Optional[torch.Tensor], dict):
        if must_tags is None:
            must_tags = tuple(self.vocab.tags)
        if isinstance(must_tags, str):
            must_tags = (must_tags, )

        if repl_method == 'lstm':
            # Pipeline:
            #    256 minimum dists
            # -> Filtered by a NN tagger
            # -> Smallest one
            words = words.repeat(64, words.size(1))
            dists, idxs = self.embed_searcher.find_neighbours(
                changed, 64, dist_measure, False)
            for i, ele in enumerate(idxs):
                words[i][word_sid] = ele
            self.nn_tagger.eval()
            s_tags = self.nn_tagger(words)
            pred_tags = s_tags.argmax(-1)[:, word_sid]
            pred_tags = pred_tags.cpu().numpy().tolist()
            new_word_vid = None
            for i, ele in enumerate(pred_tags):
                if self.vocab.tags[ele] in must_tags:
                    if idxs[i] not in forbidden_idxs__:
                        new_word_vid = idxs[i]
                        break
            return new_word_vid, {
                "avgd": dists.mean().item(),
                "mind": dists.min().item()
            }
        elif repl_method in ['3gram', 'crf']:
            # Pipeline:
            #    256 minimum dists
            # -> Filtered by a Statistical tagger
            # -> Smallest one
            tagger = self.trigram_tagger if repl_method == '3gram' else self.crf_tagger
            word_texts = self.vocab.id2word(words)
            word_sid = word_sid.item()

            dists, idxs = self.embed_searcher.find_neighbours(
                changed, 64, dist_measure, False)

            cands = []
            for ele in cast_list(idxs):
                cand = word_texts.copy()
                cand[word_sid] = self.vocab.words[ele]
                cands.append(cand)

            pred_tags = tagger.tag_sents(cands)
            s_tags = [ele[word_sid][1] for ele in pred_tags]

            new_word_vid = None
            for i, ele in enumerate(s_tags):
                if ele in must_tags:
                    if idxs[i] not in forbidden_idxs__:
                        new_word_vid = idxs[i]
                        break
            return new_word_vid, {
                "avgd": dists.mean().item(),
                "mind": dists.min().item()
            }
        elif repl_method in ['tagdict', 'bertag']:
            # Pipeline:
            #    All dists/Bert filtered dists
            # -> Filtered by a tag dict
            # -> Smallest one
            dist = {
                'euc': euc_dist,
                'cos': cos_dist
            }[dist_measure](changed, self.parser.embed.weight)

            # Mask illegal words by its POS
            if repl_method == 'tagdict':
                msk = self._gen_tag_mask(must_tags, dist.device, dist.size())
            elif repl_method == 'bertag':
                msk = self._gen_tag_mask(must_tags, dist.device, dist.size())
                # Mask illegal words by BERT
                bert_mask = self._gen_bert_mask(
                    " ".join(self.vocab.id2word(raw_words)[1:]),
                    word_sid.item() - 1, dist.device, dist.size())
                # F**k pytorch.
                msk = msk * bert_mask
            else:
                raise Exception

            dist.masked_fill_((1 - msk).bool(), 1000.)
            for ele in forbidden_idxs__:
                dist[ele] = 1000.
            mindist = dist.min()
            if abs(mindist - 1000.) < 0.001:
                new_word_vid = None
            else:
                new_word_vid = dist.argmin()
            return new_word_vid, {}
        else:
            raise NotImplementedError
Ejemplo n.º 8
0
class CRFChunkParser(ChunkParserI):
    def __init__(self, chunked_sents=[], feature_func=None, model_file=None, training_opt={}):
 
        # Transform the trees in IOB annotated sentences [(word, pos, chunk), ...]
        # chunked_sents = [tree2conlltags(sent) for sent in chunked_sents]
 
        # Transform the triplets in pairs, make it compatible with the tagger interface [((word, pos), chunk), ...]
        def triplets2tagged_pairs(iob_sent):
            return [((word, pos), chunk) for word, pos, chunk in iob_sent]
        chunked_sents = [triplets2tagged_pairs(sent) for sent in chunked_sents]
 
        if feature_func is not None:
            feat_func = feature_func
        else:
            feat_func = self._feature_detector
        self.tagger = CRFTagger(feature_func=feat_func, training_opt=training_opt)
        if not model_file:
            raise Exception("Provide path to save model file")
        self.model_file = model_file
        if chunked_sents:
            self.train(chunked_sents)
        else:
            self.tagger.set_model_file(self.model_file)

    def train(self, chunked_sents):
        self.tagger.train(chunked_sents, self.model_file)
    
    def load(self, model_file):
        self.tagger.set_model_file(model_file)
 
    def parse(self, tagged_sent, return_tree = True):
        chunks = self.tagger.tag(tagged_sent)
 
        # Transform the result from [((w1, t1), iob1), ...] 
        # to the preferred list of triplets format [(w1, t1, iob1), ...]
        iob_triplets = [(w, t, c) for ((w, t), c) in chunks]
 
        # Transform the list of triplets to nltk.Tree format
        return conlltags2tree(iob_triplets) if return_tree else iob_triplets


    def _feature_detector(self, tokens, index):
        def shape(word):
            if re.match('[0-9]+(\.[0-9]*)?|[0-9]*\.[0-9]+$', word, re.UNICODE):
                return 'number'
            elif re.match('\W+$', word, re.UNICODE):
                return 'punct'
            elif re.match('\w+$', word, re.UNICODE):
                if word.istitle():
                    return 'upcase'
                elif word.islower():
                    return 'downcase'
                else:
                    return 'mixedcase'
            else:
                return 'other'


        def simplify_pos(s):
            if s.startswith('V'):
                return "V"
            else:
                return s.split('-')[0]

        word = tokens[index][0]
        pos = simplify_pos(tokens[index][1])
        if index == 0:
            prevword = prevprevword = ""
            prevpos = prevprevpos = ""
            prevshape = prevtag = prevprevtag = ""
        elif index == 1:
            prevword = tokens[index - 1][0].lower()
            prevprevword = ""
            prevpos = simplify_pos(tokens[index - 1][1])
            prevprevpos = ""
            prevtag = "" #history[index - 1][0]
            prevshape = prevprevtag = ""
        else:
            prevword = tokens[index - 1][0].lower()
            prevprevword = tokens[index - 2][0].lower()
            prevpos = simplify_pos(tokens[index - 1][1])
            prevprevpos = simplify_pos(tokens[index - 2][1])
            prevtag = "" #history[index - 1]
            prevprevtag = "" #history[index - 2]
            prevshape = shape(prevword)
        if index == len(tokens) - 1:
            nextword = nextnextword = ""
            nextpos = nextnextpos = ""
        elif index == len(tokens) - 2:
            nextword = tokens[index + 1][0].lower()
            nextpos = tokens[index + 1][1].lower()
            nextnextword = ""
            nextnextpos = ""
        else:
            nextword = tokens[index + 1][0].lower()
            nextpos = tokens[index + 1][1].lower()
            nextnextword = tokens[index + 2][0].lower()
            nextnextpos = tokens[index + 2][1].lower()

        # 89.6
        features = {
            'shape': '{}'.format(shape(word)),
            'wordlen': '{}'.format(len(word)),
            'prefix3': word[:3].lower(),
            'suffix3': word[-3:].lower(),
            'pos': pos,
            'word': word,
            # 'prevtag': '{}'.format(prevtag),
            'prevpos': '{}'.format(prevpos),
            'nextpos': '{}'.format(nextpos),
            'prevword': '{}'.format(prevword),
            'nextword': '{}'.format(nextword),
            'prevprevword': '{}'.format(prevprevword),
            'nextnextword': '{}'.format(nextnextword),
            'word+nextpos': '{0}+{1}'.format(word.lower(), nextpos),
            'word+nextnextpos': '{0}+{1}'.format(word.lower(), nextnextpos),
            'word+prevpos': '{0}+{1}'.format(word.lower(), prevpos),
            'word+prevprevpos': '{0}+{1}'.format(word.lower(), prevprevpos),
            'pos+nextpos': '{0}+{1}'.format(pos, nextpos),
            'pos+nextnextpos': '{0}+{1}'.format(pos, nextnextpos),
            'pos+prevpos': '{0}+{1}'.format(pos, prevpos),
            'pos+prevprevpos': '{0}+{1}'.format(pos, prevprevpos),
            # 'pos+prevtag': '{0}+{1}'.format(pos, prevtag),
            # 'shape+prevtag': '{0}+{1}'.format(prevshape, prevtag),
        }

        return list(features.values())
from nltk import CRFTagger

if __name__ == '__main__':
    filename = "Indonesian_Manually_Tagged_Corpus.tsv"
    with open(filename, "r", encoding="utf-8") as f:
        datas = f.read().split("\n\n")
    taggedkalimat = []
    for data in datas[:len(datas)]:
        kalimat = data.split("\n")
        taggedkata = []
        for kata in kalimat[:]:
            k, tag = kata.split("\t")
            kata_tag = (k, tag)
            taggedkata.append(kata_tag)
        taggedkalimat.append(taggedkata)
    ctagger = CRFTagger()
    modelname = input("Save model as : ")
    ctagger.train(taggedkalimat, modelname)
    print("Generated model from %s into %s" % (filename, modelname))
    print("Usage :")
    print("\t\tnltk.CRFTagger().set_model_file(your_model)")
    print("\t\tnltk.CRFTagger().tag_sents(list(your_sentence))")