best_score = 0. for epoch in range(config['max_epoch']): for tr_data in tr_stream.get_epoch_iterator(): batch_count += 1 tr_fn(*tr_data) # sample if batch_count % config['sampling_freq'] == 0: trans_sample(tr_data[0], tr_data[2], f_init, f_next, config['hook_samples'], src_vocab_reverse, trg_vocab_reverse, batch_count) # trans valid data set if batch_count > config['val_burn_in'] and batch_count % config['bleu_val_freq'] == 0: logger.info('[{}]: {} has been tackled and start translate val set!'.format(epoch, batch_count)) val_time += 1 val_save_out = '{}.{}'.format(config['val_set_out'], val_time) val_save_file = open(val_save_out, 'w') data_iter = dev_stream.get_epoch_iterator() trans_res = multi_process_sample(data_iter, f_init, f_next, k=12, vocab=trg_vocab_reverse, process=1) val_save_file.writelines(trans_res) val_save_file.close() logger.info('[{}]: {} times val has been translated!'.format(epoch, val_time)) bleu_score = valid_bleu(config['eval_dir'], val_save_out) os.rename(val_save_out, "{}.{}.txt".format(val_save_out, bleu_score)) if bleu_score > best_score: trans.savez(config['saveto']+'/params.npz') best_score = bleu_score logger.info('epoch:{}, batchs:{}, bleu_score:{}'.format( epoch, batch_count, best_score))
params = trans.params print params[0].get_value().sum() logger.info("begin to build sample model : f_init, f_next") f_init, f_next = trans.build_sample() logger.info("end build sample model : f_init, f_next") src_vocab = pickle.load(open(config["src_vocab"])) trg_vocab = pickle.load(open(config["trg_vocab"])) src_vocab = ensure_special_tokens( src_vocab, bos_idx=0, eos_idx=config["src_vocab_size"] - 1, unk_idx=config["unk_id"] ) trg_vocab = ensure_special_tokens( trg_vocab, bos_idx=0, eos_idx=config["src_vocab_size"] - 1, unk_idx=config["unk_id"] ) trg_vocab_reverse = {index: word for word, index in trg_vocab.iteritems()} src_vocab_reverse = {index: word for word, index in src_vocab.iteritems()} logger.info("load dict finished ! src dic size : {} trg dic size : {}.".format(len(src_vocab), len(trg_vocab))) # val_set=sys.argv[1] # config['val_set']=val_set dev_stream = get_dev_stream(**config) logger.info("start training!!!") trans.load(config["saveto"] + "/params.npz") val_save_file = open("trans", "w") data_iter = dev_stream.get_epoch_iterator() trans = multi_process_sample(data_iter, f_init, f_next, k=10, vocab=trg_vocab_reverse, process=1, normalize=False) val_save_file.writelines(trans) val_save_file.close()