def encode(sess, model, config, sentences): # Load vocabularies. en_vocab_path = os.path.join(config.data_dir, "vocab%d" % config.vocab_size) en_vocab, rev_vocab = data_utils.initialize_vocabulary(en_vocab_path) means = [] logvars = [] for i, sentence in enumerate(sentences): # Get token-ids for the input sentence. token_ids = data_utils.sentence_to_token_ids(sentence, en_vocab) # Which bucket does it belong to? bucket_id = len(config.buckets) - 1 for i, bucket in enumerate(config.buckets): if bucket[0] >= len(token_ids): bucket_id = i break else: logging.warning("Sentence truncated: %s", sentence) # Get a 1-element batch to feed the sentence to the model. encoder_inputs, _, _ = model.get_batch({bucket_id: [(token_ids, [])]}, bucket_id) # Get output logits for the sentence. mean, logvar = model.encode_to_latent(sess, encoder_inputs, bucket_id) means.append(mean) logvars.append(logvar) return means, logvars
def test_decoder(gen_config): with tf.Session() as sess: model = create_model(sess, gen_config, forward_only=True, name_scope=gen_config.name_model) model.batch_size = 1 train_path = os.path.join(gen_config.train_dir, "chitchat.train") voc_file_path = [train_path + ".answer", train_path + ".query"] vocab_path = os.path.join(gen_config.train_dir, "vocab%d.all" % gen_config.vocab_size) data_utils.create_vocabulary(vocab_path, voc_file_path, gen_config.vocab_size) vocab, rev_vocab = data_utils.initialize_vocabulary(vocab_path) sys.stdout.write("> ") sys.stdout.flush() sentence = sys.stdin.readline() while sentence: token_ids = data_utils.sentence_to_token_ids( tf.compat.as_bytes(sentence), vocab) print("token_id: ", token_ids) bucket_id = len(gen_config.buckets) - 1 for i, bucket in enumerate(gen_config.buckets): if bucket[0] >= len(token_ids): bucket_id = i break else: print("Sentence truncated: %s", sentence) encoder_inputs, decoder_inputs, target_weights, _, _ = model.get_batch( {bucket_id: [(token_ids, [1])]}, bucket_id, model.batch_size, type=0) print("bucket_id: ", bucket_id) print("encoder_inputs:", encoder_inputs) print("decoder_inputs:", decoder_inputs) print("target_weights:", target_weights) _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, True) print("output_logits", np.shape(output_logits)) outputs = [ int(np.argmax(logit, axis=1)) for logit in output_logits ] if data_utils.EOS_ID in outputs: outputs = outputs[:outputs.index(data_utils.EOS_ID)] print(" ".join( [tf.compat.as_str(rev_vocab[output]) for output in outputs])) print("> ", end="") sys.stdout.flush() sentence = sys.stdin.readline()
def decoder_online(sess,gen_config, model, vocab,rev_vocab, inputs): token_ids = data_utils.sentence_to_token_ids(inputs, vocab) # Which bucket does it belong to? bucket_id = min([b for b in xrange(len(_buckets)) if _buckets[b][0] > len(token_ids)]) #bucket_id = min([i for i in xrange(len(train_buckets_scale)) # if train_buckets_scale[i] > random_number_01]) # Get a 1-element batch to feed the sentence to the model. encoder_inputs, decoder_inputs, target_weights ,_,_ = model.get_batch({bucket_id: [(token_ids, [])]}, bucket_id,gen_config.batch_size) # Get output logits for the sentence. _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, True) # If there is an EOS symbol in outputs, cut them at that point. tokens = [] resps = [] for seq in output_logits: token = [] for t in seq: token.append(int(np.argmax(t, axis=0))) tokens.append(token) tokens_t = [] for col in range(len(tokens[0])): tokens_t.append([tokens[row][col] for row in range(len(tokens))]) for seq in tokens_t: if data_utils.EOS_ID in seq: resps.append(seq[:seq.index(data_utils.EOS_ID)][:gen_config.buckets[bucket_id][1]]) else: resps.append(seq[:gen_config.buckets[bucket_id][1]]) for resp in resps: resq_str= " ".join([tf.compat.as_str(rev_vocab[output]) for output in resp]) return resq_str
def test_decoder(gen_config): # vocab_path = os.path.join(gen_config.train_dir, "vocab%d.all" % gen_config.vocab_size) # vocab, rev_vocab = data_utils.initialize_vocabulary(vocab_path) fr_vocab = open('vocab', 'rb') fr_rev_vocab = open('rev_vocab', 'rb') vocab = pickle.load(fr_vocab) rev_vocab = pickle.load(fr_rev_vocab) fr_vocab.close() fr_rev_vocab.close() # seq2seq, optimizer = create_model(gen_config, vocab) seq2seq = torch.load('./pre_seq2seq.pth') sys.stdout.write("> ") sys.stdout.flush() sentence = sys.stdin.readline() while sentence: source = data_utils.sentence_to_token_ids(sentence, vocab) # list encoder_pad = [data_utils.PAD_ID] * (gen_config.maxlen - len(source)) encoder_inputs = [list(reversed(source + encoder_pad))] src = torch.from_numpy(np.array(encoder_inputs).T) src = Variable(src).cuda() probs, result = seq2seq.decode(src) # maxlen*1*vocab, a list ans = [] for idx in result: if idx == data_utils.EOS_ID: break if idx != data_utils.PAD_ID: ans.append(rev_vocab[idx]) print(" ".join(ans)) print("> ", end="") sys.stdout.flush() sentence = sys.stdin.readline()
def test_file_decoder(gen_config, input_file, output_file): config = tf.ConfigProto() config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: model = create_model(sess, gen_config, forward_only=True, name_scope=gen_config.name_model) model.batch_size = 1 train_path = os.path.join(gen_config.train_dir, "chitchat.train") voc_file_path = [train_path + ".answer", train_path + ".query"] vocab_path = os.path.join(gen_config.train_dir, "vocab%d.all" % gen_config.vocab_size) data_utils.create_vocabulary(vocab_path, voc_file_path, gen_config.vocab_size) vocab, rev_vocab = data_utils.initialize_vocabulary(vocab_path) with open(output_file, 'w') as fout: with open(input_file, 'r') as fin: for sent in fin: print(sent) token_ids = data_utils.sentence_to_token_ids( tf.compat.as_str(sent), vocab) print("token_id: ", token_ids) bucket_id = len(gen_config.buckets) - 1 for i, bucket in enumerate(gen_config.buckets): if bucket[0] >= len(token_ids): bucket_id = i break else: print("Sentence truncated: %s", sentence) encoder_inputs, decoder_inputs, target_weights, _, _ = model.get_batch( {bucket_id: [(token_ids, [1])]}, bucket_id, model.batch_size, type=0) _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, True) outputs = [ int(np.argmax(logit, axis=1)) for logit in output_logits ] if data_utils.EOS_ID in outputs: outputs = outputs[:outputs.index(data_utils.EOS_ID)] out_sent = " ".join([ tf.compat.as_str(rev_vocab[output]) for output in outputs ]) fout.write(out_sent + '\n') print(out_sent)
def get_predicted_sentence(sess, input_sentence, vocab, model, beam_size, buckets, mc_search=False, debug=False): def model_step(enc_inp, dec_inp, dptr, target_weights, bucket_id): _, _, logits = model.step(sess, enc_inp, dec_inp, target_weights, bucket_id, True) prob = softmax(logits[dptr][0]) return prob def greedy_dec(output_logits): selected_token_ids = [ int(np.argmax(logit, axis=1)) for logit in output_logits ] return selected_token_ids input_token_ids = data_utils.sentence_to_token_ids(input_sentence, vocab) # Which bucket does it belong to? print(input_token_ids) bucket_id = min([ b for b in range(len(buckets)) if buckets[b][0] > len(input_token_ids) ]) outputs = [] feed_data = {bucket_id: [(input_token_ids, outputs)]} # Get a 1-element batch to feed the sentence to the model. None,bucket_id, True encoder_inputs, decoder_inputs, target_weights, _, _ = model.get_batch( feed_data, bucket_id, 1) if debug: print("\n[get_batch]\n", encoder_inputs, decoder_inputs, target_weights) print(decoder_inputs) ### Original greedy decoding if beam_size == 1 or (not mc_search): _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, True) #[{"dec_inp": greedy_dec(output_logits), 'prob': 1}] outputs = greedy_dec(output_logits) return " ".join([tf.compat.as_str(vocab[output]) for output in outputs]) pass # Get output logits for the sentence. # initialize beams as (log_prob, empty_string, eos) beams, new_beams, results = [(1, { 'eos': 0, 'dec_inp': decoder_inputs, 'prob': 1, 'prob_ts': 1, 'prob_t': 1 })], [], [] for dptr in range(len(decoder_inputs) - 1): if dptr > 0: target_weights[dptr] = [1.] beams, new_beams = new_beams[:beam_size], [] if debug: print("=====[beams]=====", beams) heapq.heapify( beams ) # since we will srot and remove something to keep N elements for prob, cand in beams: if cand['eos']: results += [(prob, cand)] continue print(cand['dec_inp']) all_prob_ts = model_step(encoder_inputs, cand['dec_inp'], dptr, target_weights, bucket_id) all_prob_t = [0] * len(all_prob_ts) all_prob = all_prob_ts # suppress copy-cat (respond the same as input) if dptr < len(input_token_ids): all_prob[input_token_ids[dptr]] = all_prob[ input_token_ids[dptr]] * 0.01 # beam search for c in np.argsort(all_prob)[::-1][:beam_size]: new_cand = { 'eos': (c == data_utils.EOS_ID), 'dec_inp': [(np.array([c]) if i == (dptr + 1) else k) for i, k in enumerate(cand['dec_inp'])], 'prob_ts': cand['prob_ts'] * all_prob_ts[c], 'prob_t': cand['prob_t'] * all_prob_t[c], 'prob': cand['prob'] * all_prob[c], } new_cand = (new_cand['prob'], new_cand ) # for heapq can only sort according to list[0] if (len(new_beams) < beam_size): heapq.heappush(new_beams, new_cand) elif (new_cand[0] > new_beams[0][0]): heapq.heapreplace(new_beams, new_cand) results += new_beams # flush last cands # post-process results res_cands = [] for prob, cand in sorted(results, reverse=True): res_cands.append(cand) return res_cands
def reconstruct(sess, model, config): model.batch_size = 1 # We decode one sentence at a time. model.probabilistic = config.probabilistic beam_size = config.beam_size # Load vocabularies. vocab_path = os.path.join(config.data_dir, "vocab%d.in" % config.vocab_size) en_vocab, rev_vocab = data_utils.initialize_vocabulary(vocab_path) # Decode from standard input. outputs = [] with gfile.GFile(FLAGS.input, "r") as fs: sentences = fs.readlines() for i, sentence in enumerate(sentences): # Get token-ids for the input sentence. token_ids = data_utils.sentence_to_token_ids(sentence, en_vocab) # Which bucket does it belong to? bucket_id = len(config.buckets) - 1 for i, bucket in enumerate(config.buckets): if bucket[0] >= len(token_ids): bucket_id = i break else: logging.warning("Sentence truncated: %s", sentence) encoder_inputs, decoder_inputs, target_weights = model.get_batch( {bucket_id: [(token_ids, [])]}, bucket_id) if beam_size > 1: path, symbol, output_logits = model.step( sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, True, config.probabilistic, beam_size) k = output_logits[0] paths = [] for kk in range(beam_size): paths.append([]) curr = range(beam_size) num_steps = len(path) for i in range(num_steps - 1, -1, -1): for kk in range(beam_size): paths[kk].append(symbol[i][curr[kk]]) curr[kk] = path[i][curr[kk]] recos = set() for kk in range(beam_size): output = [int(logit) for logit in paths[kk][::-1]] if EOS_ID in output: output = output[:output.index(EOS_ID)] output = " ".join([rev_vocab[word] for word in output]) + "\n" outputs.append(output) else: # Get output logits for the sentence. _, _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, True, config.probabilistic) # This is a greedy decoder - outputs are just argmaxes of output_logits. output = [int(np.argmax(logit, axis=1)) for logit in output_logits] # If there is an EOS symbol in outputs, cut them at that point. if data_utils.EOS_ID in output: output = output[:output.index(data_utils.EOS_ID)] output = " ".join([rev_vocab[word] for word in output]) + "\n" outputs.append(output) with gfile.GFile(FLAGS.output, "w") as enc_dec_f: for output in outputs: enc_dec_f.write(output)
def decode(): with tf.Session() as sess: # Create model and load parameters. model = create_model(sess, True) model.batch_size = 1 # We decode one sentence at a time. # Load vocabularies. en_vocab_path = os.path.join(FLAGS.data_dir, "vocab%d.en" % FLAGS.en_vocab_size) ja_vocab_path = os.path.join(FLAGS.data_dir, "vocab%d.ja" % FLAGS.target_vocab_size) en_vocab, _ = data_utils.initialize_vocabulary(en_vocab_path) _, rev_ja_vocab = data_utils.initialize_vocabulary(ja_vocab_path) _, rev_en_vocab = data_utils.initialize_vocabulary(en_vocab_path) if len(rev_ja_vocab) < FLAGS.target_vocab_size: rev_ja_vocab += [ "_" for i in range(FLAGS.target_vocab_size - len(rev_ja_vocab)) ] # Prepare visual context integration for file in os.listdir(FLAGS.data_dir): if file.endswith("ids{0}.{1}".format(str(FLAGS.target_vocab_size), FLAGS.target_language)): target_id_file = os.path.join(FLAGS.data_dir, file) W, bias_vec = train_visual(target_id_file, os.path.join(FLAGS.data_dir, FLAGS.visual_vec_file_name), ja_vocab_path, FLAGS.target_vocab_size, num_epochs=100) # Decode from standard input. sys.stdout.write("> ") sys.stdout.flush() sentence = sys.stdin.readline() visual_vec = [] while sentence: # Create visual context vector if given for el in reversed(sentence.split()): if el.isdigit(): visual_vec = [int(el)] + visual_vec # Get token-ids for the input sentence. sentence = " ".join(sentence.split()[:len(sentence.split()) - len(visual_vec)]) token_ids = data_utils.sentence_to_token_ids(sentence, en_vocab) print("token ids: " + str(token_ids)) # Which bucket does it belong to? bucket_id = min([ b for b in xrange(len(_buckets)) if _buckets[b][0] > len(token_ids) ]) # Get a 1-element batch to feed the sentence to the model. encoder_inputs, decoder_inputs, target_weights = model.get_batch( {bucket_id: [(token_ids, [])]}, bucket_id) # Get output logits for the sentence. _, _, output_logits = model.step( sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, True ) #step(session, encoder_inputs, decoder_inputs, target_weights, bucket_id, forward_only) #output_logits = [[log_prob(target1), log_prob(target2)...log_prob(target_n)], [log_prob(target1), ...]] #with length of decode bucket e.g. Hello! -> input buckt 5 -> decode bucket 10 # Store tensorflow scores for each output until the end-of-a-sentence symbol output_list = [] for logits in output_logits: max_id = np.argmax(logits, axis=1) if max_id == data_utils.EOS_ID: break nmt_score_dict = dict(enumerate(logits[0])) output_list.append(nmt_score_dict) #logits[0] += visual_scores # Store visual scores if visual_vec != [] and FLAGS.visual: visual_scores = feedforward(W, visual_vec, bias_vec) visual_scores = np.array([ math.log(prob / (1 - prob)) for prob in visual_scores ]) #turn probabilities into logits visual_score_dict = dict(enumerate(visual_scores)) # print ("--visual score--") # for k,v in sorted(visual_score_dict.items(), key=lambda x:x[1], reverse=True): # print(rev_ja_vocab[k]+":"+str(v), end=" ") # print ("\n") else: visual_score_dict = {} # Integrate visual scores if given and output the result print("--result--") outputs = [] for dic in output_list: for k in visual_score_dict: dic[k] += visual_score_dict[k] outputs.append( max(dic.iteritems(), key=operator.itemgetter(1))[0]) for k, v in sorted(dic.items(), key=lambda x: x[1], reverse=True): print(rev_ja_vocab[k] + ":" + str(v), end=" ") print("\n") print(" ".join([rev_ja_vocab[output] for output in outputs])) print("> ", end="") visual_vec = [] sys.stdout.flush() sentence = sys.stdin.readline()