示例#1
0
    best_score = 0.
    for epoch in range(config['max_epoch']):
        for tr_data in tr_stream.get_epoch_iterator():
            batch_count += 1
            tr_fn(*tr_data)

            # sample
            if batch_count % config['sampling_freq'] == 0:
                trans_sample(tr_data[0], tr_data[2], f_init, f_next, config['hook_samples'],
                             src_vocab_reverse, trg_vocab_reverse, batch_count)

            # trans valid data set
            if batch_count > config['val_burn_in'] and batch_count % config['bleu_val_freq'] == 0:
                logger.info('[{}]: {} has been tackled and start translate val set!'.format(epoch, batch_count))
                val_time += 1
                val_save_out = '{}.{}'.format(config['val_set_out'], val_time)
                val_save_file = open(val_save_out, 'w')
                data_iter = dev_stream.get_epoch_iterator()
                trans_res = multi_process_sample(data_iter, f_init, f_next, k=12, vocab=trg_vocab_reverse, process=1)
                val_save_file.writelines(trans_res)
                val_save_file.close()
                logger.info('[{}]: {} times val has been translated!'.format(epoch, val_time))

                bleu_score = valid_bleu(config['eval_dir'], val_save_out)
                os.rename(val_save_out, "{}.{}.txt".format(val_save_out, bleu_score))
                if bleu_score > best_score:
                    trans.savez(config['saveto']+'/params.npz')
                    best_score = bleu_score
                    logger.info('epoch:{}, batchs:{}, bleu_score:{}'.format(
                        epoch, batch_count, best_score))
示例#2
0
    params = trans.params
    print params[0].get_value().sum()

    logger.info("begin to build sample model : f_init, f_next")
    f_init, f_next = trans.build_sample()
    logger.info("end build sample model : f_init, f_next")

    src_vocab = pickle.load(open(config["src_vocab"]))
    trg_vocab = pickle.load(open(config["trg_vocab"]))
    src_vocab = ensure_special_tokens(
        src_vocab, bos_idx=0, eos_idx=config["src_vocab_size"] - 1, unk_idx=config["unk_id"]
    )
    trg_vocab = ensure_special_tokens(
        trg_vocab, bos_idx=0, eos_idx=config["src_vocab_size"] - 1, unk_idx=config["unk_id"]
    )
    trg_vocab_reverse = {index: word for word, index in trg_vocab.iteritems()}
    src_vocab_reverse = {index: word for word, index in src_vocab.iteritems()}
    logger.info("load dict finished ! src dic size : {} trg dic size : {}.".format(len(src_vocab), len(trg_vocab)))

    # val_set=sys.argv[1]
    # config['val_set']=val_set
    dev_stream = get_dev_stream(**config)
    logger.info("start training!!!")
    trans.load(config["saveto"] + "/params.npz")

    val_save_file = open("trans", "w")
    data_iter = dev_stream.get_epoch_iterator()
    trans = multi_process_sample(data_iter, f_init, f_next, k=10, vocab=trg_vocab_reverse, process=1, normalize=False)
    val_save_file.writelines(trans)
    val_save_file.close()