Esempio n. 1
0
def train(model_config=None):
    model_config = (DefaultConfig()
                    if model_config is None else model_config)
    data = TrainData(model_config)

    graph = None
    if model_config.framework == 'transformer':
        graph = TransformerGraph(data, True, model_config)
    elif model_config.framework == 'seq2seq':
        graph = Seq2SeqGraph(data, True, model_config)
    else:
        raise NotImplementedError('Unknown Framework.')
    graph.create_model_multigpu()

    ckpt_path = None
    if model_config.warm_start:
        if model_config.warm_start == 'recent':
            ckpt_path = find_best_ckpt(model_config)
        else:
            ckpt_path = model_config.warm_start
        var_list = slim.get_variables_to_restore()
    if ckpt_path is not None:
        # Handling missing vars by ourselves
        available_vars = {}
        reader = tf.train.NewCheckpointReader(ckpt_path)
        var_dict = {var.op.name: var for var in var_list}
        for var in var_dict:
            if 'global_step' in var:
                continue
            if 'optimization' in var:
                continue
            if reader.has_tensor(var):
                var_ckpt = reader.get_tensor(var)
                var_cur = var_dict[var]
                if any([var_cur.shape[i] != var_ckpt.shape[i] for i in range(len(var_ckpt.shape))]):
                    print('Variable %s missing due to shape.', var)
                else:
                    available_vars[var] = var_dict[var]
            else:
                print('Variable %s missing.', var)

        partial_restore_ckpt = slim.assign_from_checkpoint_fn(
            ckpt_path, available_vars,
            ignore_missing_vars=False, reshape_variables=False)

    def init_fn(session):
        # Restore ckpt either from warm start or automatically get when changing optimizer
        ckpt_path = None
        if model_config.warm_start:
            ckpt_path = model_config.warm_start

        if ckpt_path is not None:
            if model_config.use_partial_restore:
                partial_restore_ckpt(session)
            else:
                try:
                    graph.saver.restore(session, ckpt_path)
                except Exception as ex:
                    print('Fully restore failed, use partial restore instead. \n %s' % str(ex))
                    partial_restore_ckpt(session)

            print('Warm start with checkpoint %s' % ckpt_path)

    sv = tf.train.Supervisor(logdir=model_config.logdir,
                             global_step=graph.global_step,
                             saver=graph.saver,
                             init_fn=init_fn,
                             save_model_secs=model_config.save_model_secs)
    sess = sv.PrepareSession(config=session.get_session_config(model_config))
    perplexitys = []
    start_time = datetime.now()
    while True:
        input_feed = get_graph_train_data(
            data,
            graph.objs,
            model_config)

        # fetches = [graph.train_op, graph.loss, graph.global_step,
        #            graph.perplexity, graph.ops, graph.attn_dists, graph.targets, graph.cs]
        # _, loss, step, perplexity, _ops , attn_dists, targets, cs = sess.run(fetches, input_feed)
        fetches = [graph.train_op, graph.loss, graph.global_step,
                   graph.perplexity, graph.ops, graph.logits]
        _, loss, step, perplexity, _, logits = sess.run(fetches, input_feed)
        perplexitys.append(perplexity)

        if step % model_config.model_print_freq == 0:
            end_time = datetime.now()
            time_span = end_time - start_time
            start_time = end_time
            print('Perplexity:\t%f at step %d using %s.' % (perplexity, step, time_span))
            perplexitys.clear()
Esempio n. 2
0
def eval(model_config=None, ckpt=None):
    model_config = (DefaultConfig() if model_config is None else model_config)
    if not exists(model_config.resultdir):
        makedirs(model_config.resultdir)
    print(list_config(model_config))

    val_data = ValData(model_config)
    graph = None
    if model_config.framework == 'transformer':
        graph = TransformerGraph(val_data, False, model_config)
    elif model_config.framework == 'seq2seq':
        graph = Seq2SeqGraph(val_data, False, model_config)
    tf.reset_default_graph()
    graph.create_model_multigpu()

    ibleus_all = []
    perplexitys_all = []
    saris_all = []
    decode_outputs_all = []
    targets = []
    targets_raw = []
    sentence_simples = []
    sentence_complexs = []
    sentence_complexs_raw = []

    it = val_data.get_data_iter()

    def init_fn(session):
        graph.saver.restore(session, ckpt)
        print('Restore ckpt:%s.' % ckpt)

    sv = tf.train.Supervisor(init_fn=init_fn)
    sess = sv.PrepareSession(config=session.get_session_config(model_config))
    while True:
        is_finish = False
        (input_feed, output_sentence_simple, output_sentence_complex,
         output_sentence_complex_raw, output_sentence_complex_raw_lines,
         output_mapper, output_ref_raw_lines, out_effective_batch_size,
         output_is_end) = get_graph_val_data(graph.objs, model_config, it,
                                             val_data)

        postprocess = PostProcess(model_config, val_data)
        fetches = {
            'decoder_target_list':
            [obj['decoder_target_list'] for obj in graph.objs],
            'loss':
            graph.loss,
            'global_step':
            graph.global_step
        }
        if model_config.replace_unk_by_emb:
            print("########REPLACING UNKS########")
            fetches.update({
                'encoder_embs': [obj['encoder_embs'] for obj in graph.objs],
                'final_outputs': [obj['final_outputs'] for obj in graph.objs]
            })
        if model_config.replace_unk_by_attn:
            fetches.update(
                {'attn_distr': [obj['attn_distr'] for obj in graph.objs]})
        results = sess.run(fetches, input_feed)
        output_target, loss, step = (results['decoder_target_list'],
                                     results['loss'], results['global_step'])
        if model_config.replace_unk_by_emb:
            print("########REPLACING UNKS########")
            output_encoder_embs, output_final_outputs = results[
                'encoder_embs'], results['final_outputs']
        if model_config.replace_unk_by_attn:
            attn_distr = results['attn_distr']
        batch_perplexity = math.exp(loss)
        perplexitys_all.append(batch_perplexity)

        for i, effective_batch_size in enumerate(out_effective_batch_size):
            is_end = output_is_end[i]
            exclude_idxs = get_exclude_list(effective_batch_size,
                                            model_config.batch_size)
            sentence_simple = output_sentence_simple[i]
            sentence_complex = output_sentence_complex[i]
            sentence_complex_raw = output_sentence_complex_raw[i]
            sentence_complex_raw_lines = output_sentence_complex_raw_lines[i]
            mapper = output_mapper[i]
            ref_raw_lines = output_ref_raw_lines[i]

            target = output_target[i]
            if model_config.replace_unk_by_emb:
                encoder_embs = output_encoder_embs[i]
                final_outputs = output_final_outputs[i]

            if exclude_idxs:
                sentence_complex = exclude_list(sentence_complex, exclude_idxs)
                sentence_complex_raw = exclude_list(sentence_complex_raw,
                                                    exclude_idxs)
                sentence_complex_raw_lines = exclude_list(
                    sentence_complex_raw_lines, exclude_idxs)

                sentence_simple = exclude_list(sentence_simple, exclude_idxs)

                target = exclude_list(target, exclude_idxs)
                mapper = exclude_list(mapper, exclude_idxs)

                for ref_i in range(model_config.num_refs):
                    ref_raw_lines[ref_i] = exclude_list(
                        ref_raw_lines[ref_i], exclude_idxs)

            target = decode(target, val_data.vocab_simple,
                            model_config.subword_vocab_size > 0)
            target_raw = target

            sentence_complex_marker = [[
                val_data.vocab_simple.encode(w) ==
                val_data.vocab_simple.encode(constant.SYMBOL_UNK) for w in sent
            ] for sent in sentence_complex_raw]
            if model_config.replace_unk_by_attn:
                target_raw = postprocess.replace_unk_by_attn(
                    sentence_complex_raw, attn_distr[0], target_raw)
            elif model_config.replace_unk_by_emb:
                target_raw = postprocess.replace_unk_by_emb(
                    sentence_complex_raw, encoder_embs, final_outputs,
                    target_raw, sentence_complex_marker)
            elif model_config.replace_unk_by_cnt:
                target_raw = postprocess.replace_unk_by_cnt(
                    sentence_complex_raw, target_raw)
            if model_config.replace_ner:
                target_raw = postprocess.replace_ner(target_raw, mapper)
            target_raw = postprocess.replace_others(target_raw)
            sentence_simple = decode(sentence_simple, val_data.vocab_simple,
                                     model_config.subword_vocab_size > 0)
            sentence_complex = decode(sentence_complex, val_data.vocab_complex,
                                      model_config.subword_vocab_size > 0)

            # Replace UNK for sentence_complex_raw and ref_raw
            # Note that sentence_complex_raw_lines and ref_raw_lines are original file lines
            sentence_complex_raw = postprocess.replace_ner(
                sentence_complex_raw, mapper)
            sentence_complex_raw = truncate_sents(sentence_complex_raw)

            # Truncate decode results
            target = truncate_sents(target)
            target_raw = truncate_sents(target_raw)
            sentence_simple = truncate_sents(sentence_simple)
            sentence_complex = truncate_sents(sentence_complex)

            targets.extend(target)
            targets_raw.extend(target_raw)
            sentence_simples.extend(sentence_simple)
            sentence_complexs.extend(sentence_complex)
            sentence_complexs_raw.extend(sentence_complex_raw)

            ibleus = []
            saris = []
            fkgls = []

            for batch_i in range(effective_batch_size):
                # Compute iBLEU
                try:
                    batch_ibleu = sentence_bleu([sentence_simple[batch_i]],
                                                target[batch_i])
                except Exception as e:
                    print('Bleu error:\t' + str(e) + '\n' +
                          str(target[batch_i]) + '\n')
                    batch_ibleu = 0
                ibleus_all.append(batch_ibleu)
                ibleus.append(batch_ibleu)

                # Compute SARI
                batch_sari = 0
                if model_config.num_refs > 0:
                    rsents = []
                    for ref_i in range(model_config.num_refs):
                        rsents.append(ref_raw_lines[ref_i][batch_i])
                    try:
                        batch_sari = SARIsent(
                            sentence_complex_raw_lines[batch_i],
                            ' '.join(target_raw[batch_i]), rsents)
                    except:
                        print('sari error: %s \n %s \n %s. \n' %
                              (sentence_complex_raw_lines[batch_i], ' '.join(
                                  target_raw[batch_i]), rsents))
                saris.append(batch_sari)
                saris_all.append(batch_sari)

                # Compute FKGL
                target_text = ' '.join(target_raw[batch_i])
                batch_fkgl = 0
                if len(target_text) > 0:
                    batch_fkgl = get_fkgl(' '.join(target_raw[batch_i]))
                fkgls.append(batch_fkgl)

            # target_output = decode_to_output(target, sentence_simple, sentence_complex,
            #                                  effective_batch_size, ibleus, target_raw, sentence_complex_raw,
            #                                  saris, fkgls)
            target_output = decode_to_output(target, sentence_simple,
                                             sentence_complex,
                                             effective_batch_size, ibleus,
                                             target_raw, sentence_complex_raw,
                                             saris, fkgls, ref_raw_lines,
                                             model_config)
            decode_outputs_all.append(target_output)

            if is_end:
                is_finish = True
                break

        if is_finish:
            break

    ibleu = np.mean(ibleus_all)
    perplexity = np.mean(perplexitys_all)
    sari = np.mean(saris_all)
    # Compute FKGL in Corpus level
    fkgl = CorpusFKGL(model_config).get_fkgl_from_joshua(step, targets_raw)

    print('Current iBLEU: \t%f' % ibleu)
    print('Current SARI: \t%f' % sari)
    print('Current FKGL: \t%f' % fkgl)
    print('Current perplexity: \t%f' % perplexity)
    print('Current eval done!')
    # MtEval Result
    mteval = MtEval_BLEU(model_config)

    # MtEval Result - Decode
    # bleu_oi_decode = mteval.get_bleu_from_decoderesult(step, sentence_complexs, sentence_simples, targets)
    # bleu_or_decode = bleu_oi_decode
    # if model_config.num_refs > 0:
    #     path_ref = model_config.val_dataset_simple_folder + model_config.val_dataset_simple_references
    #     #Decode evaluation must be lowercase because the processed files are all lowercased
    #     bleu_or_decode = mteval.get_bleu_from_decoderesult_multirefs(step, path_ref, targets,
    #                                                                  lowercase=True)
    # if model_config.num_refs > 0:
    #     bleu_decode = 0.9 * bleu_or_decode + 0.1 * bleu_oi_decode
    # else:
    #     bleu_decode = bleu_oi_decode
    # print('Current Mteval iBLEU decode: \t%f' % bleu_decode)

    # MtEval Result - raw
    bleu_oi_raw = mteval.get_bleu_from_rawresult(step, targets_raw)
    bleu_or_raw = bleu_oi_raw
    if model_config.num_refs > 0:
        path_ref = model_config.val_dataset_simple_folder + model_config.val_dataset_simple_rawlines_file_references
        bleu_or_raw = mteval.get_bleu_from_decoderesult_multirefs(
            step, path_ref, targets_raw, lowercase=model_config.lower_case)
    if model_config.num_refs > 0:
        bleu_raw = 0.9 * bleu_or_raw + 0.1 * bleu_oi_raw
    else:
        bleu_raw = bleu_oi_raw
    print('Current Mteval iBLEU raw: \t%f' % bleu_raw)

    bleu_joshua = mteval.get_bleu_from_joshua(
        step, model_config.val_dataset_simple_folder +
        model_config.val_dataset_simple_rawlines_file,
        model_config.val_dataset_simple_folder +
        model_config.val_dataset_simple_rawlines_file_references, targets_raw)

    # Use corpus-level sari
    corpus_sari = CorpusSARI(model_config)
    sari_joshua = corpus_sari.get_sari_from_joshua(
        step, model_config.val_dataset_simple_folder +
        model_config.val_dataset_simple_rawlines_file,
        model_config.val_dataset_simple_folder +
        model_config.val_dataset_simple_rawlines_file_references,
        model_config.val_dataset_complex_rawlines_file, target_raw)

    decimal_cnt = 5
    format = "%." + str(decimal_cnt) + "f"
    bleu_raw = format % bleu_raw
    bleu_oi_raw = format % bleu_oi_raw
    bleu_or_raw = format % bleu_or_raw
    # bleu_decode = format % bleu_decode
    # bleu_oi_decode = format % bleu_oi_decode
    # bleu_or_decode = format % bleu_or_decode
    ibleu = format % ibleu
    bleu_joshua = format % bleu_joshua
    sari_joshua = format % sari_joshua
    fkgl = format % fkgl
    perplexity = format % perplexity

    content = '\n'.join([
        'bleu_raw\t' + str(bleu_raw),
        'bleu_oi_raw\t' + str(bleu_oi_raw),
        'bleu_or_raw\t' + str(bleu_or_raw),
        # 'bleu_decode\t' + str(bleu_decode),
        # 'bleu_oi_decode\t' + str(bleu_oi_decode),
        # 'bleu_or_decode\t' + str(bleu_or_decode),
        'ibleu\t' + str(ibleu),
        'bleu_joshua\t' + str(bleu_joshua),
        'sari\t' + str(sari_joshua),
        'fkgl\t' + str(fkgl)
    ])

    # Output Result
    f = open((model_config.resultdir + '/step' + str(step) + '-bleuraw' +
              str(bleu_raw) + '-bleurawoi' + str(bleu_oi_raw) + '-bleurawor' +
              str(bleu_or_raw) + '-bleuj' + str(bleu_joshua) + '-perplexity' +
              str(perplexity) + '-bleunltk' + str(ibleu) + '-sari' +
              str(sari_joshua) + '-fkgl' + str(fkgl)),
             'w',
             encoding='utf-8')
    f.write(content)
    f.close()
    f = open((model_config.resultdir + '/step' + str(step) + '-bleuraw' +
              str(bleu_raw) + '-bleurawoi' + str(bleu_oi_raw) + '-bleurawor' +
              str(bleu_or_raw) + '-bleuj' + str(bleu_joshua) + '-perplexity' +
              str(perplexity) + '-bleunltk' + str(ibleu) + '-sari' +
              str(sari_joshua) + '-fkgl' + str(fkgl) + '.result'),
             'w',
             encoding='utf-8')
    f.write('\n'.join(decode_outputs_all))
    f.close()

    return sari_joshua
Esempio n. 3
0
def train(model_config=None):
    model_config = (DefaultConfig() if model_config is None else model_config)

    if model_config.fetch_mode == 'tf_example_dataset':
        data = TfExampleTrainDataset(model_config)
    else:
        data = TrainData(model_config)

    if model_config.framework == 'transformer':
        graph = TransformerGraph(data, True, model_config)
    elif model_config.framework == 'seq2seq':
        graph = Seq2SeqGraph(data, True, model_config)
    else:
        raise NotImplementedError('Unknown Framework.')
    graph.create_model_multigpu()

    ckpt_path = None
    if model_config.warm_start:
        ckpt_path = model_config.warm_start
        var_list = slim.get_variables_to_restore()
    if ckpt_path is not None:
        # Handling missing vars by ourselves
        available_vars = {}
        reader = tf.train.NewCheckpointReader(ckpt_path)
        var_dict = {var.op.name: var for var in var_list}
        for var in var_dict:
            if 'global_step' in var and 'optim' not in model_config.warm_config:
                print('Ignore var:', var)
                continue
            if 'optimization' in var and 'optim' not in model_config.warm_config:
                print('Ignore var:', var)
                continue
            if reader.has_tensor(var):
                var_ckpt = reader.get_tensor(var)
                var_cur = var_dict[var]
                if any([
                        var_cur.shape[i] != var_ckpt.shape[i]
                        for i in range(len(var_ckpt.shape))
                ]):
                    print('Variable missing due to shape.', var)
                else:
                    available_vars[var] = var_dict[var]
            else:
                print('Variable missing:', var)

        partial_restore_ckpt = slim.assign_from_checkpoint_fn(
            ckpt_path,
            available_vars,
            ignore_missing_vars=False,
            reshape_variables=False)

    if model_config.bert_mode:
        bert_restore_ckpt = utils.restore_bert(ckpt=model_config.bert_ckpt)

    if 'direct' in model_config.memory:
        bert_direct_restore_ckpt = utils.restore_bert(
            ckpt=model_config.bert_ckpt, model='direct/')

    sess = tf.train.MonitoredTrainingSession(
        checkpoint_dir=model_config.logdir,
        save_checkpoint_secs=model_config.save_model_secs,
        config=session.get_session_config(model_config),
        hooks=[
            tf.train.CheckpointSaverHook(
                model_config.logdir,
                save_secs=model_config.save_model_secs,
                saver=graph.saver)
        ],
        save_summaries_steps=None,
        save_summaries_secs=None,  # Disable tf.summary
    )

    if checkpoint.is_fresh_run(
            model_config.logdir) and 'init' in model_config.bert_mode:
        if model_config.bert_mode:
            if 'direct' in model_config.memory:
                bert_direct_restore_ckpt(sess)
            # else:
            bert_restore_ckpt(sess)
            print('BERT init')

    if checkpoint.is_fresh_run(model_config.logdir):
        if ckpt_path is not None:
            partial_restore_ckpt(sess)
            print('Restore from %s' % ckpt_path)

    perplexitys = []
    start_time = datetime.now()

    # Intialize tf example dataset reader
    if model_config.fetch_mode == 'tf_example_dataset':
        if model_config.dmode == 'listalter':
            assert type(data.training_init_op) == list
            for init_op in data.training_init_op:
                sess.run(init_op)
        else:
            sess.run(data.training_init_op)
            print('Init dataset interator.')
            if model_config.dmode == 'alter':
                sess.run(data.training_init_op2)
                print('Init dataset2 interator.')

    # with tf.contrib.tfprof.ProfileContext('/zfs1/hdaqing/saz31/text_simplification_0924/bertbaseal2_ls/profile') as pctx:
    while True:
        fetches = [
            graph.train_op, graph.loss, graph.global_step, graph.perplexity,
            graph.ops, graph.increment_global_step, graph.loss_style
        ]
        if model_config.fetch_mode:
            _, loss, step, perplexity, _, _, loss_style = sess.run(fetches)
        else:
            input_feed = get_graph_train_data(data, graph.objs, model_config)
            _, loss, step, perplexity, _, _ = sess.run(fetches, input_feed)
        perplexitys.append(perplexity)

        if step % model_config.model_print_freq == 0:
            end_time = datetime.now()
            time_span = end_time - start_time
            start_time = end_time
            print('Perplexity:\t%f at step %d using %s.' %
                  (perplexity, step, time_span))
            if 'pred' in model_config.tune_mode:
                print('Loss:%s\tLoss_tyle:%s' % (loss, loss_style))
            perplexitys.clear()
            if step / model_config.model_print_freq == 1:
                print_cpu_usage()
                print_cpu_memory()
                print_gpu_memory()

        #if step % (100 * model_config.model_print_freq) == 0:
        #    graph.saver.save(sess, join(model_config.logdir, 'bk.ckpt-', step))

        if model_config.model_eval_freq > 0 and step % model_config.model_eval_freq == 0:
            if args.mode == 'dress':
                from model.model_config import WikiDressLargeDefault, WikiDressLargeEvalDefault, \
                    WikiDressLargeTestDefault
                model_config = WikiDressLargeDefault()
                ckpt = get_ckpt(model_config.modeldir, model_config.logdir)

                vconfig = WikiDressLargeEvalDefault()
                best_sari = get_best_sari(vconfig.resultdir)
                sari_point = eval(vconfig, ckpt)
                eval(WikiDressLargeTestDefault(), ckpt)
                if args.memory is not None and 'rule' in args.memory:
                    for rcand in [15, 30, 50]:
                        vconfig.max_cand_rules = rcand
                        vconfig.resultdir = get_path(
                            '../' + vconfig.output_folder +
                            '/result/eightref_val_cand' + str(rcand),
                            vconfig.environment)
                        eval(vconfig, ckpt)
                print(
                    '=====================Current Best SARI:%s====================='
                    % best_sari)
                if float(sari_point) < best_sari:
                    remove(ckpt + '.index')
                    remove(ckpt + '.meta')
                    remove(ckpt + '.data-00000-of-00001')
                    print('remove ckpt:%s' % ckpt)
                else:
                    for file in listdir(model_config.modeldir):
                        step = ckpt[ckpt.rindex('model.ckpt-') +
                                    len('model.ckpt-'):-1]
                        if step not in file:
                            remove(model_config.modeldir + file)
                    print('Get Best Model, remove ckpt except:%s.' % ckpt)