Exemple #1
0
                             init_fn=init_fn)
    sess = sv.PrepareSession(config=config)
    perplexitys = []
    start_time = datetime.now()
    while True:
        input_feed = get_feed(graph.objs, traindata, model_config)
        # fetches = [graph.train_op, graph.loss, graph.global_step, graph.perplexity, graph.objs[0]['targets'], graph.objs[0]['attn_stick']]
        # _, loss, step, perplexity, target, attn_stick = sess.run(fetches, input_feed)
        fetches = [graph.train_op, graph.loss, graph.global_step, graph.perplexity]
        _, loss, step, perplexity = sess.run(fetches, input_feed)

        perplexitys.append(perplexity)

        if step % model_config.model_print_freq == 0:
            end_time = datetime.now()
            time_span = end_time - start_time
            start_time = end_time
            print('Perplexity:\t%f at step %d using %s.' % (np.mean(perplexitys), step, time_span))
            perplexitys.clear()
            # if model_config.subword_vocab_size > 0:
            #     print(traindata.voc_kword.describe(target[0][0]))
            # else:
            #     print(' '.join([traindata.voc_kword.describe(w) for w in target[0][0]]))


if __name__ == '__main__':
    from model.model_config import DefaultConfig, DummyConfig
    model_config = DefaultConfig()
    print(list_config(model_config))
    train(model_config)
Exemple #2
0
def eval(model_config=None, ckpt=None):
    model_config = (DefaultConfig() if model_config is None else model_config)
    if not exists(model_config.resultdir):
        makedirs(model_config.resultdir)
    print(list_config(model_config))

    val_data = ValData(model_config)
    graph = None
    if model_config.framework == 'transformer':
        graph = TransformerGraph(val_data, False, model_config)
    elif model_config.framework == 'seq2seq':
        graph = Seq2SeqGraph(val_data, False, model_config)
    tf.reset_default_graph()
    graph.create_model_multigpu()

    ibleus_all = []
    perplexitys_all = []
    saris_all = []
    decode_outputs_all = []
    targets = []
    targets_raw = []
    sentence_simples = []
    sentence_complexs = []
    sentence_complexs_raw = []

    it = val_data.get_data_iter()

    def init_fn(session):
        graph.saver.restore(session, ckpt)
        print('Restore ckpt:%s.' % ckpt)

    sv = tf.train.Supervisor(init_fn=init_fn)
    sess = sv.PrepareSession(config=session.get_session_config(model_config))
    while True:
        is_finish = False
        (input_feed, output_sentence_simple, output_sentence_complex,
         output_sentence_complex_raw, output_sentence_complex_raw_lines,
         output_mapper, output_ref_raw_lines, out_effective_batch_size,
         output_is_end) = get_graph_val_data(graph.objs, model_config, it,
                                             val_data)

        postprocess = PostProcess(model_config, val_data)
        fetches = {
            'decoder_target_list':
            [obj['decoder_target_list'] for obj in graph.objs],
            'loss':
            graph.loss,
            'global_step':
            graph.global_step
        }
        if model_config.replace_unk_by_emb:
            print("########REPLACING UNKS########")
            fetches.update({
                'encoder_embs': [obj['encoder_embs'] for obj in graph.objs],
                'final_outputs': [obj['final_outputs'] for obj in graph.objs]
            })
        if model_config.replace_unk_by_attn:
            fetches.update(
                {'attn_distr': [obj['attn_distr'] for obj in graph.objs]})
        results = sess.run(fetches, input_feed)
        output_target, loss, step = (results['decoder_target_list'],
                                     results['loss'], results['global_step'])
        if model_config.replace_unk_by_emb:
            print("########REPLACING UNKS########")
            output_encoder_embs, output_final_outputs = results[
                'encoder_embs'], results['final_outputs']
        if model_config.replace_unk_by_attn:
            attn_distr = results['attn_distr']
        batch_perplexity = math.exp(loss)
        perplexitys_all.append(batch_perplexity)

        for i, effective_batch_size in enumerate(out_effective_batch_size):
            is_end = output_is_end[i]
            exclude_idxs = get_exclude_list(effective_batch_size,
                                            model_config.batch_size)
            sentence_simple = output_sentence_simple[i]
            sentence_complex = output_sentence_complex[i]
            sentence_complex_raw = output_sentence_complex_raw[i]
            sentence_complex_raw_lines = output_sentence_complex_raw_lines[i]
            mapper = output_mapper[i]
            ref_raw_lines = output_ref_raw_lines[i]

            target = output_target[i]
            if model_config.replace_unk_by_emb:
                encoder_embs = output_encoder_embs[i]
                final_outputs = output_final_outputs[i]

            if exclude_idxs:
                sentence_complex = exclude_list(sentence_complex, exclude_idxs)
                sentence_complex_raw = exclude_list(sentence_complex_raw,
                                                    exclude_idxs)
                sentence_complex_raw_lines = exclude_list(
                    sentence_complex_raw_lines, exclude_idxs)

                sentence_simple = exclude_list(sentence_simple, exclude_idxs)

                target = exclude_list(target, exclude_idxs)
                mapper = exclude_list(mapper, exclude_idxs)

                for ref_i in range(model_config.num_refs):
                    ref_raw_lines[ref_i] = exclude_list(
                        ref_raw_lines[ref_i], exclude_idxs)

            target = decode(target, val_data.vocab_simple,
                            model_config.subword_vocab_size > 0)
            target_raw = target

            sentence_complex_marker = [[
                val_data.vocab_simple.encode(w) ==
                val_data.vocab_simple.encode(constant.SYMBOL_UNK) for w in sent
            ] for sent in sentence_complex_raw]
            if model_config.replace_unk_by_attn:
                target_raw = postprocess.replace_unk_by_attn(
                    sentence_complex_raw, attn_distr[0], target_raw)
            elif model_config.replace_unk_by_emb:
                target_raw = postprocess.replace_unk_by_emb(
                    sentence_complex_raw, encoder_embs, final_outputs,
                    target_raw, sentence_complex_marker)
            elif model_config.replace_unk_by_cnt:
                target_raw = postprocess.replace_unk_by_cnt(
                    sentence_complex_raw, target_raw)
            if model_config.replace_ner:
                target_raw = postprocess.replace_ner(target_raw, mapper)
            target_raw = postprocess.replace_others(target_raw)
            sentence_simple = decode(sentence_simple, val_data.vocab_simple,
                                     model_config.subword_vocab_size > 0)
            sentence_complex = decode(sentence_complex, val_data.vocab_complex,
                                      model_config.subword_vocab_size > 0)

            # Replace UNK for sentence_complex_raw and ref_raw
            # Note that sentence_complex_raw_lines and ref_raw_lines are original file lines
            sentence_complex_raw = postprocess.replace_ner(
                sentence_complex_raw, mapper)
            sentence_complex_raw = truncate_sents(sentence_complex_raw)

            # Truncate decode results
            target = truncate_sents(target)
            target_raw = truncate_sents(target_raw)
            sentence_simple = truncate_sents(sentence_simple)
            sentence_complex = truncate_sents(sentence_complex)

            targets.extend(target)
            targets_raw.extend(target_raw)
            sentence_simples.extend(sentence_simple)
            sentence_complexs.extend(sentence_complex)
            sentence_complexs_raw.extend(sentence_complex_raw)

            ibleus = []
            saris = []
            fkgls = []

            for batch_i in range(effective_batch_size):
                # Compute iBLEU
                try:
                    batch_ibleu = sentence_bleu([sentence_simple[batch_i]],
                                                target[batch_i])
                except Exception as e:
                    print('Bleu error:\t' + str(e) + '\n' +
                          str(target[batch_i]) + '\n')
                    batch_ibleu = 0
                ibleus_all.append(batch_ibleu)
                ibleus.append(batch_ibleu)

                # Compute SARI
                batch_sari = 0
                if model_config.num_refs > 0:
                    rsents = []
                    for ref_i in range(model_config.num_refs):
                        rsents.append(ref_raw_lines[ref_i][batch_i])
                    try:
                        batch_sari = SARIsent(
                            sentence_complex_raw_lines[batch_i],
                            ' '.join(target_raw[batch_i]), rsents)
                    except:
                        print('sari error: %s \n %s \n %s. \n' %
                              (sentence_complex_raw_lines[batch_i], ' '.join(
                                  target_raw[batch_i]), rsents))
                saris.append(batch_sari)
                saris_all.append(batch_sari)

                # Compute FKGL
                target_text = ' '.join(target_raw[batch_i])
                batch_fkgl = 0
                if len(target_text) > 0:
                    batch_fkgl = get_fkgl(' '.join(target_raw[batch_i]))
                fkgls.append(batch_fkgl)

            # target_output = decode_to_output(target, sentence_simple, sentence_complex,
            #                                  effective_batch_size, ibleus, target_raw, sentence_complex_raw,
            #                                  saris, fkgls)
            target_output = decode_to_output(target, sentence_simple,
                                             sentence_complex,
                                             effective_batch_size, ibleus,
                                             target_raw, sentence_complex_raw,
                                             saris, fkgls, ref_raw_lines,
                                             model_config)
            decode_outputs_all.append(target_output)

            if is_end:
                is_finish = True
                break

        if is_finish:
            break

    ibleu = np.mean(ibleus_all)
    perplexity = np.mean(perplexitys_all)
    sari = np.mean(saris_all)
    # Compute FKGL in Corpus level
    fkgl = CorpusFKGL(model_config).get_fkgl_from_joshua(step, targets_raw)

    print('Current iBLEU: \t%f' % ibleu)
    print('Current SARI: \t%f' % sari)
    print('Current FKGL: \t%f' % fkgl)
    print('Current perplexity: \t%f' % perplexity)
    print('Current eval done!')
    # MtEval Result
    mteval = MtEval_BLEU(model_config)

    # MtEval Result - Decode
    # bleu_oi_decode = mteval.get_bleu_from_decoderesult(step, sentence_complexs, sentence_simples, targets)
    # bleu_or_decode = bleu_oi_decode
    # if model_config.num_refs > 0:
    #     path_ref = model_config.val_dataset_simple_folder + model_config.val_dataset_simple_references
    #     #Decode evaluation must be lowercase because the processed files are all lowercased
    #     bleu_or_decode = mteval.get_bleu_from_decoderesult_multirefs(step, path_ref, targets,
    #                                                                  lowercase=True)
    # if model_config.num_refs > 0:
    #     bleu_decode = 0.9 * bleu_or_decode + 0.1 * bleu_oi_decode
    # else:
    #     bleu_decode = bleu_oi_decode
    # print('Current Mteval iBLEU decode: \t%f' % bleu_decode)

    # MtEval Result - raw
    bleu_oi_raw = mteval.get_bleu_from_rawresult(step, targets_raw)
    bleu_or_raw = bleu_oi_raw
    if model_config.num_refs > 0:
        path_ref = model_config.val_dataset_simple_folder + model_config.val_dataset_simple_rawlines_file_references
        bleu_or_raw = mteval.get_bleu_from_decoderesult_multirefs(
            step, path_ref, targets_raw, lowercase=model_config.lower_case)
    if model_config.num_refs > 0:
        bleu_raw = 0.9 * bleu_or_raw + 0.1 * bleu_oi_raw
    else:
        bleu_raw = bleu_oi_raw
    print('Current Mteval iBLEU raw: \t%f' % bleu_raw)

    bleu_joshua = mteval.get_bleu_from_joshua(
        step, model_config.val_dataset_simple_folder +
        model_config.val_dataset_simple_rawlines_file,
        model_config.val_dataset_simple_folder +
        model_config.val_dataset_simple_rawlines_file_references, targets_raw)

    # Use corpus-level sari
    corpus_sari = CorpusSARI(model_config)
    sari_joshua = corpus_sari.get_sari_from_joshua(
        step, model_config.val_dataset_simple_folder +
        model_config.val_dataset_simple_rawlines_file,
        model_config.val_dataset_simple_folder +
        model_config.val_dataset_simple_rawlines_file_references,
        model_config.val_dataset_complex_rawlines_file, target_raw)

    decimal_cnt = 5
    format = "%." + str(decimal_cnt) + "f"
    bleu_raw = format % bleu_raw
    bleu_oi_raw = format % bleu_oi_raw
    bleu_or_raw = format % bleu_or_raw
    # bleu_decode = format % bleu_decode
    # bleu_oi_decode = format % bleu_oi_decode
    # bleu_or_decode = format % bleu_or_decode
    ibleu = format % ibleu
    bleu_joshua = format % bleu_joshua
    sari_joshua = format % sari_joshua
    fkgl = format % fkgl
    perplexity = format % perplexity

    content = '\n'.join([
        'bleu_raw\t' + str(bleu_raw),
        'bleu_oi_raw\t' + str(bleu_oi_raw),
        'bleu_or_raw\t' + str(bleu_or_raw),
        # 'bleu_decode\t' + str(bleu_decode),
        # 'bleu_oi_decode\t' + str(bleu_oi_decode),
        # 'bleu_or_decode\t' + str(bleu_or_decode),
        'ibleu\t' + str(ibleu),
        'bleu_joshua\t' + str(bleu_joshua),
        'sari\t' + str(sari_joshua),
        'fkgl\t' + str(fkgl)
    ])

    # Output Result
    f = open((model_config.resultdir + '/step' + str(step) + '-bleuraw' +
              str(bleu_raw) + '-bleurawoi' + str(bleu_oi_raw) + '-bleurawor' +
              str(bleu_or_raw) + '-bleuj' + str(bleu_joshua) + '-perplexity' +
              str(perplexity) + '-bleunltk' + str(ibleu) + '-sari' +
              str(sari_joshua) + '-fkgl' + str(fkgl)),
             'w',
             encoding='utf-8')
    f.write(content)
    f.close()
    f = open((model_config.resultdir + '/step' + str(step) + '-bleuraw' +
              str(bleu_raw) + '-bleurawoi' + str(bleu_oi_raw) + '-bleurawor' +
              str(bleu_or_raw) + '-bleuj' + str(bleu_joshua) + '-perplexity' +
              str(perplexity) + '-bleunltk' + str(ibleu) + '-sari' +
              str(sari_joshua) + '-fkgl' + str(fkgl) + '.result'),
             'w',
             encoding='utf-8')
    f.write('\n'.join(decode_outputs_all))
    f.close()

    return sari_joshua
            graph.objs,
            model_config)

        # fetches = [graph.train_op, graph.loss, graph.global_step,
        #            graph.perplexity, graph.ops, graph.attn_dists, graph.targets, graph.cs]
        # _, loss, step, perplexity, _ops , attn_dists, targets, cs = sess.run(fetches, input_feed)
        fetches = [graph.train_op, graph.loss, graph.global_step,
                   graph.perplexity, graph.ops, graph.logits]
        _, loss, step, perplexity, _, logits = sess.run(fetches, input_feed)
        perplexitys.append(perplexity)

        if step % model_config.model_print_freq == 0:
            end_time = datetime.now()
            time_span = end_time - start_time
            start_time = end_time
            print('Perplexity:\t%f at step %d using %s.' % (perplexity, step, time_span))
            perplexitys.clear()

if __name__ == '__main__':
    config = None
    if args.mode == 'dummy':
        config = DefaultTrainConfig()
    elif args.mode == 'dressnew':
        config = WikiDressLargeNewTrainDefault()
    elif args.mode == 'wikihuge':
        config = WikiDressHugeNewTrainDefault()
    elif args.mode == 'dress':
        config = WikiDressLargeTrainDefault()
    print(list_config(config))
    train(config)