Exemplo n.º 1
0
def main(args):
    model_path = args.model_path
    hparams.set_hparam('batch_size', 1)
    hparams.add_hparam('is_training', False)
    check_vocab(args)
    src_placeholder = tf.placeholder(shape=[None], dtype=tf.string)
    datasets = load_dataset(args, src_placeholder)
    iterator = iterator_utils.get_inference_iterator(hparams, datasets)
    src_vocab, tgt_vocab, _, tgt_reverse_vocab, src_vocab_size, tgt_vocab_size = datasets
    hparams.add_hparam('vocab_size_source', src_vocab_size)
    hparams.add_hparam('vocab_size_target', tgt_vocab_size)

    sess, model = load_model(hparams, tf.contrib.learn.ModeKeys.INFER, iterator, src_vocab, tgt_vocab, tgt_reverse_vocab)

    ckpt = tf.train.latest_checkpoint(args.model_path)
    saver = tf.train.Saver(tf.global_variables(), max_to_keep=5)
    if ckpt:
        saver.restore(sess, ckpt)
    else:
        raise Exception("can not found checkpoint file")

    src_vocab_file = os.path.join(model_path, 'vocab.src')
    src_reverse_vocab = build_reverse_vocab_table(src_vocab_file, hparams)
    sess.run(tf.tables_initializer())

    index = 1
    inputs = np.array(get_data(args), dtype=np.str)
    with sess:
        logger.info("starting inference...")
        sess.run(iterator.initializer, feed_dict={src_placeholder: inputs})
        eos = hparams.eos.encode()
        pad = hparams.pad.encode()
        while True:
            try:
                predictions, confidence, source = model.inference(sess)
                source_sent = src_reverse_vocab.lookup(tf.constant(list(source[0]), tf.int64))
                source_sent = sess.run(source_sent)
                print(index, text_utils.format_bpe_text(source_sent, [eos, pad]))
                if hparams.beam_width == 1:
                    print(bytes2sent(list(predictions[0]), [eos, pad]))
                else:
                    print(bytes2sent(list(predictions[0][:, 0]), [eos, pad]))
                if confidence is not None:
                    print(confidence[0])
                print()
                if index > args.max_data_size:
                    break
                index += 1
            except tf.errors.OutOfRangeError:
                logger.info('Done inference')
                break
Exemplo n.º 2
0
def main():
    # Pulled from tacotron-2's synthesize.py
    hparams.add_hparam('max_abs_value', 4.0)
    hparams.add_hparam('power', 1.1)
    hparams.add_hparam('outputs_per_step', 1)

    # Do all the rest
    apply_hparams(hparams)

    from tacotron.synthesize import run_eval
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    output_dir = 'tacotron_' + 'output/'

    # try:
    checkpoint_path = tf.train.get_checkpoint_state(
        '/opt/Tacotron-2/logs-Tacotron/pretrained').model_checkpoint_path
    print('loaded model at {}'.format(checkpoint_path))
    #except:
    #    raise AssertionError('Cannot restore checkpoint: {}, did you train a model?'.format(args.checkpoint))

    run_eval(None, checkpoint_path, output_dir, 'Hello, Tim Winter.')
load_path = 'hccho-ckpt\\DCT-2019-08-03_08-59-35'
#####

load_path, restore_path, checkpoint_path = prepare_dirs(hp, load_path)

#### log 파일 작성
log_path = os.path.join(load_path, 'train.log')
infolog.init(log_path, hp.model_name)

infolog.set_tf_log(load_path)  # Estimator --> log 저장
tf.logging.set_verbosity(tf.logging.INFO)  # 이게 있어야 train log가 출력된다.

# load data
inputs, targets, word_to_index, index_to_word, VOCAB_SIZE, INPUT_LENGTH, OUTPUT_LENGTH = load_data(
    hp)  # (50000, 29), (50000, 12)
hp.add_hparam('VOCAB_SIZE', VOCAB_SIZE)
hp.add_hparam('INPUT_LENGTH', INPUT_LENGTH)  # 29
hp.add_hparam('OUTPUT_LENGTH', OUTPUT_LENGTH)  # 11

train_input, test_input, train_target, test_target = train_test_split(
    inputs, targets, test_size=0.1, random_state=13371447)

datfeeder = DataFeeder(train_input,
                       train_target,
                       test_input,
                       test_target,
                       batch_size=hp.BATCH_SIZE,
                       num_epoch=hp.NUM_EPOCHS)


def seq_accuracy(
Exemplo n.º 4
0
Arquivo: eval.py Projeto: wyb330/nmt
def main(args, max_data_size=0, shuffle=True, display=False):
    hparams.set_hparam('batch_size', 10)
    hparams.add_hparam('is_training', False)
    check_vocab(args)
    datasets, src_data_size = load_dataset(args)
    iterator = iterator_utils.get_eval_iterator(hparams, datasets, hparams.eos, shuffle=shuffle)
    src_vocab, tgt_vocab, src_dataset, tgt_dataset, tgt_reverse_vocab, src_vocab_size, tgt_vocab_size = datasets
    hparams.add_hparam('vocab_size_source', src_vocab_size)
    hparams.add_hparam('vocab_size_target', tgt_vocab_size)

    sess, model = load_model(hparams, tf.contrib.learn.ModeKeys.EVAL, iterator, src_vocab, tgt_vocab, tgt_reverse_vocab)

    if args.restore_step:
        checkpoint_path = os.path.join(args.model_path, 'nmt.ckpt')
        ckpt = '%s-%d' % (checkpoint_path, args.restore_step)
    else:
        ckpt = tf.train.latest_checkpoint(args.model_path)
    saver = tf.train.Saver(tf.global_variables(), max_to_keep=5)
    if ckpt:
        saver.restore(sess, ckpt)
    else:
        raise Exception("can not found checkpoint file")

    src_vocab_file = os.path.join(args.model_path, 'vocab.src')
    src_reverse_vocab = build_reverse_vocab_table(src_vocab_file, hparams)
    sess.run(tf.tables_initializer())

    step_count = 1
    with sess:
        logger.info("starting evaluating...")
        sess.run(iterator.initializer)
        eos = hparams.eos.encode()
        references = []
        translations = []
        start_time = time.time()
        while True:
            try:
                if (max_data_size > 0) and (step_count * hparams.batch_size > max_data_size):
                    break
                if step_count % 10 == 0:
                    t = time.time() - start_time
                    logger.info('step={0} total={1} time={2:.3f}'.format(step_count, step_count * hparams.batch_size, t))
                    start_time = time.time()
                predictions, source, target, source_text, confidence = model.eval(sess)
                reference = bpe2sent(target, eos)
                if hparams.beam_width == 1:
                    translation = bytes2sent(list(predictions), eos)
                else:
                    translation = bytes2sent(list(predictions[:, 0]), eos)

                for s, r, t in zip(source, reference, translation):
                    if display:
                        source_sent = src_reverse_vocab.lookup(tf.constant(list(s), tf.int64))
                        source_sent = sess.run(source_sent)
                        source_sent = text_utils.format_bpe_text(source_sent, eos)
                        print('{}\n{}\n{}\n'.format(source_sent, r, t))
                    references.append(r)
                    translations.append(t)

                if step_count % 100 == 0:
                    bleu_score = moses_multi_bleu(references, translations, args.model_path)
                    logger.info('bleu score = {0:.3f}'.format(bleu_score))

                step_count += 1
            except tf.errors.OutOfRangeError:
                logger.info('Done eval data')
                break

        logger.info('compute bleu score...')
        # bleu_score = compute_bleu_score(references, translations)
        bleu_score = moses_multi_bleu(references, translations, args.model_path)
        logger.info('bleu score = {0:.3f}'.format(bleu_score))
Exemplo n.º 5
0
def main(args, max_data_size=0):
    vocab_dir = args.vocab_dir
    log_file_handler = logging.FileHandler(os.path.join(vocab_dir, 'train.log'))
    logger.addHandler(log_file_handler)

    check_vocab(args, vocab_dir)
    datasets = load_dataset(args, vocab_dir)
    iterator = iterator_utils.get_iterator(hparams, datasets, max_rows=max_data_size)
    src_vocab, tgt_vocab, _, _, src_vocab_size, tgt_vocab_size = datasets
    hparams.add_hparam('is_training', True)
    hparams.add_hparam('vocab_size_source', src_vocab_size)
    hparams.add_hparam('vocab_size_target', tgt_vocab_size)
    pprint(hparams.values())
    sess, model = load_model(hparams, tf.contrib.learn.ModeKeys.TRAIN, iterator, src_vocab, tgt_vocab)

    if args.restore_step > 0:
        checkpoint_path = os.path.join(vocab_dir, 'nmt.ckpt')
        ckpt = '%s-%d' % (checkpoint_path, hparams.restore_step)
    else:
        ckpt = tf.train.latest_checkpoint(vocab_dir)
    saver = tf.train.Saver(tf.global_variables(), max_to_keep=5)
    if ckpt:
        saver.restore(sess, ckpt)
    else:
        sess.run(tf.global_variables_initializer())
        print("Created model with fresh parameters.")

    sess.run(tf.tables_initializer())
    with sess:
        writer = tf.summary.FileWriter(vocab_dir, sess.graph)
        logger.info("starting training...")
        epochs = 1
        step_in_epoch = 0
        learning_rate = hparams.learning_rate
        checkpoint_path = os.path.join(vocab_dir, "nmt.ckpt")

        sess.run(iterator.initializer)
        while epochs <= args.num_train_epochs:
            start_time = time.time()
            try:
                loss, global_step, learning_rate, accuracy, summary = model.step(sess)
                step_in_epoch += 1
                if global_step % args.summary_per_steps == 0:
                    write_summary(writer, summary, global_step)

            except tf.errors.OutOfRangeError:
                logger.info('{} epochs finished'.format(epochs))
                # saver.save(sess, checkpoint_path, global_step=global_step)
                epochs += 1
                step_in_epoch = 1
                sess.run(iterator.initializer)
                continue

            sec_per_step = time.time() - start_time
            logger.info("Epoch %-3d Step %-d - %-d [%.3f sec, loss=%.4f, acc=%.3f, lr=%f]" %
                        (epochs, global_step, step_in_epoch, sec_per_step, loss, accuracy, learning_rate))

            if global_step % args.steps_per_checkpoint == 0:
                model_checkpoint_path = saver.save(sess, checkpoint_path, global_step=global_step)
                logger.info("Saved checkpoint to {}".format(model_checkpoint_path))

            if math.isnan(loss) or math.isinf(loss):
                raise Exception('loss overflow')

        writer.close()