示例#1
0
    def load(self):
        self.loading_lock.clear()

        timer = Timer()
        self.logger.info("Loading model %d" % self.model_id)
        timer.start()

        try:
            self.translator = build_translator(self.opt,
                                               report_score=False,
                                               out_file=codecs.open(
                                                   os.devnull, "w", "utf-8"))
        except RuntimeError as e:
            raise ServerModelError("Runtime Error: %s" % str(e))

        timer.tick("model_loading")
        if self.tokenizer_opt is not None:
            self.logger.info("Loading tokenizer")

            if "type" not in self.tokenizer_opt:
                raise ValueError(
                    "Missing mandatory tokenizer option 'type'")

            if self.tokenizer_opt['type'] == 'sentencepiece':
                if "model" not in self.tokenizer_opt:
                    raise ValueError(
                        "Missing mandatory tokenizer option 'model'")
                import sentencepiece as spm
                sp = spm.SentencePieceProcessor()
                model_path = os.path.join(self.model_root,
                                          self.tokenizer_opt['model'])
                sp.Load(model_path)
                self.tokenizer = sp
            elif self.tokenizer_opt['type'] == 'pyonmttok':
                if "params" not in self.tokenizer_opt:
                    raise ValueError(
                        "Missing mandatory tokenizer option 'params'")
                import pyonmttok
                if self.tokenizer_opt["mode"] is not None:
                    mode = self.tokenizer_opt["mode"]
                else:
                    mode = None
                # load can be called multiple times: modify copy
                tokenizer_params = dict(self.tokenizer_opt["params"])
                for key, value in self.tokenizer_opt["params"].items():
                    if key.endswith("path"):
                        tokenizer_params[key] = os.path.join(
                            self.model_root, value)
                tokenizer = pyonmttok.Tokenizer(mode,
                                                **tokenizer_params)
                self.tokenizer = tokenizer
            else:
                raise ValueError("Invalid value for tokenizer type")

        self.load_time = timer.tick()
        self.reset_unload_timer()
        self.loading_lock.set()
def _translate(tokenized, translate_model):
    # OpenNMT is awkwardly designed which makes it annoying to programmatic load a model.
    # Their server uses this appraoch to load a model by manipulating the command
    # line arguments, parsing, and then restoring them.
    parser = argparse.ArgumentParser()
    onmt.opts.translate_opts(parser)

    prec_argv = sys.argv
    sys.argv = sys.argv[:1]

    opt = {'verbose': True, 'replace_unk': True, 'gpu': 0}
    opt['model'] = os.path.join(ROOT_DIR, 'models', translate_model)
    opt['src'] = "dummy_src"

    for (k, v) in opt.items():
        if type(v) == bool:
            sys.argv += ['-%s' % k]
        else:
            sys.argv += ['-%s' % k, str(v)]

    opt = parser.parse_args()
    opt.cuda = opt.gpu > -1
    sys.argv = prec_argv

    translator = build_translator(opt,
                                  report_score=False,
                                  out_file=open(os.devnull, "w"))

    scores = []
    predictions = []

    scores, predictions = translator.translate(src_data_iter=[tokenized],
                                               batch_size=1)

    del translator
    return scores[0], predictions[0]
示例#3
0
    doc = doc_dicts[doc_id]
    print(doc.keys())
    text_to_extract = doc['abstract']
    print(doc_id)
    print(text_to_extract)

    parser = _get_parser()
    config_path = 'config/translate/config-rnn-keyphrase.yml'
    print(os.path.abspath('../config/translate/config-rnn-keyphrase.yml'))
    print(os.path.exists(config_path))
    # one2seq_ckpt_path = '/zfs1/pbrusilovsky/rum20/kp/OpenNMT-kpg/models/keyphrase/meng17-one2seq/meng17-one2seq-kp20k/kp20k-meng17-verbatim_append-rnn-BS64-LR0.05-Layer1-Dim150-Emb100-Dropout0.0-Copytrue-Reusetrue-Covtrue-PEfalse-Contboth-IF1_step_50000.pt'
    one2seq_ckpt_path = 'models/keyphrase/meng17-one2seq-kp20k-topmodels/kp20k-meng17-verbatim_append-rnn-BS64-LR0.05-Layer1-Dim150-Emb100-Dropout0.0-Copytrue-Reusetrue-Covtrue-PEfalse-Contboth-IF1_step_50000.pt'
    opt = parser.parse_args('-config %s' % (config_path))
    setattr(opt, 'models', [one2seq_ckpt_path])

    translator = translator.build_translator(opt, report_score=False)

    scores, predictions = translator.translate(src=[text_to_extract],
                                               tgt=None,
                                               src_dir=opt.src_dir,
                                               batch_size=opt.batch_size,
                                               attn_debug=opt.attn_debug,
                                               opt=opt)
    print('Paragraph:\n\t' + text_to_extract)
    print('Top predictions:')
    keyphrases = [
        kp.strip() for kp in predictions[0]
        if (not kp.lower().strip() in stoplist) and (kp != '<unk>')
    ]
    for kp_id, kp in enumerate(keyphrases[:min(len(keyphrases), 20)]):
        print('\t%d: %s' % (kp_id + 1, kp))
示例#4
0
def main(opt):
    translator = build_translator(opt, report_score=True)

    data = []
    with open(opt.src) as f:
        sample = {"unrolled": [], "original": ("", "")}
        for line in f:
            [line_type, source, target] = line.split("\t")
            source = source.strip().split()
            target = target.strip().split()
            if line_type == "unrolled":
                sample[line_type].append((source, target))
            else:
                sample[line_type] = (source, target)
                data.append(sample)
                sample = {"unrolled": [], "original": ("", "")}

    predictions_equal = []
    performance = []
    scores_per_input_length = defaultdict(list)
    scores_per_target_length = defaultdict(list)
    lengths = []
    shorts = []
    length_is_score = []
    all_pairs = []

    random.shuffle(data)
    for sample in data[:10000]:
        unrolled_predicted, is_long, pairs = process_unrolled(
            sample["unrolled"], translator)
        source, target = sample["original"]
        source = " ".join(source)
        original_predicted = translator.translate_one(src_data_iter=[source],
                                                      batch_size=1)
        if "<eos>" in original_predicted: original_predicted.remove("<eos>")
        if "<eos>" in unrolled_predicted: unrolled_predicted.remove("<eos>")
        local_score = original_predicted == unrolled_predicted
        all_pairs.append((pairs, local_score))

        if is_long:
            lengths.append(local_score)
        else:
            shorts.append(local_score)
        predictions_equal.append(local_score)
        length_is_score.append(local_score == (not is_long))
        performance.append(unrolled_predicted == target)
        scores_per_input_length[len(source)].append(
            original_predicted == unrolled_predicted)
        scores_per_target_length[len(target)].append(
            original_predicted == unrolled_predicted)

    logging.info("Localism {}, Performance {}".format(
        sum(predictions_equal) / len(predictions_equal),
        sum(performance) / len(performance)))
    logging.info("Score of sequences containing strings > 5: {}".format(
        sum(lengths) / len(lengths)))
    logging.info("Length equals score: {}".format(
        sum(length_is_score) / len(length_is_score)))
    logging.info("Percentage of long ones: {}".format(
        len(lengths) / len(predictions_equal)))
    logging.info("Percentage of short ones: {}".format(
        len(shorts) / len(predictions_equal)))
    logging.info("Score of sequences containing strings <= 5: {}".format(
        sum(shorts) / len(shorts)))
示例#5
0
def translate_sentence(sentence):
    translator = build_translator()
    translator.translate(src=[sentence])
示例#6
0
def _get_translator(opt):
    # ArgumentParser.validate_translate_opts(opt)
    print('hello 1')
    translator = build_translator(opt, report_score=True)
    return translator
示例#7
0
def main(opt):
    # # This is just a dummy to test the implementation
    # # TODO this needs to be fixed
    # write_dummy_generated_node_types(opt.tgt, 'tmp/generated_node_types.nt')
    # ####################################################################################
    #TODO 1. Extract grammar and build initial atc
    #TODO 2. Extract atc_file and enhance atc

    grammar_atc = extract_atc_from_grammar(opt.grammar)
    all_node_type_seq_str, node_seq_scores = get_all_node_type_str(
        opt.tmp_file)
    total_number_of_test_examples = len(all_node_type_seq_str)
    all_atcs = [grammar_atc for _ in range(total_number_of_test_examples)]
    if opt.atc is not None:
        all_atcs = refine_atc(all_atcs, opt.atc)
    # exit()
    # for i, atc in enumerate(all_atcs):
    #     for key in atc.keys():
    #         debug(i, key, len(atc[key]))
    #     exit()
    #     debug('')
    all_file_set = set()
    correct_cand_file_set = set()
    translator = build_translator(opt,
                                  report_score=True,
                                  multi_feature_translator=True)
    all_scores, all_cands = translator.translate(
        src_path=opt.src,
        tgt_path=opt.tgt,
        src_dir=opt.src_dir,
        batch_size=opt.batch_size,
        attn_debug=opt.attn_debug,
        node_type_seq=[all_node_type_seq_str, node_seq_scores],
        atc=all_atcs)
    beam_size = len(all_scores[0])
    exp_name = opt.name
    all_sources = []
    all_targets = []
    all_files, all_parent_trees, all_child_trees = [], [], []
    tgt_file = open(opt.tgt)
    src_file = open(opt.src)
    files_file = open(opt.files_file)
    parent_tree_file = open(opt.parent_tree)
    child_tree_file = open(opt.child_tree)
    for src, tgt, file_path, parent_tree, child_tree in \
            zip(src_file, tgt_file, files_file, parent_tree_file, child_tree_file):
        all_sources.append(process_source(src.strip()))
        all_targets.append(process_source(tgt.strip()))
        all_files.append(file_path.strip())
        all_parent_trees.append(parent_tree.strip())
        all_child_trees.append(child_tree.strip())
    tgt_file.close()
    src_file.close()
    files_file.close()
    parent_tree_file.close()
    child_tree_file.close()
    correct = 0
    no_change = 0
    if not os.path.exists('defj_experiment/results'):
        os.mkdir('defj_experiment/results')

    if not os.path.exists('defj_experiment/result_eds'):
        os.mkdir('defj_experiment/result_eds')

    decode_res_file = open(
        'defj_experiment/results/' + exp_name + '_' + str(beam_size) +
        '_decode_res.txt', 'w')
    bleu_file = open(
        'defj_experiment/result_eds/' + exp_name + '_' + str(beam_size) +
        '_bleus.csv', 'w')

    all_eds = []
    total_example = 0
    for idx, (src, tgt, file_path, parent_tree, child_tree, cands, scores) in \
            enumerate(zip(all_sources, all_targets, all_files,
                          all_parent_trees, all_child_trees, all_cands, all_scores)):
        decode_res_file.write(
            '========================================================================\n'
        )
        total_example += 1
        decode_res_file.write('Example Number: ' + str(idx + 1) + '\n')
        decode_res_file.write('Parent Code is: \n' + src + '\n')
        decode_res_file.write(
            '-------------------------------------------------------------------------------\n'
        )
        decode_res_file.write('Child code is: \n' + tgt + '\n')
        decode_res_file.write(
            '-------------------------------------------------------------------------------\n'
        )
        decode_res_file.write('Parent Tree is : \n' + str(parent_tree) + '\n')
        decode_res_file.write(
            '-------------------------------------------------------------------------------\n'
        )
        decode_res_file.write('Child Tree is : \n' + str(child_tree) + '\n')
        decode_res_file.write(
            '-------------------------------------------------------------------------------\n'
        )
        decode_res_file.write(str(file_path) + '\n')
        decode_res_file.write(
            '-------------------------------------------------------------------------------\n'
        )
        decode_res_file.write('Edit Distance : ' +
                              str(get_edit_dist(src, tgt)) + '\n')
        decode_res_file.write(
            '-------------------------------------------------------------------------------\n'
        )

        if src == tgt:
            no_change += 1

        eds = []
        found = False
        cands_reformatted = re_organize_candidates(cands, scores, src,
                                                   opt.n_best)
        decode_res_file.write('Cadidate List Length : ' +
                              str(len(cands_reformatted)) + '\n')
        #debug(idx, 'Cadidate List Length : ' + str(len(cands_reformatted)))
        # print(len(cands_reformatted))
        for cand in cands_reformatted:
            ed = get_edit_dist(tgt, cand)
            if cand == tgt:
                found = True
            eds.append(ed)

        fn = str(file_path)
        parts_of_fn = fn.split('/')
        if 'parent' in parts_of_fn:
            pidx = parts_of_fn.index('parent')
            project_name = parts_of_fn[pidx - 2]
            bugid = parts_of_fn[pidx - 1]
            project_bug_id = project_name + " " + bugid
        else:
            project_bug_id = fn

        all_file_set.add(project_bug_id)
        if found:
            correct_cand_file_set.add(project_bug_id)
            print(project_bug_id)
            #print(src)
            #print(tgt)
            decode_res_file.write("Correct\n")
            correct += 1
        else:
            decode_res_file.write("Wrong\n")
        decode_res_file.write(
            '========================================================================\n'
        )

        all_eds.append(eds)
        decode_res_file.write(str(found) + '\n\n')
        decode_res_file.flush()
        for cid, cand in enumerate(cands_reformatted):
            code_s = cand
            tree_s = code_s
            decode_res_file.write(
                str(cid) + '\nTree\t' + str(tree_s) + '\nCode\t' + code_s +
                '\nDistance : ' + str(eds[cid]) + '\n\n')
        decode_res_file.write(
            '========================================================================\n\n\n'
        )

    all_eds = np.asarray(all_eds)
    print_bleu_res_to_file(bleu_file, all_eds)
    decode_res_file.close()
    bleu_file.close()
    print('Correct : ', correct, '\tTotal : ', total_example, '\tTotal Bug : ',
          len(all_file_set), '\tCorrect bug : ', len(correct_cand_file_set))
示例#8
0

def _get_parser():
    parser = ArgumentParser(description='kp_generate.py')

    opts.config_opts(parser)
    opts.translate_opts(parser)
    return parser


if __name__ == "__main__":
    parser = _get_parser()

    opt = parser.parse_args()
    logger = init_logger(opt.log_file)
    translator = build_translator(opt, logger)

    src_shards = split_corpus(opt.src, opt.shard_size)
    tgt_shards = split_corpus(opt.tgt, opt.shard_size) \
        if opt.tgt is not None else repeat(None)
    shard_pairs = zip(src_shards, tgt_shards)

    for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
        logger.info("Translating shard %d." % i)
        translator.translate(src=src_shard,
                             tgt=tgt_shard,
                             src_dir=opt.src_dir,
                             batch_size=opt.batch_size,
                             attn_debug=opt.attn_debug,
                             opt=opt)
示例#9
0
step = os.path.basename(model_path)[:-3].split('step_')[-1]
temp = opt.random_sampling_temp

if opt.extra_output_str:
    opt.extra_output_str = '_' + opt.extra_output_str

if opt.output is None:
    output_path = '/'.join(
        model_path.split('/')
        [:-2]) + '/output_%s_%s%s.encoded' % (step, temp, opt.extra_output_str)
    opt.output = output_path

ArgumentParser.validate_translate_opts(opt)
logger = init_logger(opt.log_file)

translator = build_translator(opt, report_score=True)

BASE_LIB = 'html5lib'
UA = 'Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/61.0.3163.100 Safari/537.36'
HEADERS = {'user-agent': UA}

print("prepared")


@app.route('/index.html', methods=['POST'])
def main():

    if request.method == 'POST':

        url = str(request.form['message']).strip()
        if len(url) == 0 or not url.startswith("https://"):
示例#10
0
def pytorch_to_onnx(opt):

    export = True

    opt.model = model_file_path
    opt.models = [opt.model]
    opt.n_best = 5
    opt.beam_size = 5
    opt.report_bleu = False
    opt.report_rouge = False

    translator = build_translator(opt, report_score=True)
    result = translator.translate(src='src-test.txt', batch_size=1)
    print(result)

    # return

    model = onmt.model_builder.load_test_model(opt)
    # print(model)

    if export:
        with open(os.path.join(model_file_folder, 'params.json'), 'w') as outfile:
            json.dump(vars(model[2]), outfile, ensure_ascii=False, indent=2)

        with open(os.path.join(model_file_folder, 'src_vocab.json'), 'w') as outfile:
            json.dump(dict(translator.fields)["src"].base_field.vocab.itos, outfile, ensure_ascii=False)

        with open(os.path.join(model_file_folder, 'tgt_vocab.json'), 'w') as outfile:
            json.dump(dict(translator.fields)["tgt"].base_field.vocab.itos, outfile, ensure_ascii=False)


    src_embeddings = model[1].encoder.embeddings
    if export:
        with open(os.path.join(model_file_folder, 'src_embeddings_half_binary'), 'wb') as weights_file:
            weights_file.write(src_embeddings.state_dict()['make_embedding.emb_luts.0.weight'].numpy().astype('float16').tobytes())
        embeddings_input = torch.zeros((1, 1, 1)).long()
        input_names = ['input_index']
        output_names = ['embedding']
        torch.onnx.export(src_embeddings, embeddings_input,
                          os.path.join(model_file_folder, 'src_embeddings.onnx'),
                          verbose=True, input_names=input_names, output_names=output_names
                          )


    encoder = model[1].encoder.rnn
    encoder_models = decompose_bnn_lstm(encoder)
    coreml_encoder = EncoderForCoreMLExport(encoder.input_size, encoder.hidden_size, decomposed_model_list=encoder_models, num_layers=encoder.num_layers, bidirectional=encoder.bidirectional)
    test_rnn_and_coreml_models_equality(encoder, coreml_encoder)
    num_directions = 1 + encoder.bidirectional
    if export:
        for layer_index in range(encoder.num_layers):
            for direction in range(num_directions):
                encoder_model_part = encoder_models[layer_index * num_directions + direction]
                encoder_input = (torch.randn(1, encoder.input_size if layer_index == 0 else encoder.hidden_size * num_directions),
                                 encoder_model_part.init_hidden(batch_size=1))
                input_names = ['input', 'h', 'c']
                output_names = ['h', 'c']
                torch.onnx.export(encoder_model_part, encoder_input,
                                  os.path.join(model_file_folder, 'encoder_model_{}.onnx'.format(layer_index * num_directions + direction)),
                                  verbose=True, input_names=input_names, output_names=output_names)


    tgt_embeddings = model[1].decoder.embeddings
    if export:
        with open(os.path.join(model_file_folder, 'tgt_embeddings_half_binary'), 'wb') as weights_file:
            weights_file.write(tgt_embeddings.state_dict()['make_embedding.emb_luts.0.weight'].numpy().astype('float16').tobytes())
        embeddings_input = torch.zeros((1, 1, 1)).long()
        input_names = ['input_index']
        output_names = ['embedding']
        torch.onnx.export(tgt_embeddings, embeddings_input,
                          os.path.join(model_file_folder, 'tgt_embeddings.onnx'),
                          verbose=True, input_names=input_names, output_names=output_names
                          )


    decoder_rnn = model[1].decoder.rnn
    coreml_decoder_rnn = MultiLayerLSTMForCoreMLExport(decoder_rnn.input_size, decoder_rnn.hidden_size, num_layers=decoder_rnn.num_layers)
    state_dict = decoder_rnn.state_dict()
    coreml_model_state_dict = coreml_decoder_rnn.state_dict()
    for layer_index in range(decoder_rnn.num_layers):
        for coreml_key in coreml_model_state_dict:
            if 'rnn_cell' in coreml_key:
                rnn_key = re.sub(r'rnn_cell\d\.', '', coreml_key) + '_l' + coreml_key[len('rnn_cell')]
                coreml_model_state_dict[coreml_key] = state_dict[rnn_key]
        coreml_decoder_rnn.load_state_dict(coreml_model_state_dict)

    if export:
        decoder_rnn_input = (torch.randn(1, decoder_rnn.input_size),
                             coreml_decoder_rnn.init_hidden(batch_size=1))
        input_names = ['input']
        output_names = []
        for layer_index in range(decoder_rnn.num_layers):
            input_names += ['h{}'.format(layer_index), 'c{}'.format(layer_index)]
            output_names += ['h{}'.format(layer_index), 'c{}'.format(layer_index)]
        torch.onnx.export(coreml_decoder_rnn, decoder_rnn_input,
                          os.path.join(model_file_folder, 'decoder_rnn_model.onnx'),
                          verbose=True, input_names=input_names, output_names=output_names)


    attn = model[1].decoder.attn
    if export:
        input_names = ['input']
        output_names = ['output']
        rnn_output = torch.rand(1, decoder_rnn.hidden_size)
        torch.onnx.export(attn.linear_in, (rnn_output, ),
                          os.path.join(model_file_folder, 'attn_linear_in.onnx'),
                          verbose=True, input_names=input_names, output_names=output_names
                          )
        input_names = ['input']
        output_names = ['output']
        input = torch.rand(1, 2 * decoder_rnn.hidden_size)
        torch.onnx.export(attn.linear_out, (input, ),
                          os.path.join(model_file_folder, 'attn_linear_out.onnx'),
                          verbose=True, input_names=input_names, output_names=output_names
                          )


    generator = model[1].generator[0]
    if export:
        input_names = ['input']
        output_names = ['output']
        input = torch.rand(1, decoder_rnn.hidden_size)
        torch.onnx.export(generator, (input, ),
                          os.path.join(model_file_folder, 'generator.onnx'),
                          verbose=True, input_names=input_names, output_names=output_names
                          )

    translate_with_submodels(src_file='src-test.txt', fields=translator.fields, src_embeddings=src_embeddings, decomposed_encoders=encoder_models,
                             tgt_embeddings=tgt_embeddings, decoder_rnn=coreml_decoder_rnn, attn_linear_in=attn.linear_in,
                             attn_linear_out=attn.linear_out, generator=generator)
示例#11
0
def main(opt):

    #TODO: delete all lines related to WALS
    #begin
    SimulationLanguages = [opt.wals_src, opt.wals_tgt]

    print('Loading WALS features from databases...')

    cwd = os.getcwd()

    db = sqlite3.connect(cwd + '/onmt/WalsValues.db')
    cursor = db.cursor()
    cursor.execute('SELECT * FROM WalsValues')
    WalsValues = cursor.fetchall()

    db = sqlite3.connect(cwd + '/onmt/FeaturesList.db')
    cursor = db.cursor()
    cursor.execute('SELECT * FROM FeaturesList')
    FeaturesList = cursor.fetchall()

    db = sqlite3.connect(cwd + '/onmt/FTInfos.db')
    cursor = db.cursor()
    cursor.execute('SELECT * FROM FTInfos')
    FTInfos = cursor.fetchall()

    db = sqlite3.connect(cwd + '/onmt/FTList.db')
    cursor = db.cursor()
    cursor.execute('SELECT * FROM FTList')
    FTList = cursor.fetchall()

    ListLanguages = []
    for i in WalsValues:
        ListLanguages.append(i[0])

    FeatureTypes = []
    for i in FTList:
        FeatureTypes.append((i[0], i[1].split(',')))

    FeatureNames = []
    for i in FeatureTypes:
        FeatureNames += i[1]

    FeatureTypesNames = []
    for i in FeatureTypes:
        FeatureTypesNames.append(i[0])

    FeatureValues, FeatureTensors = get_feat_values(SimulationLanguages,
                                                    WalsValues, FeaturesList,
                                                    ListLanguages,
                                                    FeatureTypes, FeatureNames)
    #end

    #TODO: load wals features from command-line (wals.npz)
    #TODO: remove all parameters related to WALS features and include four numpy vectors that describe WALS
    translator = build_translator(opt,
                                  FeatureValues,
                                  FeatureTensors,
                                  FeatureTypes,
                                  FeaturesList,
                                  FeatureNames,
                                  FTInfos,
                                  FeatureTypesNames,
                                  SimulationLanguages,
                                  report_score=True)
    translator.translate(src_path=opt.src,
                         tgt_path=opt.tgt,
                         src_dir=opt.src_dir,
                         batch_size=opt.batch_size,
                         attn_debug=opt.attn_debug)
示例#12
0
文件: trainer.py 项目: takatomo-k/s2s
def build_trainer(fields,
                  opt,
                  device_id,
                  model,
                  optim,
                  model_saver=None,
                  optim_ad=None,
                  loss_ad=None):
    """
    Simplify `Trainer` creation based on user `opt`s*

    Args:
        opt (:obj:`Namespace`): user options (usually from argument parsing)
        model (:obj:`onmt.models.NMTModel`): the model to train
        optim (:obj:`onmt.utils.Optimizer`): optimizer used during training
        data_type (str): string describing the type of data
            e.g. "text", "img", "audio"
        model_saver(:obj:`onmt.models.ModelSaverBase`): the utility object
            used to save the model
    """
    train_loss = onmt.utils.loss.build_loss_compute(model, opt)
    valid_loss = onmt.utils.loss.build_loss_compute(model, opt, train=False)

    #trunc_size = opt.truncated_decoder  # Badly named...
    shard_size = opt.max_generator_batches if opt.model_dtype == 'fp32' else 0
    norm_method = opt.normalization
    accum_count = opt.accum_count
    accum_steps = opt.accum_steps
    n_gpu = opt.world_size
    average_decay = opt.average_decay
    average_every = opt.average_every
    dropout = opt.dropout
    dropout_steps = opt.dropout_steps
    if device_id >= 0:
        gpu_rank = opt.gpu_ranks[device_id]
    else:
        gpu_rank = 0
        n_gpu = 0
    gpu_verbose_level = opt.gpu_verbose_level

    earlystopper = onmt.utils.EarlyStopping(
        opt.early_stopping, scorers=onmt.utils.scorers_from_opts(opt)) \
        if opt.early_stopping > 0 else None
    #import pdb; pdb.set_trace()
    report_manager = onmt.utils.build_report_manager(opt, fields)

    translator = build_translator(model,
                                  torch.load(opt.data + "vocab.pt"),
                                  opt,
                                  report_score=True,
                                  device_id=device_id)

    trainer = onmt.Trainer(model,
                           translator,
                           train_loss,
                           valid_loss,
                           optim,
                           shard_size,
                           norm_method,
                           accum_count,
                           accum_steps,
                           n_gpu,
                           gpu_rank,
                           gpu_verbose_level,
                           report_manager,
                           model_saver=model_saver if gpu_rank == 0 else None,
                           average_decay=average_decay,
                           average_every=average_every,
                           model_dtype=opt.model_dtype,
                           earlystopper=earlystopper,
                           dropout=dropout,
                           dropout_steps=dropout_steps,
                           device_id=device_id,
                           optim_ad=optim_ad,
                           loss_ad=loss_ad)
    return trainer
示例#13
0
def main(opt):
    testset = loads_json(opt.data, 'load test file')
    f = open("data/corpus.json", "r")
    data = f.readlines()
    samples = []
    for i in data:
        samples.append(json.loads(i)["sents"])
    corpus = []
    for i in samples:
        corpus.append(" ".join([" ".join(sent) for sent in i]))
    vectorizer = TfidfVectorizer()
    tfidf = vectorizer.fit(corpus)

    translator = build_translator(opt, report_score=False)
    translated = translator.translate(data_path=opt.data,
                                      batch_size=opt.batch_size,
                                      tfidf=tfidf)

    # find first 3 not similar distractors
    hypothesis = {}
    for translation in translated:
        sample_id = str(translation.ex_raw.id).split("_")
        question_id = sample_id[0] + "_" + sample_id[1]
        pred1 = translation.pred_sents[0]
        pred2, pred3 = None, None
        for pred in translation.pred_sents[1:]:
            if jaccard_similarity(pred1, pred) < 0.5:
                if pred2 is None:
                    pred2 = pred
                else:
                    if pred3 is None:
                        if jaccard_similarity(pred2, pred) < 0.5:
                            pred3 = pred
            if pred2 is not None and pred3 is not None:
                break

        if pred2 is None:
            pred2 = translation.pred_sents[1]
            if pred3 is None:
                pred3 = translation.pred_sents[2]
        else:
            if pred3 is None:
                pred3 = translation.pred_sents[1]

        hypothesis[question_id] = [pred1, pred2, pred3]
    torch.save(hypothesis, "translated.hypothesis.pt")

    reference = {}
    f = open(opt.output, "w")
    for sample in testset:
        line = copy.deepcopy(sample)
        sample_id = sample["id"].split("_")
        question_id = sample_id[0] + "_" + sample_id[1]
        if question_id not in reference.keys():
            reference[question_id] = [sample['distractor']]
        else:
            reference[question_id].append(sample['distractor'])
        line["pred"] = hypothesis[question_id]
        f.write(json.dumps(line) + "\n")
    f.close()

    _ = eval(hypothesis, reference)
示例#14
0
def main(opt, grammar, actual_n_best):
    translator = build_translator(opt, report_score=True)
    all_scores, all_cands, all_tree_cands = translate_all(
        opt, grammar, actual_n_best)
    # debug(len(all_cands[0]), len(all_tree_cands[0]))
    beam_size = actual_n_best  #len(all_scores[0])
    exp_name = opt.name
    all_sources = []
    all_targets = []
    tgt_file = open(opt.tgt)
    src_file = open(opt.src)
    for a, b in zip(src_file, tgt_file):
        all_sources.append(a.strip())
        all_targets.append(b.strip())
    tgt_file.close()
    src_file.close()
    correct = 0
    no_change = 0
    decode_res_file = open(
        'results/' + exp_name + '_' + str(beam_size) + '_decode_res.txt', 'w')
    bleu_file = open(
        'result_bleus/' + exp_name + '_' + str(beam_size) + '_bleus.csv', 'w')
    correct_id_file = open(
        'correct_ids/' + exp_name + '_' + str(beam_size) + '.txt', 'w')

    all_bleus = []
    total_example = 0
    for idx, (src, tgt, cands, trees) in enumerate(
            zip(all_sources, all_targets, all_cands, all_tree_cands)):
        total_example += 1
        decode_res_file.write(str(idx) + '\n')
        decode_res_file.write(src + '\n')
        decode_res_file.write(
            '-------------------------------------------------------------------------------------\n'
        )
        decode_res_file.write(tgt + '\n')
        if src == tgt:
            no_change += 1
        decode_res_file.write(
            '=====================================================================================\n'
        )
        decode_res_file.write('Canditdate Size : ' + str(len(cands)) + '\n')
        decode_res_file.write(
            '-------------------------------------------------------------------------------------\n'
        )
        bleus = []
        found = False
        # debug(len(cands), len(trees))
        for cand, tree in zip(cands, trees):
            ed = get_edit_dist(tgt, cand)
            if cand == tgt:
                found = True
            bleus.append(ed)
            decode_res_file.write(cand + '\n')
            decode_res_file.write(' '.join([str(x) for x in tree]) + '\n')
            decode_res_file.write(str(ed) + '\n')
        if found:
            correct += 1
            correct_id_file.write(str(idx) + '\n')
        all_bleus.append(bleus)
        decode_res_file.write(str(found) + '\n\n')

    all_bleus = np.asarray(all_bleus)
    print_bleu_res_to_file(bleu_file, all_bleus)
    decode_res_file.close()
    bleu_file.close()
    correct_id_file.close()
    print(correct, no_change, total_example)
示例#15
0
def main(opt):
    # # This is just a dummy to test the implementation
    # # TODO this needs to be fixed
    # write_dummy_generated_node_types(opt.tgt, 'tmp/generated_node_types.nt')
    # ####################################################################################
    #TODO 1. Extract grammar and build initial atc
    #TODO 2. Extract atc_file and enhance atc

    grammar_atc = extract_atc_from_grammar(opt.grammar)
    all_node_type_seq_str, node_seq_scores = get_all_node_type_str(
        opt.tmp_file)
    total_number_of_test_examples = len(all_node_type_seq_str)
    all_atcs = [grammar_atc for _ in range(total_number_of_test_examples)]
    if opt.atc is not None:
        all_atcs = refine_atc(all_atcs, opt.atc)
    # exit()
    # for i, atc in enumerate(all_atcs):
    #     for key in atc.keys():
    #         debug(i, key, len(atc[key]))
    #     exit()
    #     debug('')

    translator = build_translator(opt,
                                  report_score=True,
                                  multi_feature_translator=True)
    all_scores, all_cands = translator.translate(
        src_path=opt.src,
        tgt_path=opt.tgt,
        src_dir=opt.src_dir,
        batch_size=opt.batch_size,
        attn_debug=opt.attn_debug,
        node_type_seq=[all_node_type_seq_str, node_seq_scores],
        atc=all_atcs)
    beam_size = len(all_scores[0])
    exp_name = opt.name
    all_sources = []
    all_targets = []
    tgt_file = open(opt.tgt)
    src_file = open(opt.src)
    for a, b in zip(src_file, tgt_file):
        all_sources.append(process_source(a.strip()))
        all_targets.append(process_source(b.strip()))
    tgt_file.close()
    src_file.close()
    correct = 0
    no_change = 0
    if not os.path.exists('results'):
        os.mkdir('results')

    if not os.path.exists('result_eds'):
        os.mkdir('result_eds')

    decode_res_file = open(
        'results/' + exp_name + '_' + str(beam_size) + '_decode_res.txt', 'w')

    all_eds = []
    total_example = 0
    for idx, (src, tgt, cands, scores) in enumerate(
            zip(all_sources, all_targets, all_cands, all_scores)):
        total_example += 1
        decode_res_file.write(str(idx) + '\n')
        decode_res_file.write(src + '\n')
        decode_res_file.write(
            '-------------------------------------------------------------------------------------\n'
        )
        decode_res_file.write(tgt + '\n')
        if src == tgt:
            no_change += 1
        eds = []
        o_ed = get_edit_dist(src, tgt)
        eds.append(o_ed)
        decode_res_file.write(
            '=====================================================================================\n'
        )
        decode_res_file.write('Canditdate Size : ' + str(len(cands)) + '\n')
        decode_res_file.write(
            '-------------------------------------------------------------------------------------\n'
        )

        found = False
        cands_reformatted = re_organize_candidates(cands, scores, src,
                                                   opt.n_best)
        for cand in cands_reformatted:
            ed = get_edit_dist(tgt, cand)
            if cand == tgt:
                found = True
            eds.append(ed)
            decode_res_file.write(cand + '\n')
            decode_res_file.write(str(ed) + '\n')
        if found:
            correct += 1
        all_eds.append(eds)
        decode_res_file.write(str(found) + '\n\n')
        decode_res_file.flush()

    ed_file = open(
        'result_eds/' + exp_name + '-' + str(correct) + '-' +
        str(opt.tree_count) + '-' + '-' + str(opt.n_best) + 'eds.csv', 'w')
    all_eds = np.asarray(all_eds)
    print_bleu_res_to_file(ed_file, all_eds)
    decode_res_file.close()
    ed_file.close()
    print(correct, no_change, total_example)
示例#16
0
def translate(opt):
    ArgumentParser.validate_translate_opts(opt)
    logger = init_logger(opt.log_file)
    val_file_path = os.path.join(opt.train_dir, opt.val_result_file)

    with open(val_file_path, 'r') as f:
        val_results = yaml.load(f)

    all_ckpts = []
    for key, result in val_results.items():
        loss = result['Validation loss']
        ckpt = opt.model_prefix + '_' + key.split()[-1] + '.pt'
        all_ckpts.append((ckpt, loss))

    all_ckpts.sort(key=lambda x: x[1])

    assert len(all_ckpts) >= opt.best_n_ckpts
    all_ckpts = all_ckpts[:opt.best_n_ckpts]

    print('Evaluating the following checkpoints')
    print(all_ckpts)
    local_ckpts = []

    for ckpt in all_ckpts:
        local_ckpt = os.path.join(opt.train_dir, ckpt[0])
        local_ckpts.append(local_ckpt)

    opt.models = local_ckpts

    train_config = {}
    checkpoint = torch.load(local_ckpts[0],
                            map_location=lambda storage, loc: storage)
    train_config = checkpoint.get('train_config')
    config = checkpoint.get('opt')
    print(config)

    eval_dir_local = os.path.join(opt.train_dir, 'eval')
    data_dir_local = opt.data_dir
    if os.path.exists(eval_dir_local) is False:
        os.mkdir(eval_dir_local)

    vocab_model = os.path.join(data_dir_local, opt.vocab_model)
    source_files = opt.src.split(' ')
    target_files = opt.tgt.split(' ')
    results = {
        "datasets": [],
        "configs": {
            "detok_model": opt.vocab_model,
            "detok": opt.detok_type,
            "models": opt.models,
            "max_dec_length": opt.max_length,
            "train_config": train_config,
        }
    }

    for src, tgt in zip(source_files, target_files):
        opt.src = os.path.join(data_dir_local, src)
        opt.tgt = os.path.join(data_dir_local, tgt)
        opt.output = os.path.join(eval_dir_local, src + '.output')

        translator = build_translator(opt, report_score=True)
        src_shards = split_corpus(opt.src, opt.shard_size)
        tgt_shards = split_corpus(opt.tgt, opt.shard_size)
        shard_pairs = zip(src_shards, tgt_shards)

        for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
            logger.info("Translating shard %d." % i)
            translator.translate(src=src_shard,
                                 tgt=tgt_shard,
                                 src_dir=opt.src_dir,
                                 batch_size=opt.batch_size,
                                 batch_type=opt.batch_type,
                                 attn_debug=opt.attn_debug,
                                 align_debug=opt.align_debug)

        bleu_score = calc_bleu(opt.output, opt.tgt, opt.bleu_script,
                               opt.detok_type, {'spm_model': vocab_model})
        results["datasets"].append({
            'src': src,
            'tgt': tgt,
            'bleu_score': bleu_score
        })

    results_file = os.path.join(eval_dir_local, 'results.yml')
    with open(results_file, 'w') as f:
        f.write(yaml.dump(results))
    print(results)
                                    % (elapsed_time, opt.test_interval))
                            else:
                                os.remove(pred_path)
                                logger.info(
                                    'Removed a bad pred file, #(line)=%d, #(elapsed_time)=%ds: %s'
                                    %
                                    (len(lines), int(elapsed_time), pred_path))
                    except Exception as e:
                        logger.exception(
                            'Error while validating or deleting pred file: %s'
                            % pred_path)

                if 'pred' in opt.tasks:
                    if do_trans_flag or opt.ignore_existing:
                        if translator is None:
                            translator = build_translator(
                                opt, report_score=opt.verbose, logger=logger)
                        # create an empty file to indicate that the translator is working on it
                        codecs.open(pred_path, 'w+', 'utf-8').close()
                        # set output_file for each dataset (instead of outputting to opt.output)
                        translator.out_file = codecs.open(
                            pred_path, 'w+', 'utf-8')
                        logger.info("Start translating [%s] for %s." %
                                    (dataname, ckpt_name))
                        all_scores, all_predictions = translator.translate(
                            src=src_shard,
                            tgt=tgt_shard,
                            src_dir=opt.src_dir,
                            batch_size=opt.batch_size,
                            attn_debug=opt.attn_debug,
                            opt=opt)
                        savePreds(src_shard, tgt_shard, all_predictions,
示例#18
0
def translate(opt, expert_id):
    ArgumentParser.validate_translate_opts(opt)
    logger = init_logger(opt.log_file)

    opt.gpu = expert_id % opt.ngpu
    logger.info('OPT_GPU: %d' % opt.gpu)
    translator = build_translator(opt,
                                  logger=logger,
                                  report_score=True,
                                  log_score=True)
    translator.expert_id = expert_id

    desired_output_length = linecount(opt.src) * opt.n_best
    logger.info(opt.src)
    logger.info(opt.tgt)

    # tiled src file
    opt.tiled_src = opt.src + '.x%d' % opt.n_best
    tile_lines_n_times(opt.src, opt.tiled_src, n=opt.n_best)

    logger.info("=== FWD ===")
    src_path = opt.src
    tgt_path = None  #opt.tgt
    out_path = opt.output + '/fwd_out%d.txt' % expert_id
    out_can_path = opt.output + '/fwd_out%d_can.txt' % expert_id

    if linecount(out_path) == desired_output_length:
        logger.info("Already translated. Pass.")
    else:
        # data preparation
        src_shards = split_corpus(src_path, opt.shard_size)
        tgt_shards = split_corpus(tgt_path, opt.shard_size)
        shard_pairs = zip(src_shards, tgt_shards)
        # translate
        translator.out_file = codecs.open(out_path, 'w+', 'utf-8')
        for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
            logger.info("Translating shard %d." % i)
            translator.translate(direction='x2y',
                                 src=src_shard,
                                 tgt=tgt_shard,
                                 src_dir=opt.src_dir,
                                 batch_size=opt.batch_size,
                                 batch_type=opt.batch_type,
                                 attn_debug=opt.attn_debug,
                                 align_debug=opt.align_debug)
        # canonicalize
        canonicalize_smitxt(out_path, out_can_path, remove_score=True)

    logger.info("=== BWD ===")
    translator.beam_size = 1
    translator.n_best = 1

    src_path = opt.output + '/fwd_out%d_can.txt' % expert_id
    tgt_path = opt.tiled_src
    out_path = opt.output + '/bwd_out%d.txt' % expert_id

    if linecount(out_path) == desired_output_length:
        logger.info("Already translated. Pass.")
    else:
        # data preparation
        src_shards = split_corpus(src_path, opt.shard_size)
        tgt_shards = split_corpus(tgt_path, opt.shard_size)
        shard_pairs = zip(src_shards, tgt_shards)
        # translate
        translator.out_file = codecs.open(out_path, 'w+', 'utf-8')
        for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
            logger.info("Translating shard %d." % i)
            translator.translate(direction='y2x',
                                 src=src_shard,
                                 tgt=tgt_shard,
                                 src_dir=opt.src_dir,
                                 batch_size=opt.batch_size,
                                 batch_type=opt.batch_type,
                                 attn_debug=opt.attn_debug,
                                 align_debug=opt.align_debug,
                                 only_gold_score=True)  # for speed
def main(opt):
    translator = build_translator(opt, report_score=True)
    all_scores, all_cands = translator.translate(src_path=opt.src,
                                                 tgt_path=opt.tgt,
                                                 src_dir=opt.src_dir,
                                                 batch_size=opt.batch_size,
                                                 attn_debug=opt.attn_debug)
    beam_size = actual_n_best  #len(all_scores[0])
    exp_name = opt.name
    all_sources = []
    all_targets = []
    tgt_file = open(opt.tgt)
    src_file = open(opt.src)
    for a, b in zip(src_file, tgt_file):
        all_sources.append(a.strip())
        all_targets.append(b.strip())
    tgt_file.close()
    src_file.close()
    correct = 0
    no_change = 0
    decode_res_file = open(
        'results/' + exp_name + '_' + str(beam_size) + '_decode_res.txt', 'w')
    bleu_file = open(
        'result_bleus/' + exp_name + '_' + str(beam_size) + '_bleus.csv', 'w')
    correct_id_file = open(
        'correct_ids/' + exp_name + '_' + str(beam_size) + '.txt', 'w')

    all_bleus = []
    total_example = 0
    new_token_added_total_count = 0
    new_token_added_of_correct_examples = 0
    for idx, (src, tgt,
              cands) in enumerate(zip(all_sources, all_targets, all_cands)):
        total_example += 1
        decode_res_file.write(str(idx) + '\n')
        decode_res_file.write(src + '\n')
        decode_res_file.write(
            '-------------------------------------------------------------------------------------\n'
        )
        decode_res_file.write(tgt + '\n')
        n_token_added = is_there_new_token_in_tgt(src, tgt)
        if n_token_added:
            new_token_added_total_count += 1
        if src == tgt:
            no_change += 1
        decode_res_file.write(
            '=====================================================================================\n'
        )
        decode_res_file.write('Canditdate Size : ' + str(len(cands)) + '\n')
        decode_res_file.write(
            '-------------------------------------------------------------------------------------\n'
        )
        bleus = []
        found = False
        for cand in cands:
            ed = get_edit_dist(tgt, cand)
            if ed == 0:
                found = True
            bleus.append(ed)
            decode_res_file.write(cand + '\n')
            decode_res_file.write(str(ed) + '\n')
        if found:
            if n_token_added:
                new_token_added_of_correct_examples += 1
            correct += 1
            correct_id_file.write(str(idx) + '\n')
        all_bleus.append(bleus)
        decode_res_file.write(str(found) + '\n\n')

    all_bleus = np.asarray(all_bleus)
    print_bleu_res_to_file(bleu_file, all_bleus)
    decode_res_file.close()
    bleu_file.close()
    correct_id_file.close()
    print(correct, no_change, total_example,
          new_token_added_of_correct_examples, new_token_added_total_count)
示例#20
0
def predict(invocations, model_dir, model_file, result_cnt=5):
    """
    Function called by the evaluation script to interface the participants submission_code
    `predict` function accepts the natural language invocations as input, and returns
    the predicted commands along with confidences as output. For each invocation,
    `result_cnt` number of predicted commands are expected to be returned.

    Args:
        1. invocations : `list (str)` : list of `n_batch` (default 16) natural language invocations
        2. result_cnt : `int` : number of predicted commands to return for each invocation

    Returns:
        1. commands : `list [ list (str) ]` : a list of list of strings of shape (n_batch, result_cnt)
        2. confidences: `list[ list (float) ]` : confidences corresponding to the predicted commands
                                                 confidence values should be between 0.0 and 1.0.
                                                 Shape: (n_batch, result_cnt)
    """
    opt = Namespace(
        models=[os.path.join(model_dir, file) for file in model_file],
        n_best=5,
        avg_raw_probs=False,
        alpha=0.0,
        batch_type='sents',
        beam_size=5,
        beta=-0.0,
        block_ngram_repeat=0,
        coverage_penalty='none',
        data_type='text',
        dump_beam='',
        fp32=True,
        gpu=-1,
        ignore_when_blocking=[],
        length_penalty='none',
        max_length=100,
        max_sent_length=None,
        min_length=0,
        output='/dev/null',
        phrase_table='',
        random_sampling_temp=1.0,
        random_sampling_topk=1,
        ratio=-0.0,
        replace_unk=True,
        report_align=False,
        report_time=False,
        seed=829,
        stepwise_penalty=False,
        tgt=None,
        verbose=False,
        tgt_prefix=None)
    translator = build_translator(opt, report_score=False)

    n_batch = len(invocations)
    commands = [[''] * result_cnt for _ in range(n_batch)]
    confidences = [[1, 0, 0, 0, 0] for _ in range(n_batch)]

    ################################################################################################
    #     Participants should add their codes to fill predict `commands` and `confidences` here    #
    ################################################################################################
    for idx, inv in enumerate(invocations):
        new_inv = tokenize_eng(inv)
        new_inv = ' '.join(new_inv)
        translated = translator.translate([new_inv], batch_size=1)
        for i in range(result_cnt):
            commands[idx][i] = translated[1][0][i]
            confidences[idx][i] = math.exp(translated[0][0][i].item()) / 2
        confidences[idx][0] = 1.0
    ################################################################################################
    #                               Participant code block ends                                    #
    ################################################################################################
    return commands, confidences
示例#21
0
    def __init__(self, model, src, tgt, rep, label_rep, gpuid):
        self.model = model
        self.representation = rep
        self.label_representation = label_rep
        self.src = src
        self.tgt = tgt
        self.gpuid = gpuid

        dummy_parser = argparse.ArgumentParser(description='train.py')
        onmt.opts.model_opts(dummy_parser)
        onmt.opts.translate_opts(dummy_parser)
        param = ["-model", self.model, "-src", self.src]

        if (gpuid != ""):
            param += ["-gpu", self.gpuid]
        if (self.tgt != ""):
            param += ["-tgt", self.tgt]

        self.opt = dummy_parser.parse_known_args(param)[0]

        self.translator = build_translator(self.opt)
        self.data = representation.Dataset.Dataset("testdata")
        if (self.label_representation != ""):
            self.data.target_representation = []

        if (self.representation == "EncoderWordEmbeddings"
                or self.representation == "EncoderHiddenLayer"):
            self.translator.model.encoder._vivisect = {
                "iteration": 0,
                "rescore": 1,
                "sentence": 0,
                "model_name": "OpenNMT",
                "framework": "pytorch"
            }
            probe(self.translator.model.encoder,
                  select=self.monitorONMT,
                  perform=self.performONMT,
                  cb=self.storeData)
        elif (self.representation == "ContextVector"
              or self.representation == "AttentionWeights"
              or self.representation == "DecoderWordEmbeddings"
              or self.representation == "DecoderHiddenLayer"):
            #need to use the encoder to see when a sentence start
            self.translator.model.decoder._vivisect = {
                "iteration": 0,
                "sentence": 0,
                "model_name": "OpenNMT",
                "framework": "pytorch"
            }
            probe(self.translator.model.decoder,
                  select=self.monitorONMT,
                  perform=self.performONMT,
                  cb=self.storeData)
            self.translator.model.encoder._vivisect = {
                "iteration": 0,
                "rescore": 1,
                "model_name": "OpenNMT",
                "framework": "pytorch"
            }
            probe(self.translator.model.encoder,
                  select=self.monitorONMT,
                  perform=self.performONMT,
                  cb=self.storeData)
        else:
            print("Unkown representation:", self.representation)
示例#22
0
文件: adr.py 项目: Just4Ease/iranlowo
def diacritize_text(undiacritized_text, verbose=False):
    # manually construct the options so we don't have to pass them in.
    opt = Namespace()
    opt.alpha = 0.0
    opt.attn_debug = False
    opt.avg_raw_probs = False
    opt.batch_size = 30
    opt.beam_size = 5
    opt.beta = -0.0
    opt.block_ngram_repeat = 0
    opt.config = None
    opt.coverage_penalty = 'none'
    opt.data_type = 'text'
    opt.dump_beam = ''
    opt.dynamic_dict = False
    opt.fp32 = False
    opt.gpu = -1
    opt.ignore_when_blocking = []
    opt.image_channel_size = 3
    opt.length_penalty = 'none'
    opt.log_file = ''
    opt.log_file_level = '0'
    opt.max_length = 100
    opt.max_sent_length = None
    opt.min_length = 0
    opt.n_best = 1
    opt.output = 'pred.txt'
    opt.phrase_table = ''
    opt.random_sampling_temp = 1.0
    opt.random_sampling_topk = 1
    opt.ratio = -0.0
    opt.replace_unk = True
    opt.report_bleu = False
    opt.report_rouge = False
    opt.report_time = False
    opt.sample_rate = 16000
    opt.save_config = None
    opt.seed = 829
    opt.shard_size = 10000
    opt.share_vocab = False
    opt.src = 'one_phrase.txt'
    opt.src_dir = ''
    opt.stepwise_penalty = False
    opt.tgt = None
    opt.verbose = verbose
    opt.window = 'hamming'
    opt.window_size = 0.02
    opt.window_stride = 0.01

    model_path = 'models/yo_adr_soft_attention_release.pt'
    opt.models = [pkg_resources.resource_filename(__name__, model_path)]

    # do work
    ArgumentParser.validate_translate_opts(opt)
    translator = build_translator(opt, report_score=True)

    # src_shard = ["awon okunrin nse ise agbara bi ise ode".encode('ascii')]
    src_shard = [undiacritized_text.encode('ascii')]
    tgt_shard = None

    score, prediction = translator.translate(src=src_shard,
                                             tgt=tgt_shard,
                                             src_dir=opt.src_dir,
                                             batch_size=opt.batch_size,
                                             attn_debug=opt.attn_debug)
    return prediction[0][0]
示例#23
0
    def load(self):
        self.loading_lock.clear()

        timer = Timer()
        self.logger.info("Loading model %d" % self.model_id)
        timer.start()

        try:
            self.translator = build_translator(self.opt,
                                               report_score=False,
                                               out_file=codecs.open(
                                                   os.devnull, "w", "utf-8"))
        except RuntimeError as e:
            raise ServerModelError("Runtime Error: %s" % str(e))

        timer.tick("model_loading")
        if self.preprocess_opt is not None:
            self.logger.info("Loading preprocessor")
            self.preprocessor = []

            for function_path in self.preprocess_opt:
                function = get_function_by_path(function_path)
                self.preprocessor.append(function)

        if self.tokenizer_opt is not None:
            self.logger.info("Loading tokenizer")

            if "type" not in self.tokenizer_opt:
                raise ValueError("Missing mandatory tokenizer option 'type'")

            if self.tokenizer_opt['type'] == 'sentencepiece':
                if "model" not in self.tokenizer_opt:
                    raise ValueError(
                        "Missing mandatory tokenizer option 'model'")
                import sentencepiece as spm
                sp = spm.SentencePieceProcessor()
                model_path = os.path.join(self.model_root,
                                          self.tokenizer_opt['model'])
                sp.Load(model_path)
                self.tokenizer = sp
            elif self.tokenizer_opt['type'] == 'pyonmttok':
                if "params" not in self.tokenizer_opt:
                    raise ValueError(
                        "Missing mandatory tokenizer option 'params'")
                import pyonmttok
                if self.tokenizer_opt["mode"] is not None:
                    mode = self.tokenizer_opt["mode"]
                else:
                    mode = None
                # load can be called multiple times: modify copy
                tokenizer_params = dict(self.tokenizer_opt["params"])
                for key, value in self.tokenizer_opt["params"].items():
                    if key.endswith("path"):
                        tokenizer_params[key] = os.path.join(
                            self.model_root, value)
                tokenizer = pyonmttok.Tokenizer(mode, **tokenizer_params)
                self.tokenizer = tokenizer
            else:
                raise ValueError("Invalid value for tokenizer type")

        if self.postprocess_opt is not None:
            self.logger.info("Loading postprocessor")
            self.postprocessor = []

            for function_path in self.postprocess_opt:
                function = get_function_by_path(function_path)
                self.postprocessor.append(function)

        self.load_time = timer.tick()
        self.reset_unload_timer()
        self.loading_lock.set()
示例#24
0
def main(opt):
    grammar_atc = extract_atc_from_grammar(opt.grammar)
    all_node_type_seq_str, node_seq_scores = get_all_node_type_str(opt.tmp_file)
    total_number_of_test_examples = len(all_node_type_seq_str)
    all_atcs = [grammar_atc for _ in range(total_number_of_test_examples)]
    if opt.atc is not None:
        all_atcs = refine_atc(all_atcs, opt.atc)
    translator = build_translator(opt, report_score=True, multi_feature_translator=True)
    all_scores, all_cands = translator.translate(
        src_path=opt.src, tgt_path=opt.tgt, src_dir=opt.src_dir, batch_size=opt.batch_size, attn_debug=opt.attn_debug,
        node_type_seq=[all_node_type_seq_str, node_seq_scores], atc=all_atcs)
    beam_size = len(all_scores[0])
    exp_name = opt.name
    all_sources, all_targets = [], []
    tgt_file = open(opt.tgt)
    src_file = open(opt.src)
    for a, b in zip(src_file, tgt_file):
        all_sources.append(process_source(a.strip()))
        all_targets.append(process_source(b.strip()))
    tgt_file.close()
    src_file.close()
    correct, no_change = 0, 0
    decode_res_file = open('full_report/details/' + exp_name + '_' + str(beam_size) + '_codit_result.txt', 'w')
    all_eds = []
    total_example = 0
    correct_ids_file = open('full_report/correct_ids/' + exp_name + '_' + str(beam_size) + '_codit_result.txt', 'w')
    for idx, (src, tgt, cands, scores) in enumerate(zip(all_sources, all_targets, all_cands, all_scores)):
        total_example += 1
        decode_res_file.write(str(idx) + '\n')
        decode_res_file.write(src + '\n')
        decode_res_file.write('-------------------------------------------------------------------------------------\n')
        decode_res_file.write(tgt + '\n')
        if src == tgt:
            no_change += 1
        eds = []
        o_ed = get_edit_dist(src, tgt)
        eds.append(o_ed)
        decode_res_file.write('=====================================================================================\n')
        decode_res_file.write('Canditdate Size : ' + str(len(cands)) + '\n')
        decode_res_file.write('-------------------------------------------------------------------------------------\n')

        found = False
        cands_reformatted = re_organize_candidates(cands, scores, src, opt.n_best)
        for cand in cands_reformatted:
            ed = get_edit_dist(tgt, cand)
            if cand == tgt:
                found = True
            eds.append(ed)
            decode_res_file.write(cand + '\n')
            decode_res_file.write(str(ed) + '\n')
        if found:
            correct += 1
            correct_ids_file.write(str(idx) + '\n')
        all_eds.append(eds)
        decode_res_file.write(str(found) + '\n\n')
        decode_res_file.flush()
        if idx % 100 == 0:
            debug("Processed %d examples so far, found %d correct!" % (idx, correct))

    ed_file = open('full_report/edit_distances/' + exp_name +
                   '-' + str(correct) + '-' + str(opt.tree_count) + '-' +
                   '-' + str(opt.n_best) + 'eds.csv', 'w')
    all_eds = np.asarray(all_eds)
    print_bleu_res_to_file(ed_file, all_eds)
    decode_res_file.close()
    ed_file.close()
    correct_ids_file.close()
    print(correct, total_example)
示例#25
0
    def validate(self, valid_iter_fct, step=0):
        """ Validate model.
            valid_iter: validate data iterator
        Returns:
            :obj:`nmt.Statistics`: validation loss statistics
        """
        # Set model in validating mode.
        self.model.eval()
        # get some bleu scores
        self.model.generator.eval()

        print('GETTING BLEU...')
        valid_iter = valid_iter_fct()
        output_fname = os.path.dirname(
            self.opt.save_model) + '/pred.%d.%d' % (self.device_id, step)
        translator = build_translator(self.opt,
                                      report_score=True,
                                      out_file=codecs.open(
                                          output_fname, 'w+', 'utf-8'),
                                      fields=self.fields,
                                      model=self.model,
                                      model_opt=self.opt)
        translator.translate(src_path=self.opt.valid_src,
                             tgt_path=self.opt.valid_tgt,
                             batch_size=self.opt.batch_size,
                             iterator=valid_iter)
        bleu = get_bleu(output_fname, self.opt.valid_tgt)
        print(bleu,
              file=open(
                  os.path.dirname(self.opt.save_model) +
                  '/%d.bleu_log' % self.device_id, 'a'))

        stats = onmt.utils.Statistics()

        for batch in valid_iter:
            src = inputters.make_features(batch, 'src', self.data_type)
            if self.data_type == 'text':
                _, src_lengths = batch.src
            else:
                src_lengths = None

            tgt = inputters.make_features(batch, 'tgt')

            # TODO -- MAKE THIS STUFF DYNAMIC!!!
            if hasattr(batch, 'ctx_0'):
                ctx_0 = inputters.make_features(batch, 'ctx_0', self.data_type)
                _, ctx_0_lens = batch.ctx_0
            else:
                ctx_0 = None
                ctx_0_lens = None
            if hasattr(batch, 'ctx_1'):
                ctx_1 = inputters.make_features(batch, 'ctx_1', self.data_type)
                _, ctx_1_lens = batch.ctx_0
            else:
                ctx_1 = None
                ctx_1_lens = None

            # F-prop through the model.
            outputs, attns, _ = self.model(src, tgt, src_lengths, ctx_0,
                                           ctx_0_lens, ctx_1, ctx_1_lens)

            # Compute loss.
            batch_stats = self.valid_loss.monolithic_compute_loss(
                batch, outputs, attns)

            # Update statistics.
            stats.update(batch_stats, bleu=bleu)

        # Set model back to training mode.
        self.model.train()

        return stats
示例#26
0
def main(opt):
    freedecoding = [True]
    pre_tgt = [""]

    #font_align = ('Courier', 20)# 'Roboto Mono''Letter Gothic'#font_align
    #font_label = ('Times',14)

    translator = build_translator(opt, report_score=True)
    word_translator = word_trans()

    def alignment(zh_list, pinyin_list):
        zh = []
        py = []
        for i, j in zip(zh_list, pinyin_list):
            zh_len, pin_len = len(i) * 2, len(j)
            len_diff = math.ceil((pin_len - zh_len) * 2)
            i = i + " " * len_diff
            zh.append(i)
            py.append(j)

        return " ".join(zh) + "\n" + " ".join(py)

    def count_space(l1, l2, l3):
        ratio = 2
        l1 = int(l1 * ratio)
        w = max(l1, l2, l3) + 1
        if w & 1: w += 1
        return w - l1, w - l2, w - l3

    def trans_whole(src_txt_list, finishall=False):

        # 3 lines of Chn, pinyin, word_trans, with alignment
        pinyin_list = [pinyin.get(i) for i in src_txt_list]
        out_word_list = [
            word_translator.translate(i, dest='en').text.lower()
            for i in src_txt_list
        ]
        '''
        BLANK = ' '
        chn_str, pinyin_str, word_str = '','',''
        for a,b,c in zip(src_txt_list, pinyin_list, out_word_list):
            if ord(a[0]) > 122:
                x,y,z = count_space(len(a), len(b), len(c))
                chn_str     += a + chr(12288)*(x>>1) #chr(12288)
                pinyin_str  += b + BLANK*y
                word_str    += c + BLANK*z
            else:
                chn_str     += a + '\t'
                pinyin_str  += b + '\t'
                word_str    += c + '\t'
        '''
        #pinyin_str = pinyin_str.replace('。','.')
        #word_str   = word_str.replace('。','.')

        chn_str = ' '.join(src_txt_list)
        pinyin_str = ' '.join(pinyin_list)
        word_str = ' '.join(out_word_list)

        # end of 3 lines

        src_txtidx_list = []
        res_baseline = ""
        res = ""
        # last version
        #src_txtidx_list = [translator.fields["src"].vocab.stoi[x] for x in src_txt_list]
        #translated_sent = translator.translate_batch(src_txtidx_list,translator.model_opt.waitk) # input is word id list

        # 8/8/2018 version
        if len(src_txt_list) >= translator.waitk:
            current_input_bpe_lst = bpe.segment(" ".join(src_txt_list)).split()
            src_txtidx_list = []
            src_txtidx_list = [
                translator.fields["src"].vocab.stoi[wordi]
                for wordi in current_input_bpe_lst
            ]

            ### detect the end of sentence
            #if sentwidx == len(sampleWholeInput)-1:
            translated_sent, pre_tgt[0] = translator._translate_batch(
                src_txtidx_list,
                pre_tgt[0],
                freedecoding[0],
                finishall=finishall)
            print('#' * 10, current_input_bpe_lst)
            print('#' * 20, translated_sent)

            res = clean_eos(translated_sent)

            if finishall:
                res_baseline_list, pre_tgt[0] = translator._translate_batch(
                    src_txtidx_list,
                    pre_tgt[0],
                    freedecoding=True,
                    finishall=True)
                #res_baseline = " ".join(clean_eos(res_baseline_list)).replace("@@ ","") + '.'
                res_baseline = clean_eos(res_baseline_list)
                #pinyin_str += '。'
                #word_str += '.'
            freedecoding[0] = False

        #print('temp: {} {} {} [{} {}] {}'.format(src_txt_list, res, src_txtidx_list, freedecoding[0], finishall, pre_tgt[0]))
        #print('\rChinese: {}'.format(chn_str), end = '')
        if finishall:
            print('\n\nChinese: {}'.format(chn_str))
            print('Pinyin: {}\nW-Trans:{}\nk-waits:{}\ngreedy: {}\n{}'.format(
                pinyin_str, word_str, res, res_baseline,
                "".join(src_txt_list)))

        return chn_str, pinyin_str, word_str, res, res_baseline

    def clean_eos(s):
        if len(s) >= 2 and s[
                -1] == "</s>":  # replace tail ".</s>" or "</s>" to "."
            if s[-2] == ".":
                s = s[:-1]
            else:
                s[-1] = "."
        #res = " ".join(translated_sent)
        return " ".join(s).replace("@@ ", "")

    ## kaibo: flush current line for next sentence with additional new word, until printing the final transcript if `is_final`
    def listen_print_loop(responses):
        """Iterates through server responses and prints them.

        The responses passed is a generator that will block until a response
        is provided by the server.

        Each response may contain multiple results, and each result may contain
        multiple alternatives; for details, see https://goo.gl/tjCPAU.  Here we
        print only the transcription for the top alternative of the top result.

        In this case, responses are provided for interim results as well. If the
        response is an interim one, print a line feed at the end of it, to allow
        the next result to overwrite it, until the response is a final one. For the
        final one, print a newline to preserve the finalized transcription.
        """

        num_words = 0  # number of words in the whole sentence (maybe multi lines)
        num_chars = 0  # number of char in current line
        #num_chars_printed = 0
        base_txt_list, txt_list = [], None
        #chn_str, pinyin_str, word_str = '','',''
        stream_end[0] = False
        Dict['chnLine'] = ""
        Dict['pinyin'] = ""
        Dict['word'] = ""

        for response in responses:
            if not response.results:

                continue
            stream_end[0] = False
            # The `results` list is consecutive. For streaming, we only care about
            # the first result being considered, since once it's `is_final`, it
            # moves on to considering the next utterance.
            asr_result = response.results[0]
            if not asr_result.alternatives:
                continue

            # Display the transcription of the top alternative.
            transcript = asr_result.alternatives[0].transcript

            #print('+'*20,transcript, stream_end[0])

            if len(transcript) > num_chars and ord(
                    transcript[-1]
            ) > 128:  # kaibo: cut only if characters increases & not eng
                num_chars = len(transcript)
                transcript_cut = jieba.cut(transcript, cut_all=False)
                #txt_list = list(w.lower() if ord(w[0])<128 else w for w in chain(base_txt_list, transcript_cut))
                txt_list = list(
                    map(lambda x: x.lower(),
                        chain(base_txt_list, transcript_cut)))
                if len(
                        txt_list
                ) > num_words:  # kaibo: update GUI only if words increases
                    num_words = len(txt_list)
                    if num_words > 1:
                        #chn_str, pinyin_str, word_str = trans(txt_list[:-1], chn_str, pinyin_str, word_str)
                        Dict['chnLine'], Dict['pinyin'], Dict['word'], Dict[
                            'engLine'], Dict['bslLine'] = trans_whole(
                                txt_list[:-1])
            #print('-'*20, num_chars, num_words, Dict['chnLine'])

            if asr_result.is_final:
                #trans(txt_list + ['。'], chn_str, pinyin_str, word_str)
                #trans_whole(txt_list + ['。'])
                Dict['chnLine'], Dict['pinyin'], Dict[
                    'word'], eng, bsl = trans_whole(txt_list, finishall=True)

                Dict['chnLine'] += '。'
                Dict['pinyin'] += '.'
                Dict['word'] += '.'

                #update_base('chn', Dict['chnLine'].replace('\t','').replace(chr(12288),''), True)
                update_base('chn', Dict['chnLine'].replace(' ', ''), True)
                update_base('eng', eng)
                update_base('bsl', bsl)

                stream_end[0] = True

                #stream_end[0] = False
                freedecoding[0] = True  # kaibo: added for 8/8/2018 version
                pre_tgt[0] = []  # kaibo: added for 8/8/2018 version
                src_txt_list = []
                num_words = 0  # number of words in the whole sentence (maybe multi lines)
                num_chars = 0  # number of char in current line
                base_txt_list, txt_list = [], None
                #print(". [end]")

    def update_base(key, s, ischn=False):
        s1, n = lineBreak(s, ischn=ischn)
        Dict[key + 'Base'] += s1
        Dict[key + 'Lines'] += n
        Dict[key + 'Line'] += ""

        #print('{}:{}'.format(key+'Base', Dict[key+'Base']))

    # Audio recording parameters
    RATE = 16000
    #RATE = 10000
    CHUNK = int(RATE / 10)  # 100ms
    language_code = "cmn-Hans-CN"
    client = speech.SpeechClient()
    config = types.RecognitionConfig(
        encoding=enums.RecognitionConfig.AudioEncoding.LINEAR16,
        sample_rate_hertz=RATE,
        language_code=language_code)
    streaming_config = types.StreamingRecognitionConfig(config=config,
                                                        interim_results=True)

    #stream_end[0] = False
    print('say something until a pause as ending in 65 seconds')
    modelStatus[0] = "model loaded sucessfully, say something within"
    duration[0] = time.time()
    with MicrophoneStream(RATE, CHUNK) as stream:
        audio_generator = stream.generator()
        requests = (types.StreamingRecognizeRequest(audio_content=content)
                    for content in audio_generator)
        responses = client.streaming_recognize(streaming_config, requests)
        # Now, put the transcription responses to use.
        listen_print_loop(responses)
示例#27
0
def main(opt):
    translator = build_translator(opt)
    translator.translate(opt.corpora)
示例#28
0
 def __build_translator(self, model_addr, src_addr):
     parser = ArgumentParser()
     opts.config_opts(parser)
     opts.translate_opts(parser)
     opt = parser.parse_args(['-model', model_addr, '-src', src_addr])
     return build_translator(opt, report_score=False)