Esempio n. 1
0
def prepro(config):
    word_counter, char_counter, bpe_counter, pos_counter = Counter(
    ), None, None, None
    bpe_model = None
    pos_model = None
    if config.use_bpe:
        if not config.use_bpe_pretrained_codes:
            # train bpe on train set
            train_bpe(config)
            bpe_model = BPE(open(config.bpe_codes_file, 'r'))
        else:
            print('Loading BPE codes: {}'.format(
                config.bpe_pretrained_codes_file))
            bpe_model = BPE(open(config.bpe_pretrained_codes_file, 'r'))
        bpe_counter = Counter()

    if config.use_char:
        char_counter = Counter()

    if config.use_pos:
        pos_model = Mystem()
        pos_counter = Counter()

    train_examples, train_eval = process_file(config,
                                              config.train_file,
                                              "train",
                                              word_counter,
                                              char_counter,
                                              bpe_counter,
                                              pos_counter,
                                              config.remove_unicode,
                                              bpe_model,
                                              pos_model,
                                              is_test=False)
    dev_examples, dev_eval = process_file(config,
                                          config.dev_file,
                                          "dev",
                                          word_counter,
                                          char_counter,
                                          bpe_counter,
                                          pos_counter,
                                          config.remove_unicode,
                                          bpe_model,
                                          pos_model,
                                          is_test=False)
    test_examples, test_eval = process_file(
        config,
        config.test_file,
        "test",
        remove_unicode=config.remove_unicode,
        bpe_model=bpe_model,
        pos_model=pos_model,
        is_test=True)

    word_emb_file = config.fasttext_file if config.fasttext else config.glove_word_file
    char_emb_file = config.glove_char_file if config.pretrained_char else None
    char_emb_size = config.glove_char_size if config.pretrained_char else None
    char_emb_dim = config.glove_dim if config.pretrained_char else config.char_dim
    bpe_emb_file = config.glove_bpe_file if config.pretrained_bpe_emb else None
    bpe_emb_size = config.glove_bpe_size if config.pretrained_bpe_emb else None
    bpe_emb_dim = config.bpe_glove_dim if config.pretrained_bpe_emb else config.bpe_dim

    word_emb_mat, word2idx_dict = get_embedding(word_counter,
                                                "word",
                                                emb_file=word_emb_file,
                                                size=config.glove_word_size,
                                                vec_size=config.glove_dim)
    char_emb_mat, char2idx_dict = None, None
    if config.use_char:
        char_emb_mat, char2idx_dict = get_embedding(char_counter,
                                                    "char",
                                                    emb_file=char_emb_file,
                                                    size=char_emb_size,
                                                    vec_size=char_emb_dim)
    bpe_emb_mat, bpe2idx_dict = None, None
    if config.use_bpe:
        bpe_emb_mat, bpe2idx_dict = get_embedding(bpe_counter,
                                                  "bpe",
                                                  emb_file=bpe_emb_file,
                                                  size=bpe_emb_size,
                                                  vec_size=bpe_emb_dim)

    pos_emb_mat, pos2idx_dict = None, None
    if config.use_pos:
        pos_emb_mat, pos2idx_dict = get_embedding(pos_counter,
                                                  "pos",
                                                  emb_file=None,
                                                  size=None,
                                                  vec_size=config.pos_dim)

    pickle.dump(word2idx_dict, open(config.word2idx_dict_file, 'wb'))
    pickle.dump(char2idx_dict, open(config.char2idx_dict_file, 'wb'))
    pickle.dump(bpe2idx_dict, open(config.bpe2idx_dict_file, 'wb'))
    pickle.dump(pos2idx_dict, open(config.pos2idx_dict_file, 'wb'))

    build_features(config, train_examples, "train", config.train_record_file,
                   word2idx_dict, char2idx_dict, bpe2idx_dict, pos2idx_dict)
    dev_meta = build_features(config, dev_examples, "dev",
                              config.dev_record_file, word2idx_dict,
                              char2idx_dict, bpe2idx_dict, pos2idx_dict)
    test_meta = build_features(config,
                               test_examples,
                               "test",
                               config.test_record_file,
                               word2idx_dict,
                               char2idx_dict,
                               bpe2idx_dict,
                               pos2idx_dict,
                               is_test=True)

    save(config.word_emb_file, word_emb_mat, message="word embedding")
    save(config.char_emb_file, char_emb_mat, message="char embedding")
    save(config.bpe_emb_file, bpe_emb_mat, message="bpe embedding")
    save(config.pos_emb_file, pos_emb_mat, message="pos embedding")
    save(config.train_eval_file, train_eval, message="train eval")
    save(config.dev_eval_file, dev_eval, message="dev eval")
    save(config.test_eval_file, test_eval, message="test eval")
    save(config.dev_meta, dev_meta, message="dev meta")
    save(config.test_meta, test_meta, message="test meta")
Esempio n. 2
0
    

build a UnsupData which returns v_d, v_f for a batch

supTrainData which return v_d, v_f, label for DDI only 
supTrainData.num_of_iter_in_a_epoch contains iteration in an epoch

ValData which return v_d, v_f, label for DDI only 

'''

dataFolder = './data'

vocab_path = dataFolder + '/codes.txt'
bpe_codes_fin = codecs.open(vocab_path)
bpe = BPE(bpe_codes_fin, merges=-1, separator='')

vocab_map = pd.read_csv(dataFolder + '/subword_units_map.csv')
idx2word = vocab_map['index'].values
words2idx = dict(zip(idx2word, range(0, len(idx2word))))
max_set = 30


def smiles2index(s1, s2):
    t1 = bpe.process_line(s1).split()  #split
    t2 = bpe.process_line(s2).split()  #split
    i1 = [words2idx[i] for i in t1]  # index
    i2 = [words2idx[i] for i in t2]  # index
    return i1, i2

Esempio n. 3
0
def load_bpe(bpe_size: int = 32000):
    bpe_codes = codecs.open(os.path.join(shared.VOCAB_FOLDER, f"bpe.{bpe_size}"), encoding='utf-8')
    return BPE(bpe_codes)
Esempio n. 4
0
    else:
        input = request.args.get('in')
    score, pred = _translate(input)
    print(' score = %s , pred = %s ' % (str(score), str(pred)))
    res = {"output": pred}
    return json.dumps(res, ensure_ascii=False)


if __name__ == '__main__':

    parser = _get_parser()
    opt = parser.parse_args()

    logger = init_logger(opt.log_file)

    translator = _get_translator(opt)

    c = codecs.open(
        '/root/workspace/translate_data/my_corpus_v6.en.tok.processed6-bpe-code',
        encoding='utf-8')
    m = -1
    sp = '@@'
    voc = None
    bpe = BPE(c, m, sp, voc, None)

    proc = PrePostProc()
    proc.load_data('py_ent_dict.txt')

    app.debug = True
    app.run(host='0.0.0.0', port=5002)
Esempio n. 5
0
    def initialize(self,
                   data_dir=_data_dir,
                   model_path=_model,
                   user_dir=_user_dir,
                   task='xmasked_seq2seq',
                   s_lang='en',
                   t_lang='zh',
                   beam=5,
                   cpu=False,
                   align_dict=None,
                   bpe_codes=_bpe_codes_en,
                   tokenizer=True):
        self.parser = options.get_generation_parser(interactive=True)
        self.src, self.tgt = s_lang, t_lang

        # generate args
        input_args = [data_dir, '--path', model_path]
        if cpu:
            input_args.append('--cpu')
        if user_dir:
            input_args.append('--user-dir')
            input_args.append(user_dir)
        if task:
            input_args.append('--task')
            input_args.append(task)
        if align_dict:
            input_args.append('--replace-unk')
            input_args.append(align_dict)
        input_args.append('--langs')
        input_args.append('{},{}'.format(s_lang, t_lang))
        input_args.append('--source-langs')
        input_args.append(s_lang)
        input_args.append('--target-langs')
        input_args.append(t_lang)
        input_args.append('-s')
        input_args.append(s_lang)
        input_args.append('-t')
        input_args.append(t_lang)
        input_args.append('--beam')
        input_args.append(str(beam))
        input_args.append('--remove-bpe')

        self.bpe = BPE(open(bpe_codes, 'r'))
        self.tokenizer = tokenizer

        self.args = options.parse_args_and_arch(self.parser,
                                                input_args=input_args)

        # initialize model
        utils.import_user_module(self.args)

        if self.args.buffer_size < 1:
            self.args.buffer_size = 1
        if self.args.max_tokens is None and self.args.max_sentences is None:
            self.args.max_sentences = 1

        assert not self.args.sampling or self.args.nbest == self.args.beam, \
            '--sampling requires --nbest to be equal to --beam'
        assert not self.args.max_sentences or self.args.max_sentences <= self.args.buffer_size, \
            '--max-sentences/--batch-size cannot be larger than --buffer-size'

        self.use_cuda = torch.cuda.is_available() and not self.args.cpu

        # Setup task, e.g., translation
        self.task = tasks.setup_task(self.args)

        # Load ensemble
        self.models, _model_args = checkpoint_utils.load_model_ensemble(
            self.args.path.split(':'),
            arg_overrides=eval(self.args.model_overrides),
            task=self.task,
        )

        # Set dictionaries
        self.src_dict = self.task.source_dictionary
        self.tgt_dict = self.task.target_dictionary

        # Optimize ensemble for generation
        for model in self.models:
            model.make_generation_fast_(
                beamable_mm_beam_size=None
                if self.args.no_beamable_mm else self.args.beam,
                need_attn=self.args.print_alignment,
            )
            if self.args.fp16:
                model.half()
            if self.use_cuda:
                model.cuda()

        # Initialize generator
        self.generator = self.task.build_generator(self.args)

        def encode_fn(x):
            if tokenizer:
                x = tokenize(x, is_zh=(s_lang == 'zh'))
            if bpe_codes:
                x = self.bpe.process_line(x)
            return x

        # Hack to support GPT-2 BPE
        if self.args.remove_bpe == 'gpt2':
            pass
        else:
            self.decoder = None
            # self.encode_fn = lambda x: x
            self.encode_fn = encode_fn

        # Load alignment dictionary for unknown word replacement
        # (None if no unknown word replacement, empty if no path to align dictionary)
        self.align_dict = utils.load_align_dict(self.args.replace_unk)

        self.max_positions = utils.resolve_max_positions(
            self.task.max_positions(),
            *[model.max_positions() for model in self.models])
Esempio n. 6
0
import torch
from torch.utils import data
from torch.autograd import Variable
from descriptastorus.descriptors import rdDescriptors, rdNormalizedDescriptors
from DeepPurpose.chemutils import get_mol, atom_features, bond_features, MAX_NB, ATOM_FDIM, BOND_FDIM
from subword_nmt.apply_bpe import BPE
import codecs
import pickle
import wget
from zipfile import ZipFile
import os

# ESPF encoding
vocab_path = './DeepPurpose/ESPF/drug_codes_chembl_freq_1500.txt'
bpe_codes_drug = codecs.open(vocab_path)
dbpe = BPE(bpe_codes_drug, merges=-1, separator='')
sub_csv = pd.read_csv(
    './DeepPurpose/ESPF/subword_units_map_chembl_freq_1500.csv')

idx2word_d = sub_csv['index'].values
words2idx_d = dict(zip(idx2word_d, range(0, len(idx2word_d))))

vocab_path = './DeepPurpose/ESPF/protein_codes_uniprot_2000.txt'
bpe_codes_protein = codecs.open(vocab_path)
pbpe = BPE(bpe_codes_protein, merges=-1, separator='')
#sub_csv = pd.read_csv(dataFolder + '/subword_units_map_protein.csv')
sub_csv = pd.read_csv('./DeepPurpose/ESPF/subword_units_map_uniprot_2000.csv')

idx2word_p = sub_csv['index'].values
words2idx_p = dict(zip(idx2word_p, range(0, len(idx2word_p))))
Esempio n. 7
0
import tensorflow as tf
import codecs
from subword_nmt.apply_bpe import BPE

codes = codecs.open("{codes_file}", encoding='utf-8')
# output = codecs.open(args.output.name, 'w', encoding='utf-8')
vocabulary = codecs.open("./100/{vocab_file}.L1", encoding='utf-8')
bpe = BPE(codes, 100, '@@', vocabulary)

# output = codecs.open(args.output.name, 'w', encoding='utf-8')
vocabulary_nl = codecs.open("./100/{vocab_file}.L2", encoding='utf-8')
bpe_nl = BPE(codes, 100, '@@', vocabulary_nl)

graph = tf.get_default_graph()
Esempio n. 8
0
class Preprocessor(object):
    def __init__(self, bpe_code_file):
        super(Preprocessor, self).__init__()

        symbols = ''
        symbol_set = set({})

        for k in name2codepoint.keys():
            symbol_set.add(k)

        for k in html5.keys():
            symbol_set.add(k.strip(';'))

        for s in symbol_set:
            symbols += '|' + s

        symbols = symbols.strip('|')

        self.single = re.compile('&[ ]?(' + symbols + ')[ ]?;', re.IGNORECASE)
        self.double = re.compile('&[ ]?amp[ ]?;[ ]?(' + symbols + ')[ ]?;',
                                 re.IGNORECASE)

        self.singleNum = re.compile('&[ ]?#[ ]?([0-9]+)[ ]?;', re.IGNORECASE)
        self.doubleNum = re.compile('&[ ]?amp[ ]?;[ ]?#[ ]?([0-9]+)[ ]?;',
                                    re.IGNORECASE)

        self.singleXNum = re.compile('&[ ]?#[ ]?x[ ]?([a-f0-9]+)[ ]?;',
                                     re.IGNORECASE)
        self.doubleXNum = re.compile(
            '&[ ]?amp[ ]?;[ ]?#[ ]?x[ ]?([a-f0-9]+)[ ]?;', re.IGNORECASE)

        self.nbsp = re.compile(
            '(&[ ]?x?[ ]?n[]?b[ ]?([a-z][ ]?){0,6}[ ]?;)|(&[ ]?o[ ]?s[ ]?p[ ]?;)',
            re.IGNORECASE)

        self.shy = re.compile('[ ]?&[ ]?s[ ]?h[ ]?y[ ]?;[ ]?', re.IGNORECASE)

        self.bpe = None
        if bpe_code_file:
            with open(bpe_code_file, mode='r', encoding='utf-8') as f:
                self.bpe = BPE(f)
        else:
            logging.error('No BPE code file specified')

    def unescape(self, line):
        # put html-escaped (or double escaped) codes back into canonical format
        line = re.sub(self.double, r'&\1;', line)
        line = re.sub(self.doubleNum, r'&#\1;', line)
        line = re.sub(self.doubleXNum, r'&#x\1;', line)
        line = re.sub(self.single, r'&\1;', line)
        line = re.sub(self.singleNum, r'&#\1;', line)
        line = re.sub(self.singleXNum, r'&#x\1;', line)

        # get rid of this tag
        # alphabetic characters -- need only get rid of space around their canonical escaped forms
        line = re.sub(self.shy, '', line)

        # unescape
        line = html.unescape(line)

        # clean up weird errors in the escaping of the non-breaking space
        line = re.sub(self.nbsp, ' ', line)
        return line

    def bpe_encode(self, text):
        return self.bpe.process_line(text).strip()
Esempio n. 9
0
class BeamSearch:
    def __init__(self,
                 decoder,
                 encoder_outputs,
                 decoder_hidden,
                 output_lang,
                 beam_size=3,
                 attentionOverrideMap=None,
                 correctionMap=None,
                 unk_map=None,
                 beam_length=0.5,
                 beam_coverage=0.5,
                 max_length=MAX_LENGTH):
        self.decoder = decoder
        self.encoder_outputs = encoder_outputs
        self.decoder_hidden = decoder_hidden
        self.beam_size = beam_size
        self.max_length = MAX_LENGTH
        self.output_lang = output_lang
        self.attention_override_map = attentionOverrideMap
        self.correction_map = correctionMap
        self.unk_map = unk_map
        self.beam_length = beam_length
        self.beam_coverage = beam_coverage
        self.bpe = BPE(open(hp.bpe_file))
        self.prefix = self.compute_prefix()
        self.process_corrections()

    def compute_prefix(self):
        if not self.correction_map:
            return
        print(self.correction_map)
        assert len(list(self.correction_map.keys())) == 1

        prefix = list(self.correction_map.keys())[0]
        correction = self.correction_map[prefix]

        prefix = prefix + " " + correction
        raw_words = prefix.split(" ")[1:]
        bpe_words = [
            self.bpe.process_line(word) if not word.endswith("@@") else word
            for word in raw_words
        ]
        words = [
            word for bpe_word in bpe_words for word in bpe_word.split(" ")
        ]
        return words

    def process_corrections(self):
        """Apply BPE to correction map, ignoring words that are already BPE'd"""
        if not self.correction_map:
            return
        prefixes = list(self.correction_map.keys())

        for prefix in prefixes:
            raw_words = self.correction_map.pop(prefix).split(" ")
            bpe_words = [
                self.bpe.process_line(word)
                if not word.endswith("@@") else word for word in raw_words
            ]
            words = [
                word for bpe_word in bpe_words for word in bpe_word.split(" ")
            ]

            for i in range(len(words)):
                curr_prefix = " ".join(prefix.split(" ") + words[:i])
                self.correction_map[curr_prefix] = words[i]

    def decode_topk(self, latest_tokens, states, last_attn_vectors, partials):
        """Decode all current hypotheses on the beam, returning len(hypotheses) x beam_size candidates"""

        # len(latest_tokens) x self.beam_size)
        topk_ids = [[0 for _ in range(self.beam_size)]
                    for _ in range(len(latest_tokens))]
        topk_log_probs = [[0 for _ in range(self.beam_size)]
                          for _ in range(len(latest_tokens))]
        new_states = [None for _ in range(len(states))]
        new_attn_vectors = [None for _ in range(len(states))]
        attns = [None for _ in range(len(states))]
        topk_words = [["" for _ in range(self.beam_size)]
                      for _ in range(len(latest_tokens))]
        is_unk = [False for _ in range(len(latest_tokens))]

        # Loop over all hypotheses
        for token, state, attn_vector, i in zip(latest_tokens, states,
                                                last_attn_vectors,
                                                range(len(latest_tokens))):
            decoder_input = Variable(torch.LongTensor([token]))

            if use_cuda:
                decoder_input = decoder_input.cuda()

            attention_override = None
            if self.attention_override_map:
                if partials[i] in self.attention_override_map:
                    attention_override = self.attention_override_map[
                        partials[i]]

            decoder_output, decoder_hidden, decoder_attention, last_attn_vector = self.decoder(
                decoder_input, state, self.encoder_outputs, attn_vector,
                attention_override)

            top_id = decoder_output.data.topk(1)[1]
            if use_cuda:
                top_id = top_id.cpu()

            top_id = top_id.numpy()[0].tolist()[0]

            if top_id == UNK_token:
                print("UNK found partial = {}".format(partials[i]))
            if top_id == UNK_token and self.unk_map and partials[
                    i] in self.unk_map:
                # Replace UNK token based on user given mapping
                word = self.unk_map[partials[i]]
                topk_words[i][0] = word
                print("Replaced UNK token with {}".format(word))
                if word not in self.output_lang.word2index:
                    is_unk[i] = True
                else:
                    idx = self.output_lang.word2index[word]
                    decoder_output.data[0][idx] = 1000
            elif self.correction_map and partials[i] in self.correction_map:
                word = self.correction_map[partials[i]]
                print("Corrected {} for partial= {}".format(word, partials[i]))
                if not word in self.output_lang.word2index:
                    topk_words[i][0] = word
                    is_unk[i] = True
                idx = self.output_lang.word2index[word]
                decoder_output.data[0][idx] = 1000

            decoder_output = nn.functional.log_softmax(decoder_output)
            topk_v, topk_i = decoder_output.data.topk(self.beam_size)
            if use_cuda:
                topk_v, topk_i = topk_v.cpu(), topk_i.cpu()
            topk_v, topk_i = topk_v.numpy()[0], topk_i.numpy()[0]

            topk_ids[i] = topk_i.tolist()
            topk_log_probs[i] = topk_v.tolist()
            topk_words[i] = [
                self.output_lang.index2word[id]
                if not topk_words[i][j] else topk_words[i][j]
                for j, id in enumerate(topk_ids[i])
            ]

            new_states[i] = tuple(h.clone() for h in decoder_hidden)
            new_attn_vectors[i] = last_attn_vector.clone()
            attns[i] = decoder_attention.data
            if use_cuda:
                attns[i] = attns[i].cpu()
            attns[i] = attns[i].numpy().tolist()[0]

        return topk_ids, topk_words, topk_log_probs, new_states, new_attn_vectors, attns, is_unk

    def to_partial(self, tokens):
        return " ".join(
            [self.output_lang.index2word[token] for token in tokens])

    def init_hypothesis(self):
        start_attn = []
        last_attn_vector = torch.zeros((1, self.decoder.hidden_size))

        if use_cuda:
            last_attn_vector = last_attn_vector.cuda()

        if not self.correction_map:
            return [
                Hypothesis([SOS_token],
                           [self.output_lang.index2word[SOS_token]], [0.0],
                           tuple(h.clone() for h in self.decoder_hidden),
                           last_attn_vector.clone(), start_attn, [[]], [False])
                for _ in range(self.beam_size)
            ]

        # Assume at most 1 correction prefix at all times
        prefix = [SOS_token] + [
            self.output_lang.word2index[token] for token in self.prefix
        ]
        # We need: hidden state at the end of prefix, last_attn_vector, attention, tokens, candidates

        tokens = []
        candidates = []
        decoder_hidden = tuple(h.clone() for h in self.decoder_hidden)

        curr_candidates = []
        attn = [[0 for _ in range(self.encoder_outputs.size(0))]]

        for token in prefix:
            decoder_input = Variable(torch.LongTensor([token]))

            # Compute
            if use_cuda:
                decoder_input = decoder_input.cuda()
            decoder_output, decoder_hidden, decoder_attention, last_attn_vector = self.decoder(
                decoder_input, decoder_hidden, self.encoder_outputs,
                last_attn_vector)

            # Update
            start_attn += [attn]
            tokens += [token]
            candidates += [curr_candidates]

            attn = decoder_attention.data
            if use_cuda:
                attn = attn.cpu()
            attn = attn.numpy().tolist()[0]

            topk_v, topk_i = decoder_output.data.topk(self.beam_size)
            if use_cuda:
                topk_v, topk_i = topk_v.cpu(), topk_i.cpu()
            topk_v, topk_i = topk_v.numpy()[0], topk_i.numpy()[0]

            topk_ids = topk_i.tolist()
            topk_log_probs = topk_v.tolist()
            curr_candidates = [
                self.output_lang.index2word[i] for i in topk_ids
            ]

        return [
            Hypothesis(
                list(tokens),
                [self.output_lang.index2word[token]
                 for token in tokens], [0.0] * len(tokens),
                tuple(h.clone()
                      for h in decoder_hidden), last_attn_vector.clone(),
                list(start_attn), candidates, [False] * len(tokens))
            for _ in range(self.beam_size)
        ]

    def search(self):

        start_attn = [[[0 for _ in range(self.encoder_outputs.size(0))]]]
        last_attn_vector = torch.zeros((1, self.decoder.hidden_size))

        if use_cuda:
            last_attn_vector = last_attn_vector.cuda()

        hyps = self.init_hypothesis()

        for h in hyps:
            h.alpha = self.beam_length
            h.beta = self.beam_coverage

        result = []

        steps = 0

        while steps < self.max_length * 2 and len(result) < self.beam_size:
            latest_tokens = [hyp.latest_token for hyp in hyps]
            states = [hyp.state for hyp in hyps]
            partials = [self.to_partial(hyp.tokens) for hyp in hyps]
            last_attn_vectors = [hyp.last_attn_vector for hyp in hyps]
            all_hyps = []

            num_beam_source = 1 if steps == 0 else len(hyps)
            topk_ids, topk_words, topk_log_probs, new_states, new_attn_vectors, attns, is_unk = self.decode_topk(
                latest_tokens, states, last_attn_vectors, partials)

            for i in range(num_beam_source):
                h, ns, av, attn = hyps[i], new_states[i], new_attn_vectors[
                    i], attns[i]

                for j in range(self.beam_size):
                    candidates = [
                        self.output_lang.index2word[c]
                        for c in (topk_ids[i][:j] + topk_ids[i][j + 1:])
                    ]

                    all_hyps.append(
                        h.extend(topk_ids[i][j], topk_words[i][j],
                                 topk_log_probs[i][j], ns, av, attn,
                                 candidates, is_unk[i]))

            # Filter
            hyps = []
            # print("All Hyps")
            for h in all_hyps:
                pass
            # print([(word, log_prob) for word, log_prob in zip(h.words, h.log_probs)])
            # print("====")

            for h in self._best_hyps(all_hyps):
                if h.latest_token == EOS_token:
                    result.append(h)
                else:
                    hyps.append(h)
                if len(hyps) == self.beam_size or len(
                        result) == self.beam_size:
                    break
            steps += 1

        # print("Beam Search found {} hypotheses for beam_size {}".format(len(result), self.beam_size))
        res = self._best_hyps(result, normalize=True)
        if res:
            res[0].is_golden = True
        return res

    def _best_hyps(self, hyps, normalize=False):
        """Sort the hyps based on log probs and length.
        Args:
          hyps: A list of hypothesis.
        Returns:
          hyps: A list of sorted hypothesis in reverse log_prob order.
        """
        if normalize:
            return sorted(hyps, key=lambda h: h.score(), reverse=True)
        else:
            return sorted(hyps, key=lambda h: h.log_prob, reverse=True)
Esempio n. 10
0
def translate_segment_OpenNMT(segment):
    try:
        url = "http://" + OpenNMTEngine_ip + ":" + str(
            OpenNMTEngine_port) + "/translator/translate"

        headers = {'content-type': 'application/json'}
        tags = re.findall('(<[^>]+>)', segment)

        equiltag = {}
        cont = 0
        for tag in tags:
            if tag.find(" ") > -1:
                tagmod = "<tag" + str(cont) + ">"
                equiltag[tagmod] = tag
                segment = segment.replace(tag, tagmod)
                t = tag.split(" ")[0].replace("<", "")
                ttanc = "</" + t + ">"
                ttancmod = "</tag" + str(cont) + ">"
                segment = segment.replace(ttanc, ttancmod)
                equiltag[ttancmod] = ttanc
                cont += 1

        if MTUOCServer_verbose:
            now = str(datetime.now())
            print("---------")
            print(now)
            print("Segment: ", segment)

        #Dealing with uppercased sentences
        toLower = False
        if segment == segment.upper():
            segment = segment.lower().capitalize()
            toLower = True

        segmentNOTAGS = remove_tags(segment)
        if MTUOCServer_verbose: print("Segment No Tags: ", segmentNOTAGS)
        if preprocess_type == "SentencePiece":
            segmentPre = to_MT(segmentNOTAGS, tokenizerA, tokenizerB)
        elif preprocess_type == "NMT":
            segmentPre = to_MT(segmentNOTAGS, tokenizer, tcmodel, bpeobject,
                               joiner, bos_annotate, eos_annotate)
        if MTUOCServer_verbose: print("Segment Pre: ", segmentPre)
        params = [{"src": segmentPre}]

        response = requests.post(url, json=params, headers=headers)
        target = response.json()
        selectedtranslationPre = target[0][0]["tgt"]
        if "align" in target[0][0]:
            selectedalignment = target[0][0]["align"][0]
        else:
            selectedalignments = ""
        if MTUOCServer_verbose:
            print("Translation Pre: ", selectedtranslationPre)
        if MTUOCServer_restore_tags:
            try:
                if preprocess_type == "SentencePiece":
                    SOURCETAGSTOK = tokenizerA.tokenize(segment)
                    SOURCETAGSTOKSP = tokenizerA.unprotect(" ".join(
                        tokenizerB.tokenize(
                            tokenizerA.protect_tags(SOURCETAGSTOK))[0]))
                    SOURCETAGSTOKSP = "<s> " + SOURCETAGSTOKSP + " </s>"
                    selectedtranslationRestored = MTUOC_tags.reinsert_wordalign(
                        SOURCETAGSTOKSP,
                        selectedalignment,
                        selectedtranslationPre,
                        splitter="▁")
                elif preprocess_type == "NMT":
                    print
                    SOURCETAGSTOK = tokenizerA.tokenize(segment)
                    glossary = []
                    tags = re.findall('(<[^>]+>)', SOURCETAGSTOK)
                    for tag in tags:
                        glossary.append(tag)
                    bpe = BPE(open(Preprocess_bpecodes, encoding="utf-8"),
                              separator=joiner,
                              glossaries=glossary)
                    SOURCETAGSTOKBPE = bpe.process_line(SOURCETAGSTOK)
                    selectedtranslationRestored = MTUOC_tags.reinsert_wordalign(
                        SOURCETAGSTOKBPE, selectedalignment,
                        selectedtranslationPre)
                    print("*****************", selectedtranslationRestored)
            except:
                print("ERROR RESTORING:", sys.exc_info())
        else:
            selectedtranslationRestored = selectedtranslationPre

        if preprocess_type == "SentencePiece":
            selectedtranslation = from_MT(selectedtranslationRestored)
        elif preprocess_type == "NMT":
            selectedtranslation = from_MT(selectedtranslationRestored,
                                          detokenizer, joiner, bos_annotate,
                                          eos_annotate)

        selectedtranslation = MTUOC_tags.fix_markup_ws(segment,
                                                       selectedtranslation)

        if MTUOCServer_verbose:
            print("Translation No Tags: ", selectedtranslationPre)
        if MTUOCServer_verbose:
            print("Translation Tags: ", selectedtranslationRestored)
        if MTUOCServer_verbose: print("Word Alignment: ", selectedalignment)

        for clau in equiltag.keys():
            selectedtranslation = selectedtranslation.replace(
                clau, equiltag[clau])
        if toLower:
            selectedtranslation = selectedtranslation.upper()

        if MTUOCServer_verbose: print("Translation: ", selectedtranslation)

        return (selectedtranslation)
    except:
        ("ERROR:", sys.exc_info())
Esempio n. 11
0
def translate_segment_Marian(segment):
    try:
        #Translate tags with attributes
        tags = re.findall('(<[^>]+>)', segment)

        equiltag = {}
        cont = 0
        for tag in tags:
            if tag.find(" ") > -1:
                tagmod = "<tag" + str(cont) + ">"
                equiltag[tagmod] = tag
                segment = segment.replace(tag, tagmod)
                t = tag.split(" ")[0].replace("<", "")
                ttanc = "</" + t + ">"
                ttancmod = "</tag" + str(cont) + ">"
                segment = segment.replace(ttanc, ttancmod)
                equiltag[ttancmod] = ttanc
                cont += 1
        #Dealing with uppercased sentences
        toLower = False
        if segment == segment.upper():
            segment = segment.lower().capitalize()
            toLower = True

        if MTUOCServer_verbose:
            now = str(datetime.now())
            print("---------")
            print(now)
            print("Segment: ", segment)
        segmentNOTAGS = remove_tags(segment)
        if MTUOCServer_verbose: print("Segment No Tags: ", segmentNOTAGS)
        if preprocess_type == "SentencePiece":
            segmentPre = to_MT(segmentNOTAGS, tokenizerA, tokenizerB)
        elif preprocess_type == "NMT":
            segmentPre = to_MT(segmentNOTAGS, tokenizer, tcmodel, bpeobject,
                               joiner, bos_annotate, eos_annotate)
        if MTUOCServer_verbose: print("Segment Pre: ", segmentPre)
        lseg = len(segmentPre)
        ws.send(segmentPre)
        translations = ws.recv()
        cont = 0
        firsttranslationPre = ""
        selectedtranslation = ""
        selectedalignment = ""
        candidates = translations.split("\n")
        translation = ""
        alignments = ""
        for candidate in candidates:
            camps = candidate.split(" ||| ")
            if len(camps) > 2:
                translation = camps[1]
                alignments = camps[2]
                if cont == 0:
                    selectedtranslationPre = translation
                    selectedalignment = alignments
                ltran = len(translation)
                if ltran >= lseg * MarianEngine_min_len_factor:
                    selectedtranslationPre = translation
                    selectedalignment = alignments
                    break
                cont += 1
        if MTUOCServer_verbose:
            print("Translation Pre: ", selectedtranslationPre)
        if MTUOCServer_restore_tags:
            try:
                if preprocess_type == "SentencePiece":
                    SOURCETAGSTOK = tokenizerA.tokenize(segment)
                    SOURCETAGSTOKSP = tokenizerA.unprotect(" ".join(
                        tokenizerB.tokenize(
                            tokenizerA.protect_tags(SOURCETAGSTOK))[0]))
                    SOURCETAGSTOKSP = "<s> " + SOURCETAGSTOKSP + " </s>"
                    selectedtranslationRestored = MTUOC_tags.reinsert_wordalign(
                        SOURCETAGSTOKSP,
                        selectedalignment,
                        selectedtranslationPre,
                        splitter="▁")
                elif preprocess_type == "NMT":
                    SOURCETAGSTOK = tokenizerA.tokenize(segment)
                    glossary = []
                    tags = re.findall('(<[^>]+>)', SOURCETAGSTOK)
                    for tag in tags:
                        glossary.append(tag)
                    bpe = BPE(open(Preprocess_bpecodes, encoding="utf-8"),
                              separator=joiner,
                              glossaries=glossary)
                    SOURCETAGSTOKBPE = bpe.process_line(SOURCETAGSTOK)
                    selectedtranslationRestored = MTUOC_tags.reinsert_wordalign(
                        SOURCETAGSTOKBPE, selectedalignment,
                        selectedtranslationPre)

            except:
                print("ERROR RESTORING:", sys.exc_info())
        else:
            selectedtranslationRestored = selectedtranslationPre

        if preprocess_type == "SentencePiece":
            selectedtranslation = from_MT(selectedtranslationRestored)
        elif preprocess_type == "NMT":
            selectedtranslation = from_MT(selectedtranslationRestored,
                                          detokenizer, joiner, bos_annotate,
                                          eos_annotate)
        selectedtranslation = MTUOC_tags.fix_markup_ws(segment,
                                                       selectedtranslation)
        if MTUOCServer_verbose:
            print("Translation No Tags: ", selectedtranslationPre)
        if MTUOCServer_verbose:
            print("Translation Tags: ", selectedtranslationRestored)
        if MTUOCServer_verbose: print("Word Alignment: ", selectedalignment)

        for clau in equiltag.keys():
            selectedtranslation = selectedtranslation.replace(
                clau, equiltag[clau])

        if toLower:
            selectedtranslation = selectedtranslation.upper()
        if MTUOCServer_verbose: print("Translation: ", selectedtranslation)

    except:
        print("ERROR:", sys.exc_info())
    return (selectedtranslation)
Esempio n. 12
0
 def __setstate__(self, state):
     self.__dict__ = state
     from subword_nmt.apply_bpe import BPE
     with open(self._model_path, 'r', encoding='utf-8') as merge_codes:
         self._bpe = BPE(codes=merge_codes, separator=self._separator)
Esempio n. 13
0
class SubwordNMTTokenizer(BaseTokenizerWithVocab):
    def __init__(self,
                 model_path,
                 vocab: Union[str, Vocab],
                 separator: str = '@@',
                 bpe_dropout: float = 0.0,
                 suffix: str = '</w>'):
        """

        Parameters
        ----------
        model_path
        vocab
        separator
        bpe_dropout
        suffix
        """
        try_import_subword_nmt()
        from subword_nmt.apply_bpe import BPE
        self._model_path = model_path
        self._vocab = load_vocab(vocab)
        self._separator = separator
        self._bpe_dropout = bpe_dropout
        self._suffix = suffix
        with open(self._model_path, 'r', encoding='utf-8') as merge_codes:
            self._bpe = BPE(codes=merge_codes, separator=self._separator)
        self._last_subword_id_set = frozenset([
            self._vocab[ele] for ele in self._vocab.all_tokens
            if not ele.endswith(self._separator)
        ])

    def transform_sentence(self, sentence):
        """replace the separator in encoded result with suffix

        a@@, b@@, c ->  a, b, c</w>

        Parameters
        ----------
        sentence

        Returns
        -------
        new_sentence
        """
        return [
            word[:-2] if len(word) > 2 and word[-2:] == self._separator else
            word + self._suffix for word in sentence
        ]

    def encode(self, sentences, output_type=str):
        is_multi_sentences = isinstance(sentences, list)
        if not is_multi_sentences:
            sentences = [sentences]
        if output_type is str:
            ret = [
                self.transform_sentence(
                    self._bpe.segment(sentence,
                                      dropout=self._bpe_dropout).split(' '))
                for sentence in sentences
            ]
        elif output_type is int:
            if self._vocab is None:
                raise TokenizerEncodeWithoutVocabError
            ret = [
                self._vocab[self.transform_sentence(
                    self._bpe.segment(sentence,
                                      dropout=self._bpe_dropout).split(' '))]
                for sentence in sentences
            ]
        else:
            raise TokenTypeNotSupportedError(output_type)
        if is_multi_sentences:
            return ret
        else:
            return ret[0]

    def encode_with_offsets(self, sentences, output_type=str):
        is_multi_sentences = isinstance(sentences, list)
        if not is_multi_sentences:
            sentences = [sentences]
        tokens = []
        token_ids = []
        offsets = []
        for sentence in sentences:
            encode_token = self.transform_sentence(
                self._bpe.segment(sentence,
                                  dropout=self._bpe_dropout).split(' '))
            encode_id = self._vocab[encode_token]
            encode_token_without_suffix = [
                x.replace(self._suffix, '') for x in encode_token
            ]
            encode_offset = rebuild_offset_from_tokens(
                sentence, encode_token_without_suffix)
            tokens.append(encode_token)
            token_ids.append(encode_id)
            offsets.append(encode_offset)
        if not is_multi_sentences:
            tokens = tokens[0]
            token_ids = token_ids[0]
            offsets = offsets[0]
        if output_type is str:
            return tokens, offsets
        elif output_type is int:
            return token_ids, offsets
        else:
            raise TokenTypeNotSupportedError(output_type)

    def decode(self, tokens: Union[TokensType, TokenIDsType]) -> SentencesType:
        is_multiple_sentences = is_tokens_from_multiple_sentences(tokens)
        if not is_multiple_sentences:
            tokens = [tokens]
        token_type = get_token_type(tokens)
        if token_type is str:
            ret = [
                ''.join(ele_tokens).replace(self._suffix, ' ').strip()
                for ele_tokens in tokens
            ]
        elif token_type is int:
            if self._vocab is None:
                raise TokenizerDecodeWithoutVocabError
            ret = [
                ''.join(self._vocab.to_tokens(ele_tokens)).replace(
                    self._suffix, ' ').strip() for ele_tokens in tokens
            ]
        else:
            raise TokenTypeNotSupportedError(token_type)
        if is_multiple_sentences:
            return ret
        else:
            return ret[0]

    def is_last_subword(self, tokens: Union[str, int, List[str], List[int]]) \
            -> Union[bool, List[bool]]:
        """Whether the token is the last subword token. This can be used
        for whole-word masking.

        Parameters
        ----------
        tokens
            The input tokens

        Returns
        -------
        ret
            Whether the token is the last subword token in the list of subwords
        """
        if isinstance(tokens, str):
            return not tokens.endswith(self._separator)
        elif isinstance(tokens, int):
            return tokens in self._last_subword_id_set
        elif isinstance(tokens, list):
            if len(tokens) == 0:
                return []
            if isinstance(tokens[0], str):
                return [not ele.endswith(self._separator) for ele in tokens]
            elif isinstance(tokens[0], int):
                return [ele in self._last_subword_id_set for ele in tokens]
            else:
                raise NotImplementedError
        else:
            raise NotImplementedError

    @property
    def vocab(self) -> Optional[Vocab]:
        return self._vocab

    def set_vocab(self, vocab: Vocab):
        self._vocab = vocab

    def set_bpe_dropout(self, bpe_dropout: float):
        self._bpe_dropout = bpe_dropout

    def __repr__(self):
        ret = '{}(\n' \
              '   model_path = {}\n' \
              '   separator = {}\n' \
              '   bpe_dropout = {}\n' \
              '   vocab = {}\n' \
              ')'.format(self.__class__.__name__,
                         os.path.realpath(self._model_path),
                         self._separator,
                         self._bpe_dropout,
                         self._vocab)
        return ret

    def __getstate__(self):
        state = self.__dict__.copy()
        state['_bpe'] = None
        return state

    def __setstate__(self, state):
        self.__dict__ = state
        from subword_nmt.apply_bpe import BPE
        with open(self._model_path, 'r', encoding='utf-8') as merge_codes:
            self._bpe = BPE(codes=merge_codes, separator=self._separator)
    def __init__(self, bpe_code_file):
        super().__init__()

        with open(bpe_code_file, mode='r', encoding='utf-8') as f:
            self.bpe = BPE(f)
Esempio n. 15
0
 def __init__(self, bpe_codes, separator="@@", encoding='utf-8'):
     self.separator = separator.strip()
     self.bpe = BPE(codecs.open(bpe_codes, encoding=encoding),
                    separator=self.separator)
Esempio n. 16
0
 def __init__(self, model_path):
     super().__init__()
     from subword_nmt.apply_bpe import BPE
     with open(model_path) as f:
         self.model = BPE(f)
Esempio n. 17
0
class TokenProcessor(object):
    def __init__(self, config_file):
        with open(config_file) as f:
            self.__dict__.update(yaml.safe_load(f))
        assert self.type in {"cn2en", "en2cn"}
        codes = codecs.open(self.codes_file, encoding='utf-8')
        cur_path = os.path.dirname(os.path.realpath(__file__))
        self.tokenizer = BPE(codes)

        if self.type == "en2cn":
            # pre_process: normalize, tokenize, subEntity,to_lower,bpe
            # post_process: delbpe,remove_space
            self.en_tokenizer = os.path.join(cur_path, self.en_tokenizer)
            self.en_normalize_punctuation = sacremoses.MosesPunctNormalizer(
                lang="en")
            self.en_tokenizer = sacremoses.MosesTokenizer(
                lang='en', custom_nonbreaking_prefixes_file=self.en_tokenizer)
        elif self.type == "cn2en":
            # pre_process: tokenize, bpe
            # post_process: delbpe,detruecase,detokenize
            self.detruecase = sacremoses.MosesDetruecaser()
            self.detokenize = sacremoses.MosesDetokenizer(lang='en')
            self.client = aiohttp.ClientSession(
                timeout=aiohttp.ClientTimeout(total=3600),
                connector=aiohttp.TCPConnector(limit=sys.maxsize,
                                               limit_per_host=sys.maxsize))
            self.cn2en_trans_dict = slang_dict(self.trans_dict_file)
            self.chinese_char_pattern = re.compile(u"[\u4E00-\u9FA5]+")
            self.stops = re.compile(u"[.!?!?。。]+")

    def in_trans_dict(self, sent: str):
        if self.type == "cn2en":
            if self.stops.sub("", sent) in self.cn2en_trans_dict:
                return True, self.cn2en_trans_dict[self.stops.sub("", sent)]
            elif not self.chinese_char_pattern.search(sent):
                return True, sent
        return False, sent

    async def preprocess(self, sent: str):
        if self.type == "cn2en":
            sent = convert(sent, "zh-cn")
            if self.stops.sub("", sent) in self.cn2en_trans_dict or \
                not self.chinese_char_pattern.search(sent):
                return sent

            async with self.client.post(self.tokenize_url,
                                        json={
                                            'q': sent,
                                            "mode": self.tokenize_mode
                                        }) as rsp:
                rsp = await rsp.json()
                sent = " ".join(rsp['words'])
                sent = remove_ngram(sent, min_n_gram=2, max_n_gram=4)
                sent = self.tokenizer.segment(sent)
        elif self.type == "en2cn":
            sent = self.en_normalize_punctuation.normalize(sent)
            sent = self.en_tokenizer.tokenize(sent, return_str=True)
            tok = E2V(sent)
            tok = tok.lower()
            tok = remove_ngram(tok, min_n_gram=2, max_n_gram=4)
            sent = self.tokenizer.segment(tok)
        else:
            raise Exception("This type({}) is not support.".format(self.type))
        return sent

    def post_process(self, sent: str):
        if self.type == "cn2en":
            delbpe = sent.replace("@@ ", "")
            detruecase = self.detruecase.detruecase(delbpe)
            tok_out = " ".join(detruecase)
            remove_dup = remove_ngram(tok_out, min_n_gram=2, max_n_gram=4)
            detruecase = remove_dup.split()
            sent = self.detokenize.detokenize(detruecase, return_str=True)
        elif self.type == "en2cn":
            delbpe = sent.replace("@@ ", "")
            tok_out = " ".join(delbpe)
            remove_dup = remove_ngram(tok_out, min_n_gram=2, max_n_gram=4)
            delbpe = remove_dup.split()
            sent = "".join(delbpe)
        return sent
Esempio n. 18
0
def main(argv):
    import argparse
    from io import StringIO

    # argument parsing
    parser = argparse.ArgumentParser(
        formatter_class=argparse.RawDescriptionHelpFormatter,
        description="adapts an UD dataset to context-sensitive lemmatization")

    io_group = parser.add_argument_group('io')
    io_group.add_argument("--input", help="file to be transformed", type=str)
    io_group.add_argument("--output",
                          help="output source and target files",
                          nargs='+',
                          type=str,
                          default=None)
    io_group.add_argument(
        "--transform_appendix",
        help=
        "appendix to transform folder name (e.g. SLURM_JOB_ID or datetime)",
        type=str,
        default=None)
    io_group.add_argument(
        "--word_column_index",
        help="index of word column in the file (zero-indexed)",
        type=int,
        default=0)
    io_group.add_argument(
        "--lemma_column_index",
        help="index of lemma column in the file (zero-indexed)",
        type=int,
        default=1)
    io_group.add_argument(
        "--tag_column_index",
        help="index of tag column in the file (zero-indexed)",
        type=int,
        default=2)
    io_group.add_argument('--debug',
                          dest='debug',
                          help="debug mode prints target/source file to stdout"
                          " instead of writing to the file system",
                          action='store_true')
    io_group.add_argument('--overwrite', dest='overwrite', action='store_true')
    io_group.add_argument(
        "--print_file",
        help="which file to output (source/target) in debug mode",
        choices=['source', 'target'],
        type=str,
        default=defaults["PRINT_FILE"])

    repr_group = parser.add_argument_group('representation')
    repr_group.add_argument(
        "--mode",
        help="mode of transformation",
        choices=['word_and_context', 'sentence_to_sentence'],
        type=str,
        default=defaults["MODE"])
    repr_group.add_argument("--word_unit",
                            help="type of word representation",
                            choices=['char', 'word', 'bpe'],
                            type=str,
                            default=defaults["WORD_UNIT"])
    repr_group.add_argument("--tag_unit",
                            help="type of tag representation",
                            choices=['char', 'word'],
                            type=str,
                            default=defaults["TAG_UNIT"])
    repr_group.add_argument("--context_unit",
                            help="type of context representation",
                            choices=['char', 'bpe', 'word'],
                            type=str,
                            default=defaults["CONTEXT_UNIT"])
    repr_group.add_argument(
        "--char_n_gram_mode",
        help="size of char-n-grams (only used if --context_unit is char"
        "or if --mode is sentence_to_sentence and --word_unit is char, default: %(default)s)",
        type=int,
        default=defaults["CHAR_N_GRAM"])
    repr_group.add_argument(
        "--sentence_size",
        help="maximum size of sentence in sentence_to_sentence mode",
        type=int,
        default=argparse.SUPPRESS)
    repr_group.add_argument('--tag_first',
                            action='store_true',
                            help="if true tags will be printed before "
                            "words in source and target files")

    ctx_group = parser.add_argument_group('context')
    ctx_group.add_argument(
        "--context_size",
        help=
        "size of context representation (in respective units) on left and right (0 to use full span)",
        type=int,
        default=defaults["CONTEXT_SIZE"])
    ctx_group.add_argument(
        "--context_char_size",
        help=
        "size of context representation (in characters) on left and right (0 to use full span, has precedence over --context_size)",
        type=int,
        default=argparse.SUPPRESS)
    ctx_group.add_argument(
        "--context_span",
        help=
        "maximum span of a word in number of sentences on left and right of the sentence of the word, default: %(default)s))",
        type=int,
        default=defaults["CONTEXT_SPAN"])
    ctx_group.add_argument(
        "--context_tags",
        help="whether and where to include tag in the context",
        choices=['none', 'left'],
        type=str,
        default=defaults["CONTEXT_TAGS"])

    bpe_group = parser.add_argument_group('bpe')
    bpe_group.add_argument(
        "--bpe_operations",
        help="number of BPE merge operations to be learned "
        "(corresponds to number of symbols/char-n-grams/codes)",
        type=int,
        default=defaults["BPE_OPERATIONS"])
    bpe_group.add_argument(
        "--bpe_codes_path",
        help=
        "full file path to export BPE codes to or to read them from if available",
        type=str,
        default=None)

    boundary_group = parser.add_argument_group('boundaries')
    boundary_group.add_argument(
        "--left_context_boundary",
        help="left context boundary special symbol (default: %(default)s)",
        type=str,
        default=defaults["LEFT_CONTEXT_BOUNDARY"])
    boundary_group.add_argument(
        "--example_boundary",
        help="example boundary special symbol (default: %(default)s)",
        type=str,
        default=defaults["EXAMPLE_BOUNDARY"])
    boundary_group.add_argument(
        "--right_context_boundary",
        help="right context boundary special symbol (default: %(default)s)",
        type=str,
        default=defaults["RIGHT_CONTEXT_BOUNDARY"])
    boundary_group.add_argument(
        "--word_boundary",
        help="word boundary special symbol (default: %(default)s)",
        type=str,
        default=defaults["WORD_BOUNDARY"])
    boundary_group.add_argument(
        "--tag_boundary",
        help="tag boundary special symbol (default: %(default)s)",
        type=str,
        default=defaults["TAG_BOUNDARY"])
    boundary_group.add_argument(
        '--subword_separator',
        type=str,
        default=defaults["SUBWORD_SEPARATOR"],
        metavar='STR',
        help=
        "separator between non-final BPE subword units (default: '%(default)s'))"
    )

    args = parser.parse_args(argv)

    # determining input
    if args.input is None:
        if args.output is None:
            raise ValueError(
                "Can't decide how to name the transformation because you feed from stdin. Use --output to specify path."
            )
        args.input = sys.stdin
    else:
        input_folders = re.split("/+|\\\\+", args.input)
        if len(input_folders) < 2:
            raise ValueError(
                "Can't decide how to name the transformation. Use --output to specify path."
            )
        else:
            input_folder = input_folders[-2]
            input_filename = input_folders[-1].split(".")[0]

    # determining output
    if not args.debug:
        if args.output is None or (type(args.output) is list
                                   and len(args.output) == 1):
            transform_folder = "{}_{}_{}{}{}{}{}{}{}{}{}".format(
                input_folder, "w" + args.word_unit, "t" + args.tag_unit,
                ("_" + (("{:02d}u".format(args.context_size))
                        if not hasattr(args, 'context_char_size') else
                        (str(args.context_char_size) + "ch")))
                if args.mode == 'word_and_context' else "",
                ("_" + ("c" + args.context_unit))
                if args.mode == "word_and_context" else "",
                "_n{}".format(args.char_n_gram_mode) if
                ((args.mode == 'word_and_context'
                  and args.context_unit == "char") or
                 (args.mode == 'sentence_to_sentence'
                  and args.word_unit == 'char')) else "",
                "_n{}".format(args.bpe_operations) if
                ((args.mode == 'word_and_context'
                  and args.context_unit == "bpe") or
                 (args.mode == 'sentence_to_sentence'
                  and args.word_unit == 'bpe')) else "",
                "_ct" if args.context_tags == 'left' else "",
                "_tf" if args.tag_first else "",
                "_cs{}".format(args.context_span)
                if args.mode == 'word_and_context' else "",
                ".{}".format(args.transform_appendix)
                if args.transform_appendix else "")

            if args.output is None or not args.output or '' in args.output:
                full_transform_folder_path = os.path.join(
                    os.path.dirname(os.path.abspath(__file__)), 'input',
                    args.mode, transform_folder)
            else:
                full_transform_folder_path = os.path.join(
                    args.output[0], transform_folder)

            os.makedirs(full_transform_folder_path, exist_ok=True)

            output_source_path = os.path.join(
                full_transform_folder_path, '{}_source'.format(input_filename))
            output_target_path = os.path.join(
                full_transform_folder_path, '{}_target'.format(input_filename))

            print(full_transform_folder_path)

            if not args.overwrite and (os.path.isfile(output_source_path)
                                       or os.path.isfile(output_target_path)):
                raise ValueError(
                    "Output files for {} already exist in {}. Pass --overwrite or delete them."
                    .format(input_filename, full_transform_folder_path))

            # truncate output files or create them anew
            open(output_source_path, 'w').close()
            open(output_target_path, 'w').close()
        else:
            if len(args.output) != 2:
                raise ValueError(
                    "You must specify full target and source output file paths (including file name)."
                )
            full_transform_folder_path = None
            output_source_path = args.output[0]
            output_target_path = args.output[1]

    print(args, file=sys.stderr)

    # loading file
    infile_df = preprocess_dataset_for_train(
        pd.read_csv(args.input,
                    sep='\s+',
                    names=cols,
                    usecols=[
                        args.word_column_index, args.lemma_column_index,
                        args.tag_column_index
                    ],
                    skip_blank_lines=False,
                    comment='#',
                    quoting=3)[cols])
    infile_df = infile_df.reset_index(drop=True)

    # subword preprocessing of the input file
    if (args.mode == 'word_and_context' and args.context_unit == 'char') or (
            args.mode == 'sentence_to_sentence' and args.word_unit == 'char'):
        # uses subword-nmt to segment text into chargrams
        from types import SimpleNamespace
        import numpy as np
        sys.path.append(
            os.path.join(os.path.dirname(__file__), '..', 'subword-nmt'))
        from subword_nmt.segment_char_ngrams import segment_char_ngrams

        def segment(col):
            subword_nmt_output = StringIO()
            segment_char_ngrams(
                SimpleNamespace(input=infile_df[col].dropna().astype(str),
                                vocab={},
                                n=args.char_n_gram_mode,
                                output=subword_nmt_output,
                                separator=args.subword_separator))
            subword_nmt_output.seek(0)
            infile_df.loc[infile_df[col].notnull(), [col]] = np.array([
                line.rstrip(' \t\n\r') for line in subword_nmt_output
            ])[:, np.newaxis]
            subword_nmt_output.truncate(0)

        segment("word")
        if args.mode == 'sentence_to_sentence' and args.word_unit == 'char':
            segment("lemma")
    elif (args.mode == 'word_and_context' and args.context_unit == 'bpe') or (
            args.mode == 'sentence_to_sentence' and args.word_unit == 'bpe'):
        if args.bpe_codes_path:
            bpe_codes_file_path = args.bpe_codes_path
        elif full_transform_folder_path:
            bpe_codes_file_path = os.path.join(full_transform_folder_path,
                                               "bpe_codes")
        else:
            raise ValueError(
                "Specify transformation output folder or bpe output file path in order to export BPE codes."
            )

        # BPE processing
        sys.path.append(
            os.path.join(os.path.dirname(__file__), '..', 'subword-nmt'))
        from subword_nmt.apply_bpe import BPE

        # only learn BPEs if bpe_codes file is unavailable
        if not os.path.isfile(bpe_codes_file_path):
            # as advised in subword-nmt's readme, we learn BPE jointly on the sources and targets
            # because they share an alphabet (for the most part)
            from subword_nmt.learn_bpe import learn_bpe
            bpe_codes = open(bpe_codes_file_path, "w", encoding='utf-8')
            learn_bpe(
                infile_df[["word", "lemma"]].dropna().astype(str).to_string(
                    index=False, header=False).splitlines(), bpe_codes,
                args.bpe_operations)
            bpe_codes.close()

        with open(bpe_codes_file_path, encoding='utf-8') as bpe_codes:
            # apply all merge operations, without vocabulary and glossaries
            bpe = BPE(bpe_codes, -1, args.subword_separator, [], [])
            infile_df.loc[infile_df["word"].notnull(), ["word", "lemma"]] = \
                infile_df.loc[infile_df["word"].notnull(), ["word", "lemma"]].applymap(bpe.process_line)

    sentence_indices = pd.isna(infile_df).all(axis=1)
    sentence_end_iterator = (i for i, e in sentence_indices.to_dict().items()
                             if e is True)

    # per-mode specific processing
    if args.mode == 'word_and_context':
        sentence_dfs = []
        transformer_args = {
            'word_unit':
            args.word_unit,
            'tag_unit':
            args.tag_unit,
            'context_size':
            args.context_size,
            'context_char_size':
            args.context_char_size
            if hasattr(args, 'context_char_size') else None,
            'context_tags':
            args.context_tags,
            'tag_first':
            args.tag_first,
            'left_context_boundary':
            args.left_context_boundary,
            'tag_boundary':
            args.tag_boundary,
            'right_context_boundary':
            args.right_context_boundary,
            'word_boundary':
            args.word_boundary,
            'example_boundary':
            args.example_boundary,
            'subword_separator':
            args.subword_separator
        }

        transformer = Transformer(**transformer_args)

        sentence_start = 0
        for sentence_end in sentence_end_iterator:
            sentence_dfs.append(infile_df.loc[sentence_start:sentence_end - 1])
            sentence_start = sentence_end + 1

        for sentence_df_idx, sentence_df in enumerate(sentence_dfs):
            # adds additional context according to CONTEXT_SPAN to sentence below
            lc_df = pd.DataFrame()
            rc_df = pd.DataFrame()

            if args.context_span > 0:
                lc_df_ls = sentence_dfs[
                    max(sentence_df_idx -
                        args.context_span, 0):sentence_df_idx]
                if lc_df_ls:
                    lc_df = pd.concat(lc_df_ls)

                rc_df_ls = sentence_dfs[
                    sentence_df_idx +
                    1:min(sentence_df_idx + 1 + args.context_span,
                          len(sentence_dfs) - 1)]
                if rc_df_ls:
                    rc_df = pd.concat(rc_df_ls)

            output_source_lines, output_target_lines = transformer.process_sentence(
                sentence_df, lc_df, rc_df)

            if not (output_source_lines or output_target_lines):
                continue

            if args.debug:
                if args.print_file == 'source':
                    print("\n".join(output_source_lines))
                else:
                    print("\n".join(output_target_lines))
            else:
                with open(output_source_path, 'a+', encoding='utf-8') as outsourcefile, \
                        open(output_target_path, 'a+', encoding='utf-8') as outtargetfile:
                    outsourcefile.write("\n".join(output_source_lines) + "\n")
                    outtargetfile.write("\n".join(output_target_lines) + "\n")
    elif args.mode == 'sentence_to_sentence':
        sentence_start = 0

        if args.example_boundary is not None:
            pos_close_tag = args.example_boundary.find('<') + 1
            open_tag = args.example_boundary
            close_tag = open_tag[:pos_close_tag] + '/' + open_tag[
                pos_close_tag:]

        for sentence_end in sentence_end_iterator:
            output_source_line = [open_tag]
            output_target_line = [open_tag]

            last_split_pos = 0

            for sentence_idx in range(sentence_start, sentence_end):
                subwords = re.split("\s*{}\s*".format(args.subword_separator),
                                    infile_df.at[sentence_idx, "word"])
                lemma = re.split("\s*{}\s*".format(args.subword_separator),
                                 infile_df.at[sentence_idx, "lemma"])
                tag = infile_df.at[sentence_idx, "tag"]

                # inserts a breaking point at the position before this word+tag were inserted in both the source and the target
                output_source_insertion_point = len(output_source_line)
                output_target_insertion_point = len(output_target_line)

                output_source_line.extend(subwords)
                output_source_line.append(args.word_boundary)

                if not args.tag_first:
                    output_target_line.extend(lemma)
                    output_target_line.append(args.tag_boundary)

                if args.tag_unit == "word":
                    output_target_line.append(tag)
                else:
                    output_target_line.extend(tag)

                if args.tag_first:
                    output_target_line.append(args.tag_boundary)
                    output_target_line.extend(lemma)

                output_target_line.append(args.word_boundary)

                # if the target translation overflows (target sentence is guaranteed to be longer in size)
                # sanity check: awk 'NF > 50 { print NR, NF }' dev_source | wc -l
                last_split_size = len(output_target_line) - last_split_pos
                if hasattr(args, 'sentence_size'
                           ) and last_split_size > args.sentence_size:
                    output_source_line.insert(output_source_insertion_point,
                                              defaults["SENTENCE_SPLIT_TAG"])
                    output_target_line.insert(output_target_insertion_point,
                                              defaults["SENTENCE_SPLIT_TAG"])
                    # set to 1, for internal slices to account for the opening <w> sentence boundary tag
                    last_split_pos = output_target_insertion_point

            sentence_start = sentence_end + 1

            output_source_line.pop()
            output_target_line.pop()

            output_source_line.append(close_tag)
            output_target_line.append(close_tag)

            assert output_source_line.count(defaults["SENTENCE_SPLIT_TAG"]) == output_target_line.count(defaults["SENTENCE_SPLIT_TAG"]), \
                "Sentence splits in sentence_to_sentence mode are wrong."

            # split sentences if necessary
            split_cond = True
            end_source_line_split_pos = 0
            end_target_line_split_pos = 0

            while split_cond:
                try:
                    start_source_line_pos = end_source_line_split_pos
                    start_target_line_pos = end_target_line_split_pos
                    end_source_line_split_pos = output_source_line.index(
                        defaults["SENTENCE_SPLIT_TAG"],
                        end_source_line_split_pos) + 1
                    end_target_line_split_pos = output_target_line.index(
                        defaults["SENTENCE_SPLIT_TAG"],
                        end_target_line_split_pos) + 1
                except ValueError:
                    split_cond = False
                    end_source_line_split_pos = len(output_source_line) + 1
                    end_target_line_split_pos = len(output_target_line) + 1

                output_source_line_split = output_source_line[
                    start_source_line_pos:end_source_line_split_pos - 1]
                output_target_line_split = output_target_line[
                    start_target_line_pos:end_target_line_split_pos - 1]

                if output_source_line_split[-1] == defaults["WORD_BOUNDARY"]:
                    output_source_line_split[-1] = close_tag
                if output_source_line_split[0] != open_tag:
                    output_source_line_split.insert(0, open_tag)

                if output_target_line_split[-1] == defaults["WORD_BOUNDARY"]:
                    output_target_line_split[-1] = close_tag
                if output_target_line_split[0] != open_tag:
                    output_target_line_split.insert(0, open_tag)

                if args.debug:
                    if args.print_file == 'source':
                        print(" ".join(output_source_line_split))
                    else:
                        print(" ".join(output_target_line_split))
                    print("\n")
                else:
                    with open(output_source_path, 'a+', encoding='utf-8') as outsourcefile, \
                            open(output_target_path, 'a+', encoding='utf-8') as outtargetfile:
                        outsourcefile.write(
                            " ".join(output_source_line_split) + "\n")
                        outtargetfile.write(
                            " ".join(output_target_line_split) + "\n")
    def greedy_decode(self,
                      input_seq: str,
                      delimiter: str = ' ',
                      use_bpe=False):
        def separate_punct(s):
            patt = r"[\w']+|[.,!?;]"
            return ' '.join(re.findall(patt, s))

        stop_tok = self.vocab['</s>']
        len_limit = 200

        # Prep input for feeding to model
        context, current = input_seq.split('\t')
        context = separate_punct(context.strip())  # context.split()
        current = separate_punct(current.strip())  # current.split()

        if use_bpe:
            bpe_args = BPEArgs()
            bpe = BPE(bpe_args.codes, bpe_args.merges, bpe_args.separator,
                      bpe_args.vocabulary, bpe_args.glossaries)
            context = bpe.segment_tokens(context.split())
            current = bpe.segment_tokens(current.split())
        else:
            context = context.split()
            current = current.split()

        context.insert(0, '<s>')
        current.insert(0, '<s>')
        print('context tokenized:', context)
        print('current tokenized:', current)

        src_context = np.asarray([
            self.vocab[w] if w in self.vocab.keys() else self.vocab['<UNK>']
            for w in context
        ])
        src_current = np.asarray([
            self.vocab[w] if w in self.vocab.keys() else self.vocab['<UNK>']
            for w in current
        ])
        print('context encoded:', src_context)
        print('current encoded:', src_current)
        src_context = np.reshape(a=src_context, newshape=(1, len(src_context)))
        src_current = np.reshape(a=src_current, newshape=(1, len(src_current)))

        # Set up decoder input data
        decoded_tokens = []
        target_seq = np.zeros((1, len_limit), dtype='int32')
        print(target_seq.shape)
        target_seq[0, 0] = self.vocab['<s>']
        print(target_seq)

        # Loop through and generate decoder tokens
        print('Generating output...')
        for i in range(len_limit - 1):
            print('=', end='', flush=True)
            output = self.model.predict_on_batch(
                [src_context, src_current, target_seq]).argmax(axis=2)
            # sampled_index = np.argmax(output[0, i, :])
            sampled_index = output[:, i]
            if sampled_index == stop_tok:
                break
            decoded_tokens.append(self.inverse_vocab[int(sampled_index)])
            target_seq[0, i + 1] = sampled_index
            print(target_seq)

        decoded = delimiter.join(decoded_tokens)
        decoded = decoded.replace('@@ ', '')
        return decoded
Esempio n. 20
0
def test_sber_onfly(config):
    print('Loading emb matrices')
    with open(config.word_emb_file, "r") as fh:
        word_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.char_emb_file, "r") as fh:
        char_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.bpe_emb_file, "r") as fh:
        bpe_mat = np.array(json.load(fh), dtype=np.float32)
    with open(config.pos_emb_file, "r") as fh:
        pos_mat = np.array(json.load(fh), dtype=np.float32)

    if config.use_bpe and config.use_bpe_pretrained_codes:
        bpe_model = BPE(open(config.bpe_pretrained_codes_file, 'r'))
    elif config.use_bpe and not config.use_bpe_pretrained_codes:
        bpe_model = BPE(open(config.bpe_codes_file, 'r'))
    else:
        bpe_model = None

    word2idx_dict = pickle.load(open(config.word2idx_dict_file, 'rb'))
    char2idx_dict = pickle.load(open(config.char2idx_dict_file, 'rb'))
    bpe2idx_dict = pickle.load(open(config.bpe2idx_dict_file, 'rb'))
    pos2idx_dict = pickle.load(open(config.pos2idx_dict_file, 'rb'))

    print("Loading model...")
    model = Model(config,
                  None,
                  word_mat,
                  char_mat,
                  bpe_mat,
                  pos_mat,
                  trainable=False,
                  use_tfdata=False)

    sess_config = tf.ConfigProto(allow_soft_placement=True)
    sess_config.gpu_options.allow_growth = True
    sess = tf.Session(config=sess_config)
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    if config.model_name == 'latest':
        checkpoint = tf.train.latest_checkpoint(config.save_dir)
    else:
        checkpoint = os.path.join(config.save_dir, config.model_name)
    print('Restoring from: {}'.format(checkpoint))
    saver.restore(sess, checkpoint)
    sess.run(tf.assign(model.is_train, tf.constant(False, dtype=tf.bool)))

    for datafile, datatype in zip(
        [config.sber_public_file, config.sber_private_file],
        ['public', 'private']):

        datafile_squad = os.path.join(config.target_dir,
                                      "{}.json_squad".format(datatype))
        sber2squad(datafile, outfile=datafile_squad)
        data_examples, data_eval = process_file(
            config,
            datafile_squad,
            datatype,
            remove_unicode=config.remove_unicode,
            bpe_model=bpe_model,
            is_test=True)

        data_features, data_meta = build_features_notfdata(config,
                                                           data_examples,
                                                           datatype,
                                                           word2idx_dict,
                                                           char2idx_dict,
                                                           bpe2idx_dict,
                                                           pos2idx_dict,
                                                           is_test=True)

        total = data_meta["total"]

        answer_dict = {}
        remapped_dict = {}

        print(len(data_features))
        # hotfix добить длину data_examples до делителя config.batch_size
        while len(data_features) % config.batch_size != 0:
            data_features.append(data_features[-1])

        print(len(data_features))

        for step in tqdm(range(total // config.batch_size + 1)):

            def get_batch():
                batch_items = data_features[step *
                                            config.batch_size:(step + 1) *
                                            config.batch_size]
                batch = dict()
                for key in batch_items[0].keys():
                    batch[key] = np.stack([el[key] for el in batch_items])
                return batch

            batch = get_batch()

            qa_id, loss, yp1, yp2 = sess.run(
                [model.qa_id, model.loss, model.yp1, model.yp2],
                feed_dict={
                    model.c_ph: batch['context_idxs'],
                    model.q_ph: batch['ques_idxs'],
                    model.ch_ph: batch['context_char_idxs'],
                    model.qh_ph: batch['ques_char_idxs'],
                    model.cb_ph: batch['context_bpe_idxs'],
                    model.qb_ph: batch['ques_bpe_idxs'],
                    model.cp_ph: batch['context_pos_idxs'],
                    model.qp_ph: batch['ques_pos_idxs'],
                    model.y1_ph: batch['y1'],
                    model.y2_ph: batch['y2'],
                    model.qa_id: batch['id'],
                })

            answer_dict_, remapped_dict_ = convert_tokens(
                data_eval, qa_id.tolist(), yp1.tolist(), yp2.tolist())
            answer_dict.update(answer_dict_)
            remapped_dict.update(remapped_dict_)

        path_to_save_answer = os.path.join(
            config.answer_dir, '{}.json_squad_ans'.format(datatype))
        with open(path_to_save_answer, "w") as fh:
            json.dump(remapped_dict, fh)

        sber_ans = '.'.join(path_to_save_answer.split('.')[0:-1]) + '.json_ans'
        squad_answer2sber(datafile, path_to_save_answer, outfile=sber_ans)

        print("Answer dumped: {}".format(path_to_save_answer))

    # evaluating
    # TODO: CHANGE TO ENG URL
    url = 'http://api.aibotbench.com/rusquad/qas'
    headers = {'Content-Type': 'application/json', 'Accept': 'text/plain'}
    metrics = dict()
    f1, em = 0.0, 0.0
    for datatype in ['public', 'private']:
        sber_ans = open(
            os.path.join(config.answer_dir, '{}.json_ans'.format(datatype)),
            'r').readline()
        res = requests.post(url, data=sber_ans, headers=headers)
        metrics[datatype] = eval(json.loads(res.text))
        f1 += metrics[datatype]['f1']
        em += metrics[datatype]['exact_match']
        print('{}: EM: {:.5f} F-1: {:.5f}'.format(
            datatype, metrics[datatype]['exact_match'],
            metrics[datatype]['f1']))
    print('EM avg: {:.5f} F-1 avg: {:.5f}'.format(em / 2, f1 / 2))
             "all", "any", "both", "each", "few", "more", "most", "other", "some", "such", "no", "nor", "not", "only",
             "own", "same", "so", "than", "too", "very", "s", "t", "can", "will", "just", "don", "should", "now"]
stopwords = set(stopwords)
OLD_ENGLISH = {"thy": "your", "thou": "you", "Thy": "Your", "Thou": "You"}

# moses tokenizer
from sacremoses import MosesTruecaser, MosesTokenizer, MosesDetokenizer, MosesDetruecaser
mtok = MosesTokenizer(lang='en')
mtr = MosesTruecaser("vocab/truecase-model.en")
md = MosesDetokenizer(lang="en")
mdtr = MosesDetruecaser()

# bpe tokenizer
from subword_nmt.apply_bpe import BPE, read_vocabulary
vocabulary = read_vocabulary(codecs.open("vocab/vocab.bpe35000.chr", encoding='utf-8'), 10)
bpe = BPE(codes=codecs.open("vocab/codes_file_chr_35000", encoding='utf-8'), merges=35000, vocab=vocabulary)

# load nmt models
import onmt.opts
from translator_for_demo import build_translator
from onmt.utils.parse import ArgumentParser


def _parse_opt(opt):
    prec_argv = sys.argv
    sys.argv = sys.argv[:1]
    parser = ArgumentParser()
    onmt.opts.translate_opts(parser)

    opt['src'] = "dummy_src"
    opt['replace_unk'] = True
Esempio n. 22
0
# coding = utf-8

# @time    : 2019/7/15 3:56 PM
# @author  : alchemistlee
# @fileName: bpe_test.py
# @abstract:

import codecs
from subword_nmt.apply_bpe import BPE

if __name__ == '__main__':
    print('hi')
    c = codecs.open(
        '/root/workspace/translate_data/my_corpus_v6.zh-cut.processed6-bpe-code',
        encoding='utf-8')
    m = -1
    sp = '@@'
    voc = None
    a = '保罗 和 哈登'
    bpe = BPE(c, m, sp, voc, None)
    print(bpe.process_line(a))
    pass
Esempio n. 23
0
from helper import utils
import paddle
from paddle import io
import os
import numpy as np
import pandas as pd
import codecs
from subword_nmt.apply_bpe import BPE

# Set global variable, drug max position, target max position
D_MAX = 50
T_MAX = 545

drug_vocab_path = './vocabulary/drug_bpe_chembl_freq_100.txt'
drug_codes_bpe = codecs.open(drug_vocab_path)
drug_bpe = BPE(drug_codes_bpe, merges=-1, separator='')
drug_temp = pd.read_csv('./vocabulary/subword_list_chembl_freq_100.csv')
drug_index2word = drug_temp['index'].values
drug_idx = dict(zip(drug_index2word, range(0, len(drug_index2word))))

target_vocab_path = './vocabulary/target_bpe_uniprot_freq_500.txt'
target_codes_bpe = codecs.open(target_vocab_path)
target_bpe = BPE(target_codes_bpe, merges=-1, separator='')
target_temp = pd.read_csv('./vocabulary/subword_list_uniprot_freq_500.csv')
target_index2word = target_temp['index'].values
target_idx = dict(zip(target_index2word, range(0, len(target_index2word))))


def drug_encoder(input_smiles):
    """
    Drug Encoder
Esempio n. 24
0
class Translator:
    def __init__(self):
        self.parser = None
        self.args = None
        self.task = None
        self.models = None
        self.model = None
        self.src_dict, self.tgt_dict = None, None
        self.generator = None
        self.align_dict = None
        self.max_positions = None
        self.decoder = None
        self.encode_fn = None
        self.use_cuda = True
        self.src = 'en'
        self.tgt = 'zh'

        self.bpe = None
        self.tokenizer = True

    def initialize(self,
                   data_dir=_data_dir,
                   model_path=_model,
                   user_dir=_user_dir,
                   task='xmasked_seq2seq',
                   s_lang='en',
                   t_lang='zh',
                   beam=5,
                   cpu=False,
                   align_dict=None,
                   bpe_codes=_bpe_codes_en,
                   tokenizer=True):
        self.parser = options.get_generation_parser(interactive=True)
        self.src, self.tgt = s_lang, t_lang

        # generate args
        input_args = [data_dir, '--path', model_path]
        if cpu:
            input_args.append('--cpu')
        if user_dir:
            input_args.append('--user-dir')
            input_args.append(user_dir)
        if task:
            input_args.append('--task')
            input_args.append(task)
        if align_dict:
            input_args.append('--replace-unk')
            input_args.append(align_dict)
        input_args.append('--langs')
        input_args.append('{},{}'.format(s_lang, t_lang))
        input_args.append('--source-langs')
        input_args.append(s_lang)
        input_args.append('--target-langs')
        input_args.append(t_lang)
        input_args.append('-s')
        input_args.append(s_lang)
        input_args.append('-t')
        input_args.append(t_lang)
        input_args.append('--beam')
        input_args.append(str(beam))
        input_args.append('--remove-bpe')

        self.bpe = BPE(open(bpe_codes, 'r'))
        self.tokenizer = tokenizer

        self.args = options.parse_args_and_arch(self.parser,
                                                input_args=input_args)

        # initialize model
        utils.import_user_module(self.args)

        if self.args.buffer_size < 1:
            self.args.buffer_size = 1
        if self.args.max_tokens is None and self.args.max_sentences is None:
            self.args.max_sentences = 1

        assert not self.args.sampling or self.args.nbest == self.args.beam, \
            '--sampling requires --nbest to be equal to --beam'
        assert not self.args.max_sentences or self.args.max_sentences <= self.args.buffer_size, \
            '--max-sentences/--batch-size cannot be larger than --buffer-size'

        self.use_cuda = torch.cuda.is_available() and not self.args.cpu

        # Setup task, e.g., translation
        self.task = tasks.setup_task(self.args)

        # Load ensemble
        self.models, _model_args = checkpoint_utils.load_model_ensemble(
            self.args.path.split(':'),
            arg_overrides=eval(self.args.model_overrides),
            task=self.task,
        )

        # Set dictionaries
        self.src_dict = self.task.source_dictionary
        self.tgt_dict = self.task.target_dictionary

        # Optimize ensemble for generation
        for model in self.models:
            model.make_generation_fast_(
                beamable_mm_beam_size=None
                if self.args.no_beamable_mm else self.args.beam,
                need_attn=self.args.print_alignment,
            )
            if self.args.fp16:
                model.half()
            if self.use_cuda:
                model.cuda()

        # Initialize generator
        self.generator = self.task.build_generator(self.args)

        def encode_fn(x):
            if tokenizer:
                x = tokenize(x, is_zh=(s_lang == 'zh'))
            if bpe_codes:
                x = self.bpe.process_line(x)
            return x

        # Hack to support GPT-2 BPE
        if self.args.remove_bpe == 'gpt2':
            pass
        else:
            self.decoder = None
            # self.encode_fn = lambda x: x
            self.encode_fn = encode_fn

        # Load alignment dictionary for unknown word replacement
        # (None if no unknown word replacement, empty if no path to align dictionary)
        self.align_dict = utils.load_align_dict(self.args.replace_unk)

        self.max_positions = utils.resolve_max_positions(
            self.task.max_positions(),
            *[model.max_positions() for model in self.models])

    def translate(self, text, verbose=False):
        start_id = 0
        inputs = [text]
        #inputs = [text.lower()]
        #inputs = [tokenize(text, is_zh=(self.src == 'zh'))]
        results = []
        outputs = []
        for batch in self.make_batches(inputs, self.args, self.task,
                                       self.max_positions, self.encode_fn):
            src_tokens = batch.src_tokens
            src_lengths = batch.src_lengths
            if self.use_cuda:
                src_tokens = src_tokens.cuda()
                src_lengths = src_lengths.cuda()

            sample = {
                'net_input': {
                    'src_tokens': src_tokens,
                    'src_lengths': src_lengths,
                },
            }
            translations = self.task.inference_step(self.generator,
                                                    self.models, sample)
            for i, (id,
                    hypos) in enumerate(zip(batch.ids.tolist(), translations)):
                src_tokens_i = utils.strip_pad(src_tokens[i],
                                               self.tgt_dict.pad())
                results.append((start_id + id, src_tokens_i, hypos))

        # sort output to match input order
        for id, src_tokens, hypos in sorted(results, key=lambda x: x[0]):
            if self.src_dict is not None:
                src_str = self.src_dict.string(src_tokens,
                                               self.args.remove_bpe)
                if verbose:
                    print('S-{}\t{}'.format(id, src_str))

            # Process top predictions
            for hypo in hypos[:min(len(hypos), self.args.nbest)]:
                hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                    hypo_tokens=hypo['tokens'].int().cpu(),
                    src_str=src_str,
                    alignment=hypo['alignment'].int().cpu()
                    if hypo['alignment'] is not None else None,
                    align_dict=self.align_dict,
                    tgt_dict=self.tgt_dict,
                    remove_bpe=self.args.remove_bpe,
                )
                if self.decoder is not None:
                    hypo_str = self.decoder.decode(
                        map(int,
                            hypo_str.strip().split()))
                outputs.append(hypo_str)
                if verbose:
                    print('H-{}\t{}\t{}'.format(id, hypo['score'], hypo_str))
                    print('P-{}\t{}'.format(
                        id, ' '.join(
                            map(lambda x: '{:.4f}'.format(x),
                                hypo['positional_scores'].tolist()))))
                if self.args.print_alignment and verbose:
                    print('A-{}\t{}'.format(
                        id,
                        ' '.join(map(lambda x: str(utils.item(x)),
                                     alignment))))
        return ''.join(
            ''.join(outputs).split(' ')) if self.src == 'en' else ' '.join(
                ''.join(outputs).split(' '))

    def make_batches(self, lines, args, task, max_positions, encode_fn):
        tokens = [
            task.source_dictionary.encode_line(encode_fn(src_str),
                                               add_if_not_exist=False).long()
            for src_str in lines
        ]
        lengths = torch.LongTensor([t.numel() for t in tokens])
        itr = task.get_batch_iterator(
            dataset=task.build_dataset_for_inference(tokens, lengths),
            max_tokens=args.max_tokens,
            max_sentences=args.max_sentences,
            max_positions=max_positions,
        ).next_epoch_itr(shuffle=False)
        for batch in itr:
            yield Batch(
                ids=batch['id'],
                src_tokens=batch['net_input']['src_tokens'],
                src_lengths=batch['net_input']['src_lengths'],
            )