Пример #1
0
def encode(sess, model, config, sentences):
    # Load vocabularies.
    en_vocab_path = os.path.join(config.data_dir,
                                 "vocab%d" % config.vocab_size)
    en_vocab, rev_vocab = data_utils.initialize_vocabulary(en_vocab_path)

    means = []
    logvars = []
    for i, sentence in enumerate(sentences):
        # Get token-ids for the input sentence.
        token_ids = data_utils.sentence_to_token_ids(sentence, en_vocab)
        # Which bucket does it belong to?
        bucket_id = len(config.buckets) - 1
        for i, bucket in enumerate(config.buckets):
            if bucket[0] >= len(token_ids):
                bucket_id = i
                break
        else:
            logging.warning("Sentence truncated: %s", sentence)

            # Get a 1-element batch to feed the sentence to the model.
        encoder_inputs, _, _ = model.get_batch({bucket_id: [(token_ids, [])]},
                                               bucket_id)
        # Get output logits for the sentence.
        mean, logvar = model.encode_to_latent(sess, encoder_inputs, bucket_id)
        means.append(mean)
        logvars.append(logvar)

    return means, logvars
Пример #2
0
def test_decoder(gen_config):
    with tf.Session() as sess:
        model = create_model(sess,
                             gen_config,
                             forward_only=True,
                             name_scope=gen_config.name_model)
        model.batch_size = 1

        train_path = os.path.join(gen_config.train_dir, "chitchat.train")
        voc_file_path = [train_path + ".answer", train_path + ".query"]
        vocab_path = os.path.join(gen_config.train_dir,
                                  "vocab%d.all" % gen_config.vocab_size)
        data_utils.create_vocabulary(vocab_path, voc_file_path,
                                     gen_config.vocab_size)
        vocab, rev_vocab = data_utils.initialize_vocabulary(vocab_path)

        sys.stdout.write("> ")
        sys.stdout.flush()
        sentence = sys.stdin.readline()
        while sentence:
            token_ids = data_utils.sentence_to_token_ids(
                tf.compat.as_bytes(sentence), vocab)
            print("token_id: ", token_ids)
            bucket_id = len(gen_config.buckets) - 1
            for i, bucket in enumerate(gen_config.buckets):
                if bucket[0] >= len(token_ids):
                    bucket_id = i
                    break
            else:
                print("Sentence truncated: %s", sentence)

            encoder_inputs, decoder_inputs, target_weights, _, _ = model.get_batch(
                {bucket_id: [(token_ids, [1])]},
                bucket_id,
                model.batch_size,
                type=0)

            print("bucket_id: ", bucket_id)
            print("encoder_inputs:", encoder_inputs)
            print("decoder_inputs:", decoder_inputs)
            print("target_weights:", target_weights)

            _, _, output_logits = model.step(sess, encoder_inputs,
                                             decoder_inputs, target_weights,
                                             bucket_id, True)

            print("output_logits", np.shape(output_logits))

            outputs = [
                int(np.argmax(logit, axis=1)) for logit in output_logits
            ]

            if data_utils.EOS_ID in outputs:
                outputs = outputs[:outputs.index(data_utils.EOS_ID)]

            print(" ".join(
                [tf.compat.as_str(rev_vocab[output]) for output in outputs]))
            print("> ", end="")
            sys.stdout.flush()
            sentence = sys.stdin.readline()
Пример #3
0
def decoder_online(sess,gen_config, model, vocab,rev_vocab, inputs):
    
    token_ids = data_utils.sentence_to_token_ids(inputs, vocab)

    # Which bucket does it belong to?
    bucket_id = min([b for b in xrange(len(_buckets)) if _buckets[b][0] > len(token_ids)])
    #bucket_id = min([i for i in xrange(len(train_buckets_scale))
       #             if train_buckets_scale[i] > random_number_01])
    # Get a 1-element batch to feed the sentence to the model.
    encoder_inputs, decoder_inputs, target_weights ,_,_ = model.get_batch({bucket_id: [(token_ids, [])]}, bucket_id,gen_config.batch_size)

    # Get output logits for the sentence.
    _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, True)
    
    # If there is an EOS symbol in outputs, cut them at that point.
    tokens = []
    resps = []
    for seq in output_logits:
        token = []
        for t in seq:
            token.append(int(np.argmax(t, axis=0)))
        tokens.append(token)
        tokens_t = []
        for col in range(len(tokens[0])):
            tokens_t.append([tokens[row][col] for row in range(len(tokens))])

        for seq in tokens_t:
            if data_utils.EOS_ID in seq:
                resps.append(seq[:seq.index(data_utils.EOS_ID)][:gen_config.buckets[bucket_id][1]])
            else:
                resps.append(seq[:gen_config.buckets[bucket_id][1]])
    for resp in resps:
        resq_str= " ".join([tf.compat.as_str(rev_vocab[output]) for output in resp])
    return resq_str
Пример #4
0
def test_decoder(gen_config):
    # vocab_path = os.path.join(gen_config.train_dir, "vocab%d.all" % gen_config.vocab_size)
    # vocab, rev_vocab = data_utils.initialize_vocabulary(vocab_path)
    fr_vocab = open('vocab', 'rb')
    fr_rev_vocab = open('rev_vocab', 'rb')
    vocab = pickle.load(fr_vocab)
    rev_vocab = pickle.load(fr_rev_vocab)
    fr_vocab.close()
    fr_rev_vocab.close()
    # seq2seq, optimizer = create_model(gen_config, vocab)
    seq2seq = torch.load('./pre_seq2seq.pth')
    sys.stdout.write("> ")
    sys.stdout.flush()
    sentence = sys.stdin.readline()
    while sentence:
        source = data_utils.sentence_to_token_ids(sentence, vocab)  # list
        encoder_pad = [data_utils.PAD_ID] * (gen_config.maxlen - len(source))
        encoder_inputs = [list(reversed(source + encoder_pad))]
        src = torch.from_numpy(np.array(encoder_inputs).T)
        src = Variable(src).cuda()
        probs, result = seq2seq.decode(src)  # maxlen*1*vocab, a list
        ans = []
        for idx in result:
            if idx == data_utils.EOS_ID:
                break
            if idx != data_utils.PAD_ID:
                ans.append(rev_vocab[idx])
        print(" ".join(ans))
        print("> ", end="")
        sys.stdout.flush()
        sentence = sys.stdin.readline()
def test_file_decoder(gen_config, input_file, output_file):
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        model = create_model(sess,
                             gen_config,
                             forward_only=True,
                             name_scope=gen_config.name_model)
        model.batch_size = 1
        train_path = os.path.join(gen_config.train_dir, "chitchat.train")
        voc_file_path = [train_path + ".answer", train_path + ".query"]
        vocab_path = os.path.join(gen_config.train_dir,
                                  "vocab%d.all" % gen_config.vocab_size)
        data_utils.create_vocabulary(vocab_path, voc_file_path,
                                     gen_config.vocab_size)
        vocab, rev_vocab = data_utils.initialize_vocabulary(vocab_path)
        with open(output_file, 'w') as fout:
            with open(input_file, 'r') as fin:
                for sent in fin:
                    print(sent)
                    token_ids = data_utils.sentence_to_token_ids(
                        tf.compat.as_str(sent), vocab)
                    print("token_id: ", token_ids)
                    bucket_id = len(gen_config.buckets) - 1
                    for i, bucket in enumerate(gen_config.buckets):
                        if bucket[0] >= len(token_ids):
                            bucket_id = i
                            break
                    else:
                        print("Sentence truncated: %s", sentence)
                    encoder_inputs, decoder_inputs, target_weights, _, _ = model.get_batch(
                        {bucket_id: [(token_ids, [1])]},
                        bucket_id,
                        model.batch_size,
                        type=0)
                    _, _, output_logits = model.step(sess, encoder_inputs,
                                                     decoder_inputs,
                                                     target_weights, bucket_id,
                                                     True)

                    outputs = [
                        int(np.argmax(logit, axis=1))
                        for logit in output_logits
                    ]
                    if data_utils.EOS_ID in outputs:
                        outputs = outputs[:outputs.index(data_utils.EOS_ID)]
                    out_sent = " ".join([
                        tf.compat.as_str(rev_vocab[output])
                        for output in outputs
                    ])
                    fout.write(out_sent + '\n')
                    print(out_sent)
Пример #6
0
def get_predicted_sentence(sess,
                           input_sentence,
                           vocab,
                           model,
                           beam_size,
                           buckets,
                           mc_search=False,
                           debug=False):
    def model_step(enc_inp, dec_inp, dptr, target_weights, bucket_id):

        _, _, logits = model.step(sess, enc_inp, dec_inp, target_weights,
                                  bucket_id, True)
        prob = softmax(logits[dptr][0])

        return prob

    def greedy_dec(output_logits):
        selected_token_ids = [
            int(np.argmax(logit, axis=1)) for logit in output_logits
        ]
        return selected_token_ids

    input_token_ids = data_utils.sentence_to_token_ids(input_sentence, vocab)
    # Which bucket does it belong to?
    print(input_token_ids)
    bucket_id = min([
        b for b in range(len(buckets)) if buckets[b][0] > len(input_token_ids)
    ])
    outputs = []
    feed_data = {bucket_id: [(input_token_ids, outputs)]}

    # Get a 1-element batch to feed the sentence to the model.   None,bucket_id, True
    encoder_inputs, decoder_inputs, target_weights, _, _ = model.get_batch(
        feed_data, bucket_id, 1)
    if debug:
        print("\n[get_batch]\n", encoder_inputs, decoder_inputs,
              target_weights)
    print(decoder_inputs)
    ### Original greedy decoding
    if beam_size == 1 or (not mc_search):
        _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs,
                                         target_weights, bucket_id, True)
        #[{"dec_inp": greedy_dec(output_logits), 'prob': 1}]
        outputs = greedy_dec(output_logits)
    return " ".join([tf.compat.as_str(vocab[output]) for output in outputs])
    pass
    # Get output logits for the sentence. # initialize beams as (log_prob, empty_string, eos)
    beams, new_beams, results = [(1, {
        'eos': 0,
        'dec_inp': decoder_inputs,
        'prob': 1,
        'prob_ts': 1,
        'prob_t': 1
    })], [], []

    for dptr in range(len(decoder_inputs) - 1):
        if dptr > 0:
            target_weights[dptr] = [1.]
            beams, new_beams = new_beams[:beam_size], []
        if debug: print("=====[beams]=====", beams)
        heapq.heapify(
            beams
        )  # since we will srot and remove something to keep N elements
        for prob, cand in beams:
            if cand['eos']:
                results += [(prob, cand)]
                continue
            print(cand['dec_inp'])
            all_prob_ts = model_step(encoder_inputs, cand['dec_inp'], dptr,
                                     target_weights, bucket_id)
            all_prob_t = [0] * len(all_prob_ts)
            all_prob = all_prob_ts

            # suppress copy-cat (respond the same as input)
            if dptr < len(input_token_ids):
                all_prob[input_token_ids[dptr]] = all_prob[
                    input_token_ids[dptr]] * 0.01

            # beam search
            for c in np.argsort(all_prob)[::-1][:beam_size]:
                new_cand = {
                    'eos': (c == data_utils.EOS_ID),
                    'dec_inp': [(np.array([c]) if i == (dptr + 1) else k)
                                for i, k in enumerate(cand['dec_inp'])],
                    'prob_ts':
                    cand['prob_ts'] * all_prob_ts[c],
                    'prob_t':
                    cand['prob_t'] * all_prob_t[c],
                    'prob':
                    cand['prob'] * all_prob[c],
                }
                new_cand = (new_cand['prob'], new_cand
                            )  # for heapq can only sort according to list[0]

                if (len(new_beams) < beam_size):
                    heapq.heappush(new_beams, new_cand)
                elif (new_cand[0] > new_beams[0][0]):
                    heapq.heapreplace(new_beams, new_cand)

    results += new_beams  # flush last cands

    # post-process results
    res_cands = []
    for prob, cand in sorted(results, reverse=True):
        res_cands.append(cand)
    return res_cands
Пример #7
0
def reconstruct(sess, model, config):
    model.batch_size = 1  # We decode one sentence at a time.
    model.probabilistic = config.probabilistic
    beam_size = config.beam_size

    # Load vocabularies.
    vocab_path = os.path.join(config.data_dir,
                              "vocab%d.in" % config.vocab_size)
    en_vocab, rev_vocab = data_utils.initialize_vocabulary(vocab_path)

    # Decode from standard input.
    outputs = []
    with gfile.GFile(FLAGS.input, "r") as fs:
        sentences = fs.readlines()
    for i, sentence in enumerate(sentences):
        # Get token-ids for the input sentence.
        token_ids = data_utils.sentence_to_token_ids(sentence, en_vocab)
        # Which bucket does it belong to?
        bucket_id = len(config.buckets) - 1
        for i, bucket in enumerate(config.buckets):
            if bucket[0] >= len(token_ids):
                bucket_id = i
                break
        else:
            logging.warning("Sentence truncated: %s", sentence)

        encoder_inputs, decoder_inputs, target_weights = model.get_batch(
            {bucket_id: [(token_ids, [])]}, bucket_id)

        if beam_size > 1:
            path, symbol, output_logits = model.step(
                sess, encoder_inputs, decoder_inputs, target_weights,
                bucket_id, True, config.probabilistic, beam_size)

            k = output_logits[0]
            paths = []
            for kk in range(beam_size):
                paths.append([])
            curr = range(beam_size)
            num_steps = len(path)
            for i in range(num_steps - 1, -1, -1):
                for kk in range(beam_size):
                    paths[kk].append(symbol[i][curr[kk]])
                    curr[kk] = path[i][curr[kk]]
            recos = set()
            for kk in range(beam_size):
                output = [int(logit) for logit in paths[kk][::-1]]

                if EOS_ID in output:
                    output = output[:output.index(EOS_ID)]
                output = " ".join([rev_vocab[word] for word in output]) + "\n"
                outputs.append(output)

        else:
            # Get output logits for the sentence.
            _, _, _, output_logits = model.step(sess, encoder_inputs,
                                                decoder_inputs, target_weights,
                                                bucket_id, True,
                                                config.probabilistic)
            # This is a greedy decoder - outputs are just argmaxes of output_logits.
            output = [int(np.argmax(logit, axis=1)) for logit in output_logits]
            # If there is an EOS symbol in outputs, cut them at that point.
            if data_utils.EOS_ID in output:
                output = output[:output.index(data_utils.EOS_ID)]
            output = " ".join([rev_vocab[word] for word in output]) + "\n"
            outputs.append(output)
    with gfile.GFile(FLAGS.output, "w") as enc_dec_f:
        for output in outputs:
            enc_dec_f.write(output)
Пример #8
0
def decode():
    with tf.Session() as sess:
        # Create model and load parameters.
        model = create_model(sess, True)
        model.batch_size = 1  # We decode one sentence at a time.

        # Load vocabularies.
        en_vocab_path = os.path.join(FLAGS.data_dir,
                                     "vocab%d.en" % FLAGS.en_vocab_size)
        ja_vocab_path = os.path.join(FLAGS.data_dir,
                                     "vocab%d.ja" % FLAGS.target_vocab_size)
        en_vocab, _ = data_utils.initialize_vocabulary(en_vocab_path)
        _, rev_ja_vocab = data_utils.initialize_vocabulary(ja_vocab_path)
        _, rev_en_vocab = data_utils.initialize_vocabulary(en_vocab_path)
        if len(rev_ja_vocab) < FLAGS.target_vocab_size:
            rev_ja_vocab += [
                "_" for i in range(FLAGS.target_vocab_size - len(rev_ja_vocab))
            ]

        # Prepare visual context integration
        for file in os.listdir(FLAGS.data_dir):
            if file.endswith("ids{0}.{1}".format(str(FLAGS.target_vocab_size),
                                                 FLAGS.target_language)):
                target_id_file = os.path.join(FLAGS.data_dir, file)
        W, bias_vec = train_visual(target_id_file,
                                   os.path.join(FLAGS.data_dir,
                                                FLAGS.visual_vec_file_name),
                                   ja_vocab_path,
                                   FLAGS.target_vocab_size,
                                   num_epochs=100)

        # Decode from standard input.
        sys.stdout.write("> ")
        sys.stdout.flush()
        sentence = sys.stdin.readline()
        visual_vec = []
        while sentence:
            # Create visual context vector if given
            for el in reversed(sentence.split()):
                if el.isdigit():
                    visual_vec = [int(el)] + visual_vec
            # Get token-ids for the input sentence.
            sentence = " ".join(sentence.split()[:len(sentence.split()) -
                                                 len(visual_vec)])
            token_ids = data_utils.sentence_to_token_ids(sentence, en_vocab)
            print("token ids: " + str(token_ids))
            # Which bucket does it belong to?
            bucket_id = min([
                b for b in xrange(len(_buckets))
                if _buckets[b][0] > len(token_ids)
            ])
            # Get a 1-element batch to feed the sentence to the model.
            encoder_inputs, decoder_inputs, target_weights = model.get_batch(
                {bucket_id: [(token_ids, [])]}, bucket_id)

            # Get output logits for the sentence.
            _, _, output_logits = model.step(
                sess, encoder_inputs, decoder_inputs, target_weights,
                bucket_id, True
            )  #step(session, encoder_inputs, decoder_inputs, target_weights, bucket_id, forward_only)
            #output_logits = [[log_prob(target1), log_prob(target2)...log_prob(target_n)], [log_prob(target1), ...]]
            #with length of decode bucket e.g. Hello! -> input buckt 5 -> decode bucket 10

            # Store tensorflow scores for each output until the end-of-a-sentence symbol
            output_list = []
            for logits in output_logits:
                max_id = np.argmax(logits, axis=1)
                if max_id == data_utils.EOS_ID:
                    break
                nmt_score_dict = dict(enumerate(logits[0]))
                output_list.append(nmt_score_dict)
                #logits[0] += visual_scores

            # Store visual scores
            if visual_vec != [] and FLAGS.visual:
                visual_scores = feedforward(W, visual_vec, bias_vec)
                visual_scores = np.array([
                    math.log(prob / (1 - prob)) for prob in visual_scores
                ])  #turn probabilities into logits
                visual_score_dict = dict(enumerate(visual_scores))
                # print ("--visual score--")
                # for k,v in sorted(visual_score_dict.items(), key=lambda x:x[1], reverse=True):
                #     print(rev_ja_vocab[k]+":"+str(v), end=" ")
                # print ("\n")
            else:
                visual_score_dict = {}

            # Integrate visual scores if given and output the result
            print("--result--")
            outputs = []
            for dic in output_list:
                for k in visual_score_dict:
                    dic[k] += visual_score_dict[k]
                outputs.append(
                    max(dic.iteritems(), key=operator.itemgetter(1))[0])
                for k, v in sorted(dic.items(),
                                   key=lambda x: x[1],
                                   reverse=True):
                    print(rev_ja_vocab[k] + ":" + str(v), end="  ")
                print("\n")

            print(" ".join([rev_ja_vocab[output] for output in outputs]))
            print("> ", end="")
            visual_vec = []
            sys.stdout.flush()
            sentence = sys.stdin.readline()