def main():
    args = parse_args()
    state = prototype_state()

    state_path = args.model_prefix + "_state.pkl"
    model_path = args.model_prefix + "_model.npz"

    with open(state_path) as src:
        state.update(cPickle.load(src))

    logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")

    model = DialogEncoderDecoder(state) 
    
    if os.path.isfile(model_path):
        logger.debug("Loading previous model")
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")
    
    contexts = [[]]
    lines = open(args.dialogues, "r").readlines()
    if len(lines):
        contexts = [x.strip().split('\t') for x in lines]
   
    model_compute_encoding = model.build_encoder_function()
    dialogue_encodings = []

    # Start loop
    joined_contexts = []
    batch_index = 0
    batch_total = int(math.ceil(float(len(contexts)) / float(model.bs)))
    for context_id, context_sentences in enumerate(contexts):

        # Convert contextes into list of ids
        joined_context = []

        if len(context_sentences) == 0:
            joined_context = [model.eos_sym]
        else:
            joined_context += [model.eos_sym]
            for sentence in context_sentences:
                sentence_ids = model.words_to_indices(sentence.split())
                # Add eos tokens
                joined_context += sentence_ids + [model.eos_sym]

        # HACK
        #for i in range(0, 50):
        #    joined_context += [0] + [model.eos_sym]

        joined_contexts.append(joined_context)

        if len(joined_contexts) == model.bs:
            batch_index = batch_index + 1
            logger.debug("[COMPUTE] - Got batch %d / %d" % (batch_index, batch_total))
            encs = compute_encodings(joined_contexts, model, model_compute_encoding, args.use_second_last_state)
            for i in range(len(encs)):
                dialogue_encodings.append(encs[i])

            joined_contexts = []


    if len(joined_contexts) > 0:
        logger.debug("[COMPUTE] - Got batch %d / %d" % (batch_total, batch_total))
        encs = compute_encodings(joined_contexts, model, model_compute_encoding, args.use_second_last_state)
        for i in range(len(encs)):
            dialogue_encodings.append(encs[i])

    # Save encodings to disc
    cPickle.dump(dialogue_encodings, open(args.output + '.pkl', 'w'))
Beispiel #2
0
def main():
    args = parse_args()
    state = prototype_state()

    state_path = args.model_prefix + "_state.pkl"
    model_path = args.model_prefix + "_model.npz"

    with open(state_path) as src:
        state.update(cPickle.load(src))

    logging.basicConfig(
        level=getattr(logging, state['level']),
        format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")

    model = DialogEncoderDecoder(state)

    if os.path.isfile(model_path):
        logger.debug("Loading previous model")
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")

    contexts = [[]]
    lines = open(args.dialogues, "r").readlines()
    if len(lines):
        contexts = [x.strip().split('\t') for x in lines]

    model_compute_encoding = model.build_encoder_function()
    dialogue_encodings = []

    # Start loop
    joined_contexts = []
    batch_index = 0
    batch_total = int(math.ceil(float(len(contexts)) / float(model.bs)))
    for context_id, context_sentences in enumerate(contexts):

        # Convert contextes into list of ids
        joined_context = []

        if len(context_sentences) == 0:
            joined_context = [model.eos_sym]
        else:
            for sentence in context_sentences:
                sentence_ids = model.words_to_indices(sentence.split())
                # Add sos and eos tokens
                joined_context += [model.sos_sym
                                   ] + sentence_ids + [model.eos_sym]

        # HACK
        for i in range(0, 50):
            joined_context += [model.sos_sym] + [0] + [model.eos_sym]

        joined_contexts.append(joined_context)

        if len(joined_contexts) == model.bs:
            batch_index = batch_index + 1
            logger.debug("[COMPUTE] - Got batch %d / %d" %
                         (batch_index, batch_total))
            encs = compute_encodings(joined_contexts, model,
                                     model_compute_encoding,
                                     args.use_second_last_state)
            for i in range(len(encs)):
                dialogue_encodings.append(encs[i])

            joined_contexts = []

    if len(joined_contexts) > 0:
        logger.debug("[COMPUTE] - Got batch %d / %d" %
                     (batch_total, batch_total))
        encs = compute_encodings(joined_contexts, model,
                                 model_compute_encoding,
                                 args.use_second_last_state)
        for i in range(len(encs)):
            dialogue_encodings.append(encs[i])

    # Save encodings to disc
    cPickle.dump(dialogue_encodings, open(args.output + '.pkl', 'w'))
Beispiel #3
0
def main(model_prefix, dialogue_file, use_second_last_state):
    state = prototype_state()

    state_path = model_prefix + "_state.pkl"
    model_path = model_prefix + "_model.npz"

    with open(state_path) as src:
        state.update(cPickle.load(src))

    logging.basicConfig(
        level=getattr(logging, state['level']),
        format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")

    state['bs'] = 10

    model = DialogEncoderDecoder(state)

    if os.path.isfile(model_path):
        logger.debug("Loading previous model")
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")

    contexts = [[]]
    lines = open(dialogue_file, "r").readlines()
    if len(lines):
        contexts = [x.strip() for x in lines]

    model_compute_encoding = model.build_encoder_function()
    dialogue_encodings = []

    # Start loop
    joined_contexts = []
    batch_index = 0
    batch_total = int(math.ceil(float(len(contexts)) / float(model.bs)))
    for context_id, context_sentences in enumerate(contexts):
        # Convert contexts into list of ids
        joined_context = []

        if len(context_sentences) == 0:
            joined_context = [model.eos_sym]
        else:
            joined_context = model.words_to_indices(context_sentences.split())

            if joined_context[0] != model.eos_sym:
                joined_context = [model.eos_sym] + joined_context

            if joined_context[-1] != model.eos_sym:
                joined_context += [model.eos_sym]

        joined_contexts.append(joined_context)

        if len(joined_contexts) == model.bs:
            batch_index = batch_index + 1
            logger.debug("[COMPUTE] - Got batch %d / %d" %
                         (batch_index, batch_total))
            encs = compute_encodings(joined_contexts, model,
                                     model_compute_encoding,
                                     use_second_last_state)
            for i in range(len(encs)):
                dialogue_encodings.append(encs[i])

            joined_contexts = []

    if len(joined_contexts) > 0:
        logger.debug("[COMPUTE] - Got batch %d / %d" %
                     (batch_total, batch_total))
        encs = compute_encodings(joined_contexts, model,
                                 model_compute_encoding, use_second_last_state)
        for i in range(len(encs)):
            dialogue_encodings.append(encs[i])

    return dialogue_encodings
Beispiel #4
0
def main():
    args = parse_args()
    state = prototype_state()

    state_path = args.model_prefix + "_state.pkl"
    model_path = args.model_prefix + "_model.npz"

    with open(state_path) as src:
        state.update(cPickle.load(src))

    logging.basicConfig(
        level=getattr(logging, state['level']),
        format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")

    # For simplicity, we force the batch size to be one
    state['bs'] = 1
    model = DialogEncoderDecoder(state)

    if os.path.isfile(model_path):
        logger.debug("Loading previous model")
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")

    eval_batch = model.build_eval_function()

    contexts = [[]]
    lines = open(args.context, "r").readlines()
    if len(lines):
        contexts = [x.strip() for x in lines]

    potential_responses_set = open(args.responses, "r").readlines()
    most_probable_responses_string = ''

    print('Retrieval started...')

    for context_idx, context in enumerate(contexts):
        if context_idx % 100 == 0:
            print '     processing example: ' + str(context_idx) + ' / ' + str(
                len(contexts))
        potential_responses = potential_responses_set[context_idx].strip(
        ).split('\t')

        most_probable_response_loglikelihood = -1.0
        most_probable_response = ''

        for potential_response_idx, potential_response in enumerate(
                potential_responses):
            # Convert contexts into list of ids
            dialogue = []
            if len(context) == 0:
                dialogue = [model.eos_sym]
            else:
                sentence_ids = model.words_to_indices(context.split())
                # Add eos tokens
                if len(sentence_ids) > 0:
                    if not sentence_ids[0] == model.eos_sym:
                        sentence_ids = [model.eos_sym] + sentence_ids
                    if not sentence_ids[-1] == model.eos_sym:
                        sentence_ids += [model.eos_sym]
                else:
                    sentence_ids = [model.eos_sym]

                dialogue += sentence_ids

            response = model.words_to_indices(potential_response.split())
            if len(response) > 0:
                if response[0] == model.eos_sym:
                    del response[0]
                if not response[-1] == model.eos_sym:
                    response += [model.eos_sym]

            dialogue += response
            dialogue = numpy.asarray(dialogue, dtype='int32').reshape(
                (len(dialogue), 1))

            dialogue_reversed = model.reverse_utterances(dialogue)

            dialogue_mask = numpy.ones((len(dialogue), 1), dtype='float32')
            dialogue_weight = numpy.ones((len(dialogue), 1), dtype='float32')
            dialogue_reset_mask = numpy.zeros((1), dtype='float32')
            dialogue_ran_vectors = model.rng.normal(size=(
                1, model.latent_gaussian_per_utterance_dim)).astype('float32')
            dialogue_ran_vectors = numpy.tile(dialogue_ran_vectors,
                                              (len(dialogue), 1, 1))

            dialogue_drop_mask = numpy.ones((len(dialogue), 1),
                                            dtype='float32')

            c, _, _, _, _ = eval_batch(dialogue, dialogue_reversed,
                                       len(dialogue), dialogue_mask,
                                       dialogue_weight, dialogue_reset_mask,
                                       dialogue_ran_vectors,
                                       dialogue_drop_mask)
            c = c / len(dialogue)

            print 'c', c

            if (potential_response_idx
                    == 0) or (-c > most_probable_response_loglikelihood):
                most_probable_response_loglikelihood = -c
                most_probable_response = potential_response

        most_probable_responses_string += most_probable_response + '\n'

    print('Retrieval finished.')
    print('Saving to file...')

    # Write to output file
    output_handle = open(args.output, "w")
    output_handle.write(most_probable_responses_string)
    output_handle.close()

    print('Saving to file finished.')
    print('All done!')
def main():
    args = parse_args()
    state = prototype_state()
   
    state_path = args.model_prefix + "_state.pkl"
    model_path = args.model_prefix + "_model.npz"

    with open(state_path) as src:
        state.update(cPickle.load(src)) 
    
    logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")

    state['lr'] = 0.01
    state['bs'] = 2

    state['compute_training_updates'] = False
    state['apply_meanfield_inference'] = True

    model = DialogEncoderDecoder(state)
    if os.path.isfile(model_path):
        logger.debug("Loading previous model")
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")
    


    mf_update_batch = model.build_mf_update_function()
    mf_reset_batch = model.build_mf_reset_function()

    saliency_batch = model.build_saliency_eval_function()

    test_dialogues = open(args.test_dialogues, 'r').readlines()
    for test_dialogue_idx, test_dialogue in enumerate(test_dialogues):
        #print 'Visualizing dialogue: ', test_dialogue
        test_dialogue_split = test_dialogue.split()
        # Convert dialogue into list of ids
        dialogue = []
        if len(test_dialogue) == 0:
            dialogue = [model.eos_sym]
        else:
            sentence_ids = model.words_to_indices(test_dialogue_split)
            # Add eos tokens
            if len(sentence_ids) > 0:
                if not sentence_ids[0] == model.eos_sym:
                    sentence_ids = [model.eos_sym] + sentence_ids
                if not sentence_ids[-1] == model.eos_sym:
                    sentence_ids += [model.eos_sym]
            else:
                sentence_ids = [model.eos_sym]


            dialogue += sentence_ids

        if len(dialogue) > 3:
            if ((dialogue[-1] == model.eos_sym)
             and (dialogue[-2] == model.eod_sym)
             and (dialogue[-3] == model.eos_sym)):
                del dialogue[-1]
                del dialogue[-1]

        
        dialogue = numpy.asarray(dialogue, dtype='int32').reshape((len(dialogue), 1))
        #print 'dialogue', dialogue
        dialogue_reversed = model.reverse_utterances(dialogue)

        max_batch_sequence_length = len(dialogue)
        bs = state['bs']

        # Initialize batch with zeros
        batch_dialogues = numpy.zeros((max_batch_sequence_length, bs), dtype='int32')
        batch_dialogues_reversed = numpy.zeros((max_batch_sequence_length, bs), dtype='int32')
        batch_dialogues_mask = numpy.zeros((max_batch_sequence_length, bs), dtype='float32')
        batch_dialogues_reset_mask = numpy.zeros((bs), dtype='float32')
        batch_dialogues_drop_mask = numpy.ones((max_batch_sequence_length, bs), dtype='float32')

        # Fill in batch with values
        batch_dialogues[:,0]  = dialogue[:, 0]
        batch_dialogues_reversed[:,0] = dialogue_reversed[:, 0]
        #batch_dialogues  = dialogue
        #batch_dialogues_reversed = dialogue_reversed


        batch_dialogues_ran_gaussian_vectors = numpy.zeros((max_batch_sequence_length, bs, model.latent_gaussian_per_utterance_dim), dtype='float32')
        batch_dialogues_ran_uniform_vectors = numpy.zeros((max_batch_sequence_length, bs, model.latent_piecewise_per_utterance_dim), dtype='float32')

        eos_sym_list = numpy.where(batch_dialogues[:, 0] == state['eos_sym'])[0]
        if len(eos_sym_list) > 1:
            second_last_eos_sym = eos_sym_list[-2]
        else:
            print 'WARNING: dialogue does not have at least two EOS tokens!'

        batch_dialogues_mask[:, 0] = 1.0
        batch_dialogues_mask[0:second_last_eos_sym+1, 0] = 0.0

        print '###'
        print '###'
        print '###'
        if False==True:
            mf_reset_batch()
            for i in range(10):
                print  '  SGD Update', i

                batch_dialogues_ran_gaussian_vectors[second_last_eos_sym:, 0, :] = model.rng.normal(loc=0, scale=1, size=model.latent_gaussian_per_utterance_dim)
                batch_dialogues_ran_uniform_vectors[second_last_eos_sym:, 0, :] = model.rng.uniform(low=0.0, high=1.0, size=model.latent_piecewise_per_utterance_dim)


                training_cost, kl_divergence_cost_acc, kl_divergences_between_piecewise_prior_and_posterior, kl_divergences_between_gaussian_prior_and_posterior = mf_update_batch(batch_dialogues, batch_dialogues_reversed, max_batch_sequence_length, batch_dialogues_mask, batch_dialogues_reset_mask, batch_dialogues_ran_gaussian_vectors, batch_dialogues_ran_uniform_vectors, batch_dialogues_drop_mask)

                print '     training_cost', training_cost
                print '     kl_divergence_cost_acc', kl_divergence_cost_acc
                print '     kl_divergences_between_gaussian_prior_and_posterior',  numpy.sum(kl_divergences_between_gaussian_prior_and_posterior)
                print '     kl_divergences_between_piecewise_prior_and_posterior', numpy.sum(kl_divergences_between_piecewise_prior_and_posterior)


        batch_dialogues_ran_gaussian_vectors[second_last_eos_sym:, 0, :] = model.rng.normal(loc=0, scale=1, size=model.latent_gaussian_per_utterance_dim)
        batch_dialogues_ran_uniform_vectors[second_last_eos_sym:, 0, :] = model.rng.uniform(low=0.0, high=1.0, size=model.latent_piecewise_per_utterance_dim)

        gaussian_saliency, piecewise_saliency = saliency_batch(batch_dialogues, batch_dialogues_reversed, max_batch_sequence_length, batch_dialogues_mask, batch_dialogues_reset_mask, batch_dialogues_ran_gaussian_vectors, batch_dialogues_ran_uniform_vectors, batch_dialogues_drop_mask)

        if test_dialogue_idx < 2:
            print 'gaussian_saliency', gaussian_saliency.shape, gaussian_saliency
            print 'piecewise_saliency', piecewise_saliency.shape, piecewise_saliency

        gaussian_sum = 0.0
        piecewise_sum = 0.0
        for i in range(second_last_eos_sym+1, max_batch_sequence_length):
            gaussian_sum += gaussian_saliency[dialogue[i, 0]]
            piecewise_sum += piecewise_saliency[dialogue[i, 0]]

        gaussian_sum = max(gaussian_sum, 0.0000000000001)
        piecewise_sum = max(piecewise_sum, 0.0000000000001)


        print '###'
        print '###'
        print '###'

        print 'Topic: ', ' '.join(test_dialogue_split[0:second_last_eos_sym])

        print ''
        print 'Response', ' '.join(test_dialogue_split[second_last_eos_sym+1:max_batch_sequence_length])
        gaussian_str = ''
        piecewise_str = ''
        for i in range(second_last_eos_sym+1, max_batch_sequence_length):
            gaussian_str += str(gaussian_saliency[dialogue[i, 0]]/gaussian_sum) + ' '
            piecewise_str += str(piecewise_saliency[dialogue[i, 0]]/piecewise_sum) + ' '




        print 'Gaussian_saliency', gaussian_str
        print 'Piecewise_saliency', piecewise_str

        #print ''
        #print 'HEY', gaussian_saliency[:, 0].argsort()[-3:][::-1]
        print ''
        print 'Gaussian Top 3 Words: ', model.indices_to_words(list(gaussian_saliency[:].argsort()[-3:][::-1]))
        print 'Piecewise Top 3 Words: ', model.indices_to_words(list(piecewise_saliency[:].argsort()[-3:][::-1]))

        print 'Gaussian Top 5 Words: ', model.indices_to_words(list(gaussian_saliency[:].argsort()[-5:][::-1]))
        print 'Piecewise Top 5 Words: ', model.indices_to_words(list(piecewise_saliency[:].argsort()[-5:][::-1]))

        print 'Gaussian Top 7 Words: ', model.indices_to_words(list(gaussian_saliency[:].argsort()[-7:][::-1]))
        print 'Piecewise Top 7 Words: ', model.indices_to_words(list(piecewise_saliency[:].argsort()[-7:][::-1]))


    logger.debug("All done, exiting...")
Beispiel #6
0
def main():
    args = parse_args()
    state = prototype_state()

    state_path = args.model_prefix + "_state.pkl"
    model_path = args.model_prefix + "_model.npz"

    with open(state_path) as src:
        state.update(cPickle.load(src))

    logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")

    state['bs'] = 20
    state['compute_training_updates'] = False

    if args.mf_inference_steps > 0:
        state['lr'] = 0.01
        state['apply_meanfield_inference'] = True

    model = DialogEncoderDecoder(state) 

    if os.path.isfile(model_path):
        logger.debug("Loading previous model")
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")
    

    eval_batch = model.build_eval_function()

    if args.mf_inference_steps > 0:
        mf_update_batch = model.build_mf_update_function()
        mf_reset_batch = model.build_mf_reset_function()
        
    contexts = [[]]
    lines = open(args.context, "r").readlines()
    if len(lines):
        contexts = [x.strip() for x in lines]

    lines = open(args.responses, "r").readlines()
    if len(lines):
        potential_responses_set = [x.strip() for x in lines]
    
    print('Building data batches...')

    # This is the maximum sequence length we can process on a 12 GB GPU with large HRED models
    max_sequence_length = 80*(80/state['bs'])

    assert len(potential_responses_set) == len(contexts)

    # Note we assume that each example has the same number of potential responses
    # The code can be adapted, however, by simply counting the total number of responses to consider here...
    examples_to_process = len(contexts) * len(potential_responses_set[0].strip().split('\t'))

    all_dialogues = numpy.zeros((max_sequence_length, examples_to_process), dtype='int32')
    all_dialogues_reversed = numpy.zeros((max_sequence_length, examples_to_process), dtype='int32')
    all_dialogues_mask = numpy.zeros((max_sequence_length, examples_to_process), dtype='float32')
    all_dialogues_reset_mask = numpy.zeros((examples_to_process), dtype='float32')
    all_dialogues_drop_mask = numpy.ones((max_sequence_length, examples_to_process), dtype='float32')

    all_dialogues_len = numpy.zeros((examples_to_process), dtype='int32')

    example_idx = -1

    for context_idx, context in enumerate(contexts):
        if context_idx % 100 == 0:
            print '     processing context idx: ' + str(context_idx) + ' / ' + str(len(contexts))

        potential_responses = potential_responses_set[context_idx].strip().split('\t')

        most_probable_response_loglikelihood = -1.0
        most_probable_response = ''

        for potential_response_idx, potential_response in enumerate(potential_responses):
            example_idx += 1

            # Convert contexts into list of ids
            dialogue = []
            if len(context) == 0:
                dialogue = [model.eos_sym]
            else:
                sentence_ids = model.words_to_indices(context.split())
                # Add eos tokens
                if len(sentence_ids) > 0:
                    if not sentence_ids[0] == model.eos_sym:
                        sentence_ids = [model.eos_sym] + sentence_ids
                    if not sentence_ids[-1] == model.eos_sym:
                        sentence_ids += [model.eos_sym]
                else:
                    sentence_ids = [model.eos_sym]


                dialogue += sentence_ids

            response = model.words_to_indices(potential_response.split())
            if len(response) > 0:
                if response[0] == model.eos_sym:
                    del response[0]
                if not response[-1] == model.eos_sym:
                    response += [model.eos_sym]

            if len(response) > 3:
                if ((response[-1] == model.eos_sym)
                 and (response[-2] == model.eod_sym)
                 and (response[-3] == model.eos_sym)):
                    del response[-1]
                    del response[-1]

            dialogue += response

            if context_idx == 0:
                print 'DEBUG INFO'
                print 'dialogue', dialogue

            if len(numpy.where(numpy.asarray(dialogue) == state['eos_sym'])[0]) < 3:
                print 'POTENTIAL PROBLEM WITH DIALOGUE CONTEXT IDX', context_idx
                print '     dialogue', dialogue


            # Trim beginning of dialogue words if the dialogue is too long... this should be a rare event.
            if len(dialogue) > max_sequence_length:
                if state['do_generate_first_utterance']:
                    dialogue = dialogue[len(dialogue)-max_sequence_length:len(dialogue)]
                else: # CreateDebate specific setting
                    dialogue = dialogue[0:max_sequence_length-1]
                    if not dialogue[-1] == state['eos_sym']:
                        dialogue += [state['eos_sym']]

                    


            dialogue = numpy.asarray(dialogue, dtype='int32').reshape((len(dialogue), 1))

            dialogue_reversed = model.reverse_utterances(dialogue)
            dialogue_mask = numpy.ones((len(dialogue)), dtype='float32') 
            dialogue_weight = numpy.ones((len(dialogue)), dtype='float32')
            dialogue_drop_mask = numpy.ones((len(dialogue)), dtype='float32')

            # Add example to large numpy arrays...
            all_dialogues[0:len(dialogue), example_idx] = dialogue[:,0]
            all_dialogues_reversed[0:len(dialogue), example_idx] = dialogue_reversed[:,0]
            all_dialogues_mask[0:len(dialogue), example_idx] = dialogue_mask
            all_dialogues_drop_mask[0:len(dialogue), example_idx] = dialogue_drop_mask

            all_dialogues_len[example_idx] = len(dialogue)


    sorted_dialogue_indices = numpy.argsort(all_dialogues_len)

    all_dialogues_cost = numpy.ones((examples_to_process), dtype='float32')

    batch_count = int(numpy.ceil(float(examples_to_process) / float(state['bs'])))

    print('Computing costs on batches...')
    for batch_idx in reversed(range(batch_count)):
        if batch_idx % 10 == 0:
            print '     processing batch idx: ' + str(batch_idx) + ' / ' + str(batch_count)

        example_index_start = batch_idx * state['bs']
        example_index_end = min((batch_idx+1) * state['bs'], examples_to_process)
        example_indices = sorted_dialogue_indices[example_index_start:example_index_end]

        current_batch_size = len(example_indices)

        max_batch_sequence_length = all_dialogues_len[example_indices[-1]]

        # Initialize batch with zeros
        batch_dialogues = numpy.zeros((max_batch_sequence_length, state['bs']), dtype='int32')
        batch_dialogues_reversed = numpy.zeros((max_batch_sequence_length, state['bs']), dtype='int32')
        batch_dialogues_mask = numpy.zeros((max_batch_sequence_length, state['bs']), dtype='float32')
        batch_dialogues_weight = numpy.ones((max_batch_sequence_length, state['bs']), dtype='float32')
        batch_dialogues_reset_mask = numpy.zeros((state['bs']), dtype='float32')
        batch_dialogues_drop_mask = numpy.ones((max_batch_sequence_length, state['bs']), dtype='float32')

        # Fill in batch with values
        batch_dialogues[:,0:current_batch_size]  = all_dialogues[0:max_batch_sequence_length, example_indices]
        batch_dialogues_reversed[:,0:current_batch_size] = all_dialogues_reversed[0:max_batch_sequence_length, example_indices]
        batch_dialogues_mask[:,0:current_batch_size] = all_dialogues_mask[0:max_batch_sequence_length, example_indices]
        batch_dialogues_drop_mask[:,0:current_batch_size] = all_dialogues_drop_mask[0:max_batch_sequence_length, example_indices]

        if batch_idx < 10:
            print '###'
            print '###'
            print '###'

            print 'DEBUG THIS:'
            print 'batch_dialogues', batch_dialogues[:, 0]
            print 'batch_dialogues_reversed',batch_dialogues_reversed[:, 0]
            print 'max_batch_sequence_length', max_batch_sequence_length
            print 'batch_dialogues_reset_mask', batch_dialogues_reset_mask


        batch_dialogues_ran_gaussian_vectors = numpy.zeros((max_batch_sequence_length, state['bs'], model.latent_gaussian_per_utterance_dim), dtype='float32')
        batch_dialogues_ran_uniform_vectors = numpy.zeros((max_batch_sequence_length, state['bs'], model.latent_piecewise_per_utterance_dim), dtype='float32')

        for idx in range(state['bs']):
            eos_sym_list = numpy.where(batch_dialogues[:, idx] == state['eos_sym'])[0]
            if len(eos_sym_list) > 1:
                second_last_eos_sym = eos_sym_list[-2]
            else:
                print 'WARNING: batch_idx ' + str(batch_idx) + ' example index ' + str(idx) + ' does not have at least two EOS tokens!'
                print '     batch was: ', batch_dialogues[:, idx]

                if len(eos_sym_list) > 0:
                    second_last_eos_sym = eos_sym_list[-1]
                else:
                    second_last_eos_sym = 0

            batch_dialogues_ran_gaussian_vectors[second_last_eos_sym:, idx, :] = model.rng.normal(loc=0, scale=1, size=model.latent_gaussian_per_utterance_dim)
            batch_dialogues_ran_uniform_vectors[second_last_eos_sym:, idx, :] = model.rng.uniform(low=0.0, high=1.0, size=model.latent_piecewise_per_utterance_dim)

            batch_dialogues_mask[0:second_last_eos_sym+1, idx] = 0.0




        if batch_idx < 10:
            print 'batch_dialogues_mask', batch_dialogues_mask[:, 0]
            print 'batch_dialogues_ran_gaussian_vectors', batch_dialogues_ran_gaussian_vectors[:, 0, 0]
            print 'batch_dialogues_ran_gaussian_vectors [-1]', batch_dialogues_ran_gaussian_vectors[:, 0, -1]
            print 'batch_dialogues_ran_uniform_vectors', batch_dialogues_ran_uniform_vectors[:, 0, 0]
            print 'batch_dialogues_ran_uniform_vectors [-1]', batch_dialogues_ran_uniform_vectors[:, 0, -1]


        # Carry out mean-field inference:
        if args.mf_inference_steps > 0:
            mf_reset_batch()
            model.build_mf_reset_function()
            for mf_step in range(args.mf_inference_steps):


                training_cost, kl_divergence_cost_acc, kl_divergences_between_piecewise_prior_and_posterior, kl_divergences_between_gaussian_prior_and_posterior = mf_update_batch(batch_dialogues, batch_dialogues_reversed, max_batch_sequence_length, batch_dialogues_mask, batch_dialogues_reset_mask, batch_dialogues_ran_gaussian_vectors, batch_dialogues_ran_uniform_vectors, batch_dialogues_drop_mask)
                if batch_idx % 10 == 0:
                    print  '  mf_step', mf_step
                    print '     training_cost', training_cost
                    print '     kl_divergence_cost_acc', kl_divergence_cost_acc
                    print '     kl_divergences_between_gaussian_prior_and_posterior',  numpy.sum(kl_divergences_between_gaussian_prior_and_posterior)
                    print '     kl_divergences_between_piecewise_prior_and_posterior', numpy.sum(kl_divergences_between_piecewise_prior_and_posterior)



        eval_batch(batch_dialogues, batch_dialogues_reversed, max_batch_sequence_length, batch_dialogues_mask, batch_dialogues_reset_mask, batch_dialogues_ran_gaussian_vectors, batch_dialogues_ran_uniform_vectors, batch_dialogues_drop_mask)

        # Train on batch and get results
        _, c_list, _ = eval_batch(batch_dialogues, batch_dialogues_reversed, max_batch_sequence_length, batch_dialogues_mask, batch_dialogues_reset_mask, batch_dialogues_ran_gaussian_vectors, batch_dialogues_ran_uniform_vectors, batch_dialogues_drop_mask)
        c_list = c_list.reshape((state['bs'],max_batch_sequence_length-1), order=(1,0))
        c_list = numpy.sum(c_list, axis=1) / numpy.sum(batch_dialogues_mask, axis=0)
        all_dialogues_cost[example_indices] = c_list[0:current_batch_size]

        #print 'hd_input [0]', hd_input[:, 0, 0]
        #print 'hd_input [-]', hd_input[:, 0, -1]


    print('Ranking responses...')
    example_idx = -1
    ranked_all_responses_string = ''
    ranked_all_responses_costs_string = ''
    for context_idx, context in enumerate(contexts):
        if context_idx % 100 == 0:
            print '     processing context idx: ' + str(context_idx) + ' / ' + str(len(contexts))

        potential_responses = potential_responses_set[context_idx].strip().split('\t')
        potential_example_response_indices = numpy.zeros((len(potential_responses)), dtype='int32')

        for potential_response_idx, potential_response in enumerate(potential_responses):
            example_idx += 1
            potential_example_response_indices[potential_response_idx] = example_idx

        ranked_potential_example_response_indices = numpy.argsort(all_dialogues_cost[potential_example_response_indices])

        ranked_responses_costs = all_dialogues_cost[potential_example_response_indices]

        ranked_responses_string = ''
        ranked_responses_costs_string = ''
        for idx in ranked_potential_example_response_indices:
            ranked_responses_string += potential_responses[idx] + '\t'
            ranked_responses_costs_string += str(ranked_responses_costs[idx]) + '\t'

        ranked_all_responses_string += ranked_responses_string[0:len(ranked_responses_string)-1] + '\n'
        ranked_all_responses_costs_string += ranked_responses_costs_string[0:len(ranked_responses_string)-1] + '\n'

    print('Finished all computations!')
    print('Saving to file...')

    # Write to output file
    output_handle = open(args.output + '.txt', "w")
    output_handle.write(ranked_all_responses_string)
    output_handle.close()

    output_handle = open(args.output + '_Costs.txt' , "w")
    output_handle.write(ranked_all_responses_costs_string)
    output_handle.close()

    print('Saving to file finished.')
    print('All done!')