def train(): """Train a en->fr translation model using WMT data.""" #with tf.device("/gpu:0"): # Prepare WMT data. train_path = os.path.join(FLAGS.data_dir, "chitchat.train") fixed_path = os.path.join(FLAGS.data_dir, "chitchat.fixed") weibo_path = os.path.join(FLAGS.data_dir, "chitchat.weibo") qa_path = os.path.join(FLAGS.data_dir, "chitchat.qa") voc_file_path = [ train_path + ".answer", fixed_path + ".answer", weibo_path + ".answer", qa_path + ".answer", train_path + ".query", fixed_path + ".query", weibo_path + ".query", qa_path + ".query" ] vocab_path = os.path.join(FLAGS.data_dir, "vocab%d.all" % FLAGS.vocab_size) data_utils.create_vocabulary(vocab_path, voc_file_path, FLAGS.vocab_size) vocab, rev_vocab = data_utils.initialize_vocabulary(vocab_path) print("Preparing Chitchat data in %s" % FLAGS.data_dir) train_query, train_answer, dev_query, dev_answer = data_utils.prepare_chitchat_data( FLAGS.data_dir, vocab, FLAGS.vocab_size) print("Preparing Fixed data in %s" % FLAGS.fixed_set_path) fixed_path = os.path.join(FLAGS.fixed_set_path, "chitchat.fixed") fixed_query, fixed_answer = data_utils.prepare_defined_data( fixed_path, vocab, FLAGS.vocab_size) print("Preparing Weibo data in %s" % FLAGS.weibo_set_path) weibo_path = os.path.join(FLAGS.weibo_set_path, "chitchat.weibo") weibo_query, weibo_answer = data_utils.prepare_defined_data( weibo_path, vocab, FLAGS.vocab_size) print("Preparing QA data in %s" % FLAGS.qa_set_path) qa_path = os.path.join(FLAGS.qa_set_path, "chitchat.qa") qa_query, qa_answer = data_utils.prepare_defined_data( qa_path, vocab, FLAGS.vocab_size) dummy_path = os.path.join(FLAGS.data_dir, "chitchat.dummy") dummy_set = data_utils.get_dummy_set(dummy_path, vocab, FLAGS.vocab_size) print("Get Dummy Set : ", dummy_set) with tf.Session() as sess: #with tf.device("/gpu:1"): # Create model. print("Creating %d layers of %d units." % (FLAGS.num_layers, FLAGS.size)) model = create_model(sess, dummy_set, False) # Read data into buckets and compute their sizes. print("Reading development and training data (limit: %d)." % FLAGS.max_train_data_size) dev_set = read_data(dev_query, dev_answer) train_set = read_data(train_query, train_answer, FLAGS.max_train_data_size) fixed_set = read_data(fixed_query, fixed_answer, FLAGS.max_train_data_size) weibo_set = read_data(weibo_query, weibo_answer, FLAGS.max_train_data_size) qa_set = read_data(qa_query, qa_answer, FLAGS.max_train_data_size) train_bucket_sizes = [len(train_set[b]) for b in xrange(len(_buckets))] train_total_size = float(sum(train_bucket_sizes)) train_buckets_scale = [ sum(train_bucket_sizes[:i + 1]) / train_total_size for i in xrange(len(train_bucket_sizes)) ] # This is the training loop. step_time, loss = 0.0, 0.0 current_step = 0 previous_losses = [] en_dict_cover = {} fr_dict_cover = {} if model.global_step.eval() > FLAGS.steps_per_checkpoint: try: with open(FLAGS.en_cover_dict_path, "rb") as ef: en_dict_cover = pickle.load(ef) # for line in ef.readlines(): # line = line.strip() # key, value = line.strip(",") # en_dict_cover[int(key)]=int(value) except Exception: print("no find query_cover_file") try: with open(FLAGS.ff_cover_dict_path, "rb") as ff: fr_dict_cover = pickle.load(ff) # for line in ff.readlines(): # line = line.strip() # key, value = line.strip(",") # fr_dict_cover[int(key)]=int(value) except Exception: print("no find answer_cover_file") step_loss_summary = tf.Summary() #merge = tf.merge_all_summaries() writer = tf.summary.FileWriter("../logs/", sess.graph) while True: # Choose a bucket according to data distribution. We pick a random number # in [0, 1] and use the corresponding interval in train_buckets_scale. random_number_01 = np.random.random_sample() bucket_id = min([ i for i in xrange(len(train_buckets_scale)) if train_buckets_scale[i] > random_number_01 ]) # Get a batch and make a step. start_time = time.time() encoder_inputs, decoder_inputs, target_weights, batch_source_encoder, batch_source_decoder = model.get_batch( train_set, bucket_id, 0, fixed_set, weibo_set, qa_set) if FLAGS.reinforce_learning: _, step_loss, _ = model.step_rl(sess, _buckets, encoder_inputs, decoder_inputs, target_weights, batch_source_encoder, batch_source_decoder, bucket_id) else: _, step_loss, _ = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, forward_only=False, force_dec_input=True) step_time += (time.time() - start_time) / FLAGS.steps_per_checkpoint loss += step_loss / FLAGS.steps_per_checkpoint current_step += 1 query_size, answer_size = _buckets[bucket_id] for batch_index in xrange(FLAGS.batch_size): for query_index in xrange(query_size): query_word = encoder_inputs[query_index][batch_index] if en_dict_cover.has_key(query_word): en_dict_cover[query_word] += 1 else: en_dict_cover[query_word] = 0 for answer_index in xrange(answer_size): answer_word = decoder_inputs[answer_index][batch_index] if fr_dict_cover.has_key(answer_word): fr_dict_cover[answer_word] += 1 else: fr_dict_cover[answer_word] = 0 # Once in a while, we save checkpoint, print statistics, and run evals. if current_step % FLAGS.steps_per_checkpoint == 0: bucket_value = step_loss_summary.value.add() bucket_value.tag = "loss" bucket_value.simple_value = float(loss) writer.add_summary(step_loss_summary, current_step) print("query_dict_cover_num: %s" % (str(en_dict_cover.__len__()))) print("answer_dict_cover_num: %s" % (str(fr_dict_cover.__len__()))) ef = open(FLAGS.en_cover_dict_path, "wb") pickle.dump(en_dict_cover, ef) ff = open(FLAGS.ff_cover_dict_path, "wb") pickle.dump(fr_dict_cover, ff) # Print statistics for the previous epoch. perplexity = math.exp(loss) if loss < 300 else float('inf') print( "global step %d learning rate %.4f step-time %.2f perplexity " "%.2f" % (model.global_step.eval(), model.learning_rate.eval(), step_time, perplexity)) # Decrease learning rate if no improvement was seen over last 3 times. if len(previous_losses) > 2 and loss > max( previous_losses[-3:]): sess.run(model.learning_rate_decay_op) previous_losses.append(loss) # Save checkpoint and zero timer and loss. checkpoint_path = os.path.join(FLAGS.train_dir, "chitchat.model") model.saver.save(sess, checkpoint_path, global_step=model.global_step) step_time, loss = 0.0, 0.0 # Run evals on development set and print their perplexity. # for bucket_id in xrange(len(_buckets)): # encoder_inputs, decoder_inputs, target_weights = model.get_batch( # dev_set, bucket_id) # _, eval_loss, _ = model.step(sess, encoder_inputs, decoder_inputs, # target_weights, bucket_id, True) # eval_ppx = math.exp(eval_loss) if eval_loss < 300 else float('inf') # print(" eval: bucket %d perplexity %.2f" % (bucket_id, eval_ppx)) sys.stdout.flush()
def train(): """Train a en->fr translation model using WMT data.""" #with tf.device("/gpu:0"): # Prepare WMT data. train_path = os.path.join(FLAGS.data_dir, "weibo") fixed_path = os.path.join(FLAGS.data_dir, "fixed") weibo_path = os.path.join(FLAGS.data_dir, "wb") qa_path = os.path.join(FLAGS.data_dir, "qa") voc_file_path = [ train_path + ".answer", fixed_path + ".answer", weibo_path + ".answer", qa_path + ".answer", train_path + ".query", fixed_path + ".query", weibo_path + ".query", qa_path + ".query" ] vocab_path = os.path.join(FLAGS.data_dir, "vocab%d.txt" % FLAGS.vocab_size) data_utils.create_vocabulary(vocab_path, voc_file_path, FLAGS.vocab_size) vocab, rev_vocab = data_utils.initialize_vocabulary(vocab_path) print(len(vocab)) print("Preparing Chitchat data in %s" % FLAGS.data_dir) train_query, train_answer, dev_query, dev_answer = data_utils.prepare_chitchat_data( FLAGS.data_dir, vocab, FLAGS.vocab_size) print("Preparing Fixed data in %s" % FLAGS.fixed_set_path) fixed_path = os.path.join(FLAGS.fixed_set_path, "wb") fixed_query, fixed_answer = data_utils.prepare_defined_data( fixed_path, vocab, FLAGS.vocab_size) print("Preparing Weibo data in %s" % FLAGS.weibo_set_path) weibo_path = os.path.join(FLAGS.weibo_set_path, "wb") weibo_query, weibo_answer = data_utils.prepare_defined_data( weibo_path, vocab, FLAGS.vocab_size) print("Preparing QA data in %s" % FLAGS.qa_set_path) qa_path = os.path.join(FLAGS.qa_set_path, "wb") qa_query, qa_answer = data_utils.prepare_defined_data( qa_path, vocab, FLAGS.vocab_size) dummy_path = os.path.join(FLAGS.data_dir, "dummy") dummy_set = data_utils.get_dummy_set(dummy_path, vocab, FLAGS.vocab_size) print("Get Dummy Set : ", dummy_set) if FLAGS.reinforce_learning == True and FLAGS.dual_learning == False: import data0_utils as du config = {} config['fill_word'] = du._PAD_ config['embedding'] = du.embedding config['fold'] = 1 config['model_file'] = "model_mp" config['log_file'] = "dis.log" config['train_iters'] = 50000 config['model_tag'] = "mxnet" config['batch_size'] = 64 config['data1_maxlen'] = 46 config['data2_maxlen'] = 74 config['data1_psize'] = 5 config['data2_psize'] = 5 from importlib import import_module mo = import_module(config['model_file']) disModel = mo.Model(config) disSess = tf.Session() disModel.init_step(disSess) if sys.argv[1] != "no": disModel.saver.restore(disSess, sys.argv[1]) outputFile = open("RL_ouput.txt", "w") lofFile = open("log.txt", "w") tfconfig = tf.ConfigProto() tfconfig.gpu_options.allow_growth = True with tf.Session(config=tfconfig) as sess: #with tf.device("/gpu:1"): # Create model. print("Creating %d layers of %d units." % (FLAGS.num_layers, FLAGS.size)) model = create_model(sess, dummy_set, False, False) if FLAGS.dual_learning: du_model = create_model(sess, dummy_set, False, True) #sess.run(model.learning_rate_set_op) # Read data into buckets and compute their sizes. print("Reading development and training data (limit: %d)." % FLAGS.max_train_data_size) # This is the training loop. step_time, loss = 0.0, 0.0 current_step = 0 previous_losses = [] en_dict_cover = {} fr_dict_cover = {} if model.global_step.eval() > FLAGS.steps_per_checkpoint: try: with open(FLAGS.en_cover_dict_path, "rb") as ef: en_dict_cover = pickle.load(ef) # for line in ef.readlines(): # line = line.strip() # key, value = line.strip(",") # en_dict_cover[int(key)]=int(value) except Exception: print("no find query_cover_file") try: with open(FLAGS.ff_cover_dict_path, "rb") as ff: fr_dict_cover = pickle.load(ff) # for line in ff.readlines(): # line = line.strip() # key, value = line.strip(",") # fr_dict_cover[int(key)]=int(value) except Exception: print("no find answer_cover_file") step_loss_summary = tf.Summary() #merge = tf.merge_all_summaries() writer = tf.summary.FileWriter("./logs/", sess.graph) while True: # Choose a bucket according to data distribution. We pick a random number # in [0, 1] and use the corresponding interval in train_buckets_scale. for ind in range(30): dev_set = read_data(dev_query, dev_answer, 0, 3000000) train_set = read_data(train_query, train_answer, ind * 100000, (ind + 1) * 100000) fixed_set = read_data(fixed_query, fixed_answer, FLAGS.max_train_data_size) weibo_set = read_data(weibo_query, weibo_answer, FLAGS.max_train_data_size) qa_set = read_data(qa_query, qa_answer, FLAGS.max_train_data_size) train_bucket_sizes = [ len(train_set[b]) for b in xrange(len(_buckets)) ] train_total_size = float(sum(train_bucket_sizes)) train_buckets_scale = [ sum(train_bucket_sizes[:i + 1]) / train_total_size for i in xrange(len(train_bucket_sizes)) ] for kk in range(500): random_number_01 = np.random.random_sample() bucket_id = min([ i for i in xrange(len(train_buckets_scale)) if train_buckets_scale[i] > random_number_01 ]) # Get a batch and make a step. start_time = time.time() encoder_inputs, decoder_inputs, target_weights, batch_source_encoder, batch_source_decoder = model.get_batch( train_set, bucket_id, 0, fixed_set, weibo_set, qa_set) inv_encoder_inputs, inv_decoder_inputs, inv_target_weights, inv_batch_source_encoder, inv_batch_source_decoder = model.inverse( batch_source_encoder, batch_source_decoder, bucket_id) if FLAGS.reinforce_learning: if FLAGS.dual_learning: _, step_loss1, _ = model.step_dual( sess, _buckets, encoder_inputs, decoder_inputs, target_weights, batch_source_encoder, batch_source_decoder, bucket_id, du_model, rev_vocab=rev_vocab) _, step_loss2, _ = du_model.step_dual( sess, _buckets, inv_encoder_inputs, inv_decoder_inputs, inv_target_weights, inv_batch_source_encoder, inv_batch_source_decoder, bucket_id, model, rev_vocab=rev_vocab) step_loss = [] for ii in range(len(step_loss1)): step_loss.append(step_loss1[ii] + step_loss2[ii]) else: _, step_loss, _ = model.step_rl( sess, _buckets, encoder_inputs, decoder_inputs, target_weights, batch_source_encoder, batch_source_decoder, bucket_id, rev_vocab=rev_vocab, disSession=disSess, disModel=disModel) else: _, step_loss, _ = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, forward_only=False, force_dec_input=True) lossmean = 0. for ii in step_loss: lossmean = lossmean + ii lossmean = lossmean / len(step_loss) loss += lossmean / FLAGS.steps_per_checkpoint step_time += (time.time() - start_time) / FLAGS.steps_per_checkpoint current_step += 1 query_size, answer_size = _buckets[bucket_id] for batch_index in xrange(FLAGS.batch_size): for query_index in xrange(query_size): query_word = encoder_inputs[query_index][ batch_index] if en_dict_cover.has_key(query_word): en_dict_cover[query_word] += 1 else: en_dict_cover[query_word] = 0 for answer_index in xrange(answer_size): answer_word = decoder_inputs[answer_index][ batch_index] if fr_dict_cover.has_key(answer_word): fr_dict_cover[answer_word] += 1 else: fr_dict_cover[answer_word] = 0 # Once in a while, we save checkpoint, print statistics, and run evals. if current_step % FLAGS.steps_per_checkpoint == 0: outputFile = open( "OpenSubData/RL_" + str(model.global_step.eval()) + ".txt", "w") bucket_value = step_loss_summary.value.add() bucket_value.tag = "loss" bucket_value.simple_value = float(loss) writer.add_summary(step_loss_summary, current_step) print("query_dict_cover_num: %s" % (str(en_dict_cover.__len__()))) print("answer_dict_cover_num: %s" % (str(fr_dict_cover.__len__()))) ef = open(FLAGS.en_cover_dict_path, "wb") pickle.dump(en_dict_cover, ef) ff = open(FLAGS.ff_cover_dict_path, "wb") pickle.dump(fr_dict_cover, ff) num = 0 pick = 0. mmm = 1 eval_loss = 0 dictt = {} dictt_b = {} for idd in range(2): bucket_id = idd + 2 batch_num = 1 + int( len(dev_set[bucket_id]) / FLAGS.batch_size) for mm in range(batch_num): encoder_inputs, decoder_inputs, target_weights, batch_source_encoder, batch_source_decoder = model.get_batch_dev( dev_set, bucket_id, mm * FLAGS.batch_size, fixed_set, weibo_set, qa_set) _, eval_loss_per, output_logits = model.step( sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, forward_only=True, force_dec_input=False) #_, eval_loss_per, _ = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, forward_only=False, force_dec_input=True) eval_loss += np.mean(eval_loss_per) resp_tokens = model.remove_type( output_logits, model.buckets[bucket_id], type=1) #prob = model.calprob(sess,_buckets, encoder_inputs, decoder_inputs, target_weights,batch_source_encoder, batch_source_decoder, bucket_id,rev_vocab=rev_vocab) resp_c = model.ids2tokens( resp_tokens, rev_vocab) resp_b = model.ids2tokens( batch_source_decoder, rev_vocab) resp_a = model.ids2tokens( batch_source_encoder, rev_vocab) for ii in range(len(resp_a)): aa = "" for ww in resp_a[ii]: aa = aa + " " + ww bb = "" for ww in resp_b[ii]: bb = bb + " " + ww cc = "" pre = "" for ww in resp_c[ii]: cc = cc + " " + ww if ww not in dictt: dictt[ww] = 0 if pre + ww not in dictt_b: dictt_b[pre + ww] = 0 dictt[ww] += 1 dictt_b[pre + ww] += 1 pre = ww #print("Q:",aa) #print("A1:",bb) #print("A2:",cc) #print("\n") outputFile.write("%s\n%s\n%s \n\n" % (aa, bb, cc)) outputFile.flush() BLEUscore = nltk.translate.bleu_score.sentence_bleu( [resp_c[ii]], resp_b[ii]) print(BLEUscore) #eval_loss += BLEUscore mmm += 1 #dummy = model.caldummy(sess,_buckets, encoder_inputs, decoder_inputs, target_weights,batch_source_encoder, batch_source_decoder, bucket_id,rev_vocab=rev_vocab) #print(dummy) #eval_loss +=dummy eval_loss = eval_loss / mmm # Print statistics for the previous epoch. perplexity = math.exp(loss) if loss < 300 else float( 'inf') print( "global step %d learning rate %.4f step-time %.2f loss " "%.2f" % (model.global_step.eval(), model.learning_rate.eval(), step_time, loss)) # Decrease learning rate if no improvement was seen over last 3 times. if len(previous_losses) > 2 and loss > max( previous_losses[-3:]): sess.run(model.learning_rate_decay_op) sess.run(du_model.learning_rate_decay_op) previous_losses.append(loss) # Save checkpoint and zero timer and loss. checkpoint_path = os.path.join(FLAGS.train_dir, "weibo.model") model.saver.save(sess, checkpoint_path, global_step=model.global_step) checkpoint_path2 = os.path.join( FLAGS.train_dir2, "weibo.du_model") du_model.saver.save(sess, checkpoint_path2, global_step=model.global_step) eval_ppx = math.exp( eval_loss) if eval_loss < 300 else float('inf') summ = [dictt[w] for w in dictt] summ = 1.0 * sum(summ) print( " eval: %.5f bucket %d distinct-1 %.5f distinct-2 %.5f " % (eval_loss, bucket_id, len(dictt) / summ, len(dictt_b) / summ)) lofFile.write("%.2f %.2f\n" % (loss, eval_loss)) lofFile.flush() step_time, loss = 0.0, 0.0 # Run evals on development set and print their perplexity. # for bucket_id in xrange(len(_buckets)): # encoder_inputs, decoder_inputs, target_weights = model.get_batch( # dev_set, bucket_id) # _, eval_loss, _ = model.step(sess, encoder_inputs, decoder_inputs, # target_weights, bucket_id, True) # eval_ppx = math.exp(eval_loss) if eval_loss < 300 else float('inf') # print(" eval: bucket %d perplexity %.2f" % (bucket_id, eval_ppx)) sys.stdout.flush()
def test_decoder(config): train_path = os.path.join(config.train_dir, "movie_subtitle.train") data_path_list = [train_path + ".answer", train_path + ".query"] vocab_path = os.path.join(config.train_dir, "vocab%d.all" % config.vocab_size) # data_utils.create_vocabulary(vocab_path, data_path_list, config.vocab_size) vocab, rev_vocab = data_utils.initialize_vocabulary(vocab_path) dummy_set = data_utils.get_dummy_set("grl_data/dummy_sentence", vocab, 25000) with tf.Session() as sess: if config.name_model in [ gst_config.name_model, gcc_config.name_model, gbk_config.name_model ]: model = create_st_model(sess, config, forward_only=True, name_scope=config.name_model) elif config.name_model in [ grl_config.name_model, pre_grl_config.name_model ]: model = create_rl_model(sess, config, forward_only=True, name_scope=config.name_model, dummy_set=dummy_set) model.batch_size = 1 sys.stdout.write("> ") sys.stdout.flush() sentence = sys.stdin.readline() while sentence: token_ids = data_utils.sentence_to_token_ids( tf.compat.as_bytes(sentence), vocab) print("token_id: ", token_ids) bucket_id = len(config.buckets) - 1 for i, bucket in enumerate(config.buckets): if bucket[0] >= len(token_ids): bucket_id = i break else: print("Sentence truncated: %s", sentence) encoder_inputs, decoder_inputs, target_weights, _, _ = model.get_batch( {bucket_id: [(token_ids, [1])]}, bucket_id) # st_model step if config.name_model in [ gst_config.name_model, gcc_config.name_model, gbk_config.name_model ]: output_logits, _ = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, True) outputs = [ int(np.argmax(logit, axis=1)) for logit in output_logits ] if data_utils.EOS_ID in outputs: outputs = outputs[:outputs.index(data_utils.EOS_ID)] print(" ".join([str(rev_vocab[output]) for output in outputs])) # beam_search step elif config.name_model in [ grl_config.name_model, pre_grl_config.name_model ]: _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs, target_weights, reward=1, bucket_id=bucket_id, forward_only=True) #output_logits = np.reshape(output_logits,[1,-1,25000]) output_logits = np.squeeze(output_logits) outputs = np.argmax(output_logits, axis=1) outputs = list(outputs) # for i, output in enumerate(output_logits): # print("index: %d, answer tokens: %s" %(i, str(output))) # if data_utils.EOS_ID in output: # output = output[:output.index(data_utils.EOS_ID)] if data_utils.EOS_ID in outputs: outputs = outputs[:outputs.index(data_utils.EOS_ID)] print(outputs) while data_utils.UNK_ID in outputs: sub_max = np.argmax(output_logits[outputs.index( data_utils.UNK_ID)][4:]) + 4 outputs[outputs.index(data_utils.UNK_ID)] = sub_max print(" ".join([str(rev_vocab[out]) for out in outputs])) print("> ", end="") sys.stdout.flush() sentence = sys.stdin.readline()
def read_file_test(config, test_model_name, input_path, output_path): vocab_path = os.path.join(config.train_dir, "vocab%d.all" % config.vocab_size) vocab, rev_vocab = data_utils.initialize_vocabulary(vocab_path) dummy_set = data_utils.get_dummy_set("grl_data/dummy_sentence", vocab, 25000) forward_only = True with tf.Session() as sess: with tf.variable_scope(name_or_scope=config.name_model): model = grl_rnn_model.grl_model(grl_config=config, name_scope=config.name_model, forward=forward_only, dummy_set=dummy_set) #ckpt = tf.train.get_checkpoint_state(os.path.join(rl_config.train_dir, "checkpoints")) #print (ckpt.model_checkpoint_path) model.batch_size = 1 if test_model_name == 'S2S': model.saver.restore(sess, "grl_data/movie_subtitle.model-118000") elif test_model_name == 'RL': model.saver.restore(sess, "grl_data/movie_subtitle.model-127200") else: model.saver.restore(sess, "grl_data/movie_subtitle.model-127200") with open(input_path) as f: sentences = f.readlines() output_file = [] for sentence in sentences: token_ids = data_utils.sentence_to_token_ids( tf.compat.as_bytes(sentence), vocab) print("token_id: ", token_ids) bucket_id = len(config.buckets) - 1 for i, bucket in enumerate(config.buckets): if bucket[0] >= len(token_ids): bucket_id = i break else: print("Sentence truncated: %s", sentence) encoder_inputs, decoder_inputs, target_weights, _, _ = model.get_batch( {bucket_id: [(token_ids, [1])]}, bucket_id) _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs, target_weights, reward=1, bucket_id=bucket_id, forward_only=True) #output_logits = np.reshape(output_logits,[1,-1,25000]) output_logits = np.squeeze(output_logits) outputs = np.argmax(output_logits, axis=1) outputs = list(outputs) # for i, output in enumerate(output_logits): # print("index: %d, answer tokens: %s" %(i, str(output))) # if data_utils.EOS_ID in output: # output = output[:output.index(data_utils.EOS_ID)] if data_utils.EOS_ID in outputs: outputs = outputs[:outputs.index(data_utils.EOS_ID)] print(outputs) while data_utils.UNK_ID in outputs: sub_max = np.argmax(output_logits[outputs.index( data_utils.UNK_ID)][4:]) + 4 outputs[outputs.index(data_utils.UNK_ID)] = sub_max while 30 in outputs: outputs.remove(30) output_sentence = " ".join( [str(rev_vocab[out]) for out in outputs]) + '\n' #while '$' in output_sentence: output_sentence = output_sentence.replace("$", "") print(output_sentence) output_file.append(output_sentence) f = open(output_path, 'w') f.writelines(output_file)
def train(): vocab, rev_vocab, train_set = prepare_data(grl_config) for b_set in train_set: print("b_set length: ", len(b_set)) dummy_set = data_utils.get_dummy_set("grl_data/dummy_sentence", vocab, 25000) with tf.Session() as sess: rl_model = create_rl_model(sess, grl_config, False, grl_config.name_model, dummy_set) st_model = create_st_model(sess, gst_config, True, gst_config.name_model) #bk_model = create_st_model(sess, gbk_config, True, gbk_config.name_model) #cc_model = create_st_model(sess, gcc_config, True, gcc_config.name_model) train_bucket_sizes = [ len(train_set[b]) for b in range(len(grl_config.buckets)) ] train_total_size = float(sum(train_bucket_sizes)) train_buckets_scale = [ sum(train_bucket_sizes[:i + 1]) / train_total_size for i in range(len(train_bucket_sizes)) ] step_time, loss = 0.0, 0.0 current_step = 0 previous_losses = [] step_loss_summary = tf.Summary() # merge = tf.merge_all_summaries() rl_writer = tf.summary.FileWriter(grl_config.tensorboard_dir, sess.graph) while True: random_number_01 = np.random.random_sample() bucket_id = min([ i for i in range(len(train_buckets_scale)) if train_buckets_scale[i] > random_number_01 ]) # Get a batch and make a step. start_time = time.time() encoder_inputs, decoder_inputs, target_weights, batch_source_encoder, _ = \ rl_model.get_batch(train_set,bucket_id) updata, norm, step_loss = rl_model.step_rl( sess, st_model=st_model, bk_model=st_model, cc_model=st_model, encoder_inputs=encoder_inputs, decoder_inputs=decoder_inputs, target_weights=target_weights, batch_source_encoder=batch_source_encoder, bucket_id=bucket_id) step_time += (time.time() - start_time) / grl_config.steps_per_checkpoint loss += step_loss / grl_config.steps_per_checkpoint current_step += 1 # Once in a while, we save checkpoint, print statistics, and run evals. if current_step % grl_config.steps_per_checkpoint == 0: bucket_value = step_loss_summary.value.add() bucket_value.tag = grl_config.name_loss bucket_value.simple_value = float(loss) rl_writer.add_summary(step_loss_summary, int(sess.run(rl_model.global_step))) # Print statistics for the previous epoch. perplexity = math.exp(loss) if loss < 300 else float('inf') print( "global step %d learning rate %.4f step-time %.2f perplexity " "%.2f" % (rl_model.global_step.eval(), rl_model.learning_rate.eval(), step_time, perplexity)) # Decrease learning rate if no improvement was seen over last 3 times. if len(previous_losses) > 2 and loss > max( previous_losses[-3:]): sess.run(rl_model.learning_rate_decay_op) previous_losses.append(loss) # Save checkpoint and zero timer and loss. gen_ckpt_dir = os.path.abspath( os.path.join(grl_config.train_dir, "checkpoints")) if not os.path.exists(gen_ckpt_dir): os.makedirs(gen_ckpt_dir) checkpoint_path = os.path.join(gen_ckpt_dir, "movie_subtitle.model") rl_model.saver.save(sess, checkpoint_path, global_step=rl_model.global_step) step_time, loss = 0.0, 0.0 # Run evals on development set and print their perplexity. # for bucket_id in xrange(len(gen_config.buckets)): # encoder_inputs, decoder_inputs, target_weights = model.get_batch( # dev_set, bucket_id) # _, eval_loss, _ = model.step(sess, encoder_inputs, decoder_inputs, # target_weights, bucket_id, True) # eval_ppx = math.exp(eval_loss) if eval_loss < 300 else float('inf') # print(" eval: bucket %d perplexity %.2f" % (bucket_id, eval_ppx)) sys.stdout.flush()