示例#1
0
def main():
    args = parse_args()
    state = prototype_state()
   
    state_path = args.model_prefix + "_state.pkl"
    model_path = args.model_prefix + "_model.npz"

    with open(state_path) as src:
        state.update(cPickle.load(src)) 
    
    logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")

    state['lr'] = 0.01
    state['bs'] = 2

    state['compute_training_updates'] = False
    state['apply_meanfield_inference'] = True

    model = DialogEncoderDecoder(state)
    if os.path.isfile(model_path):
        logger.debug("Loading previous model")
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")
    


    mf_update_batch = model.build_mf_update_function()
    mf_reset_batch = model.build_mf_reset_function()

    saliency_batch = model.build_saliency_eval_function()

    test_dialogues = open(args.test_dialogues, 'r').readlines()
    for test_dialogue_idx, test_dialogue in enumerate(test_dialogues):
        #print 'Visualizing dialogue: ', test_dialogue
        test_dialogue_split = test_dialogue.split()
        # Convert dialogue into list of ids
        dialogue = []
        if len(test_dialogue) == 0:
            dialogue = [model.eos_sym]
        else:
            sentence_ids = model.words_to_indices(test_dialogue_split)
            # Add eos tokens
            if len(sentence_ids) > 0:
                if not sentence_ids[0] == model.eos_sym:
                    sentence_ids = [model.eos_sym] + sentence_ids
                if not sentence_ids[-1] == model.eos_sym:
                    sentence_ids += [model.eos_sym]
            else:
                sentence_ids = [model.eos_sym]


            dialogue += sentence_ids

        if len(dialogue) > 3:
            if ((dialogue[-1] == model.eos_sym)
             and (dialogue[-2] == model.eod_sym)
             and (dialogue[-3] == model.eos_sym)):
                del dialogue[-1]
                del dialogue[-1]

        
        dialogue = numpy.asarray(dialogue, dtype='int32').reshape((len(dialogue), 1))
        #print 'dialogue', dialogue
        dialogue_reversed = model.reverse_utterances(dialogue)

        max_batch_sequence_length = len(dialogue)
        bs = state['bs']

        # Initialize batch with zeros
        batch_dialogues = numpy.zeros((max_batch_sequence_length, bs), dtype='int32')
        batch_dialogues_reversed = numpy.zeros((max_batch_sequence_length, bs), dtype='int32')
        batch_dialogues_mask = numpy.zeros((max_batch_sequence_length, bs), dtype='float32')
        batch_dialogues_reset_mask = numpy.zeros((bs), dtype='float32')
        batch_dialogues_drop_mask = numpy.ones((max_batch_sequence_length, bs), dtype='float32')

        # Fill in batch with values
        batch_dialogues[:,0]  = dialogue[:, 0]
        batch_dialogues_reversed[:,0] = dialogue_reversed[:, 0]
        #batch_dialogues  = dialogue
        #batch_dialogues_reversed = dialogue_reversed


        batch_dialogues_ran_gaussian_vectors = numpy.zeros((max_batch_sequence_length, bs, model.latent_gaussian_per_utterance_dim), dtype='float32')
        batch_dialogues_ran_uniform_vectors = numpy.zeros((max_batch_sequence_length, bs, model.latent_piecewise_per_utterance_dim), dtype='float32')

        eos_sym_list = numpy.where(batch_dialogues[:, 0] == state['eos_sym'])[0]
        if len(eos_sym_list) > 1:
            second_last_eos_sym = eos_sym_list[-2]
        else:
            print 'WARNING: dialogue does not have at least two EOS tokens!'

        batch_dialogues_mask[:, 0] = 1.0
        batch_dialogues_mask[0:second_last_eos_sym+1, 0] = 0.0

        print '###'
        print '###'
        print '###'
        if False==True:
            mf_reset_batch()
            for i in range(10):
                print  '  SGD Update', i

                batch_dialogues_ran_gaussian_vectors[second_last_eos_sym:, 0, :] = model.rng.normal(loc=0, scale=1, size=model.latent_gaussian_per_utterance_dim)
                batch_dialogues_ran_uniform_vectors[second_last_eos_sym:, 0, :] = model.rng.uniform(low=0.0, high=1.0, size=model.latent_piecewise_per_utterance_dim)


                training_cost, kl_divergence_cost_acc, kl_divergences_between_piecewise_prior_and_posterior, kl_divergences_between_gaussian_prior_and_posterior = mf_update_batch(batch_dialogues, batch_dialogues_reversed, max_batch_sequence_length, batch_dialogues_mask, batch_dialogues_reset_mask, batch_dialogues_ran_gaussian_vectors, batch_dialogues_ran_uniform_vectors, batch_dialogues_drop_mask)

                print '     training_cost', training_cost
                print '     kl_divergence_cost_acc', kl_divergence_cost_acc
                print '     kl_divergences_between_gaussian_prior_and_posterior',  numpy.sum(kl_divergences_between_gaussian_prior_and_posterior)
                print '     kl_divergences_between_piecewise_prior_and_posterior', numpy.sum(kl_divergences_between_piecewise_prior_and_posterior)


        batch_dialogues_ran_gaussian_vectors[second_last_eos_sym:, 0, :] = model.rng.normal(loc=0, scale=1, size=model.latent_gaussian_per_utterance_dim)
        batch_dialogues_ran_uniform_vectors[second_last_eos_sym:, 0, :] = model.rng.uniform(low=0.0, high=1.0, size=model.latent_piecewise_per_utterance_dim)

        gaussian_saliency, piecewise_saliency = saliency_batch(batch_dialogues, batch_dialogues_reversed, max_batch_sequence_length, batch_dialogues_mask, batch_dialogues_reset_mask, batch_dialogues_ran_gaussian_vectors, batch_dialogues_ran_uniform_vectors, batch_dialogues_drop_mask)

        if test_dialogue_idx < 2:
            print 'gaussian_saliency', gaussian_saliency.shape, gaussian_saliency
            print 'piecewise_saliency', piecewise_saliency.shape, piecewise_saliency

        gaussian_sum = 0.0
        piecewise_sum = 0.0
        for i in range(second_last_eos_sym+1, max_batch_sequence_length):
            gaussian_sum += gaussian_saliency[dialogue[i, 0]]
            piecewise_sum += piecewise_saliency[dialogue[i, 0]]

        gaussian_sum = max(gaussian_sum, 0.0000000000001)
        piecewise_sum = max(piecewise_sum, 0.0000000000001)


        print '###'
        print '###'
        print '###'

        print 'Topic: ', ' '.join(test_dialogue_split[0:second_last_eos_sym])

        print ''
        print 'Response', ' '.join(test_dialogue_split[second_last_eos_sym+1:max_batch_sequence_length])
        gaussian_str = ''
        piecewise_str = ''
        for i in range(second_last_eos_sym+1, max_batch_sequence_length):
            gaussian_str += str(gaussian_saliency[dialogue[i, 0]]/gaussian_sum) + ' '
            piecewise_str += str(piecewise_saliency[dialogue[i, 0]]/piecewise_sum) + ' '




        print 'Gaussian_saliency', gaussian_str
        print 'Piecewise_saliency', piecewise_str

        #print ''
        #print 'HEY', gaussian_saliency[:, 0].argsort()[-3:][::-1]
        print ''
        print 'Gaussian Top 3 Words: ', model.indices_to_words(list(gaussian_saliency[:].argsort()[-3:][::-1]))
        print 'Piecewise Top 3 Words: ', model.indices_to_words(list(piecewise_saliency[:].argsort()[-3:][::-1]))

        print 'Gaussian Top 5 Words: ', model.indices_to_words(list(gaussian_saliency[:].argsort()[-5:][::-1]))
        print 'Piecewise Top 5 Words: ', model.indices_to_words(list(piecewise_saliency[:].argsort()[-5:][::-1]))

        print 'Gaussian Top 7 Words: ', model.indices_to_words(list(gaussian_saliency[:].argsort()[-7:][::-1]))
        print 'Piecewise Top 7 Words: ', model.indices_to_words(list(piecewise_saliency[:].argsort()[-7:][::-1]))


    logger.debug("All done, exiting...")