Example #1
0
def main():
    args = parse_args()
    state = prototype_state()

    state_path = args.model_prefix + "_state.pkl"
    model_path = args.model_prefix + "_model.npz"

    with open(state_path) as src:
        state.update(cPickle.load(src))

    logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")

    model = DialogEncoderDecoder(state)
    sampler = search.Sampler(model)

    if os.path.isfile(model_path):
        logger.debug("Loading previous model")
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")

    contexts = [[]]
    lines = open(args.context, "r").readlines()
    contexts = [x.strip().split('\t') for x in lines]

    context_samples, context_costs = sampler.sample(contexts,
                                            n_samples=args.n_samples,
                                            ignore_unk=args.ignore_unk,
                                            verbose=args.verbose)

    # Write to output file
    output_handle = open(args.output, "w")
    for context_sample in context_samples:
        print >> output_handle, '\t'.join(context_sample)
    output_handle.close()
Example #2
0
def init(path):

    ##------------------------------------------------------------------------------##
    #                Compile and load the model.                                     #
    #                Compile the encoder.                                            #
    ##------------------------------------------------------------------------------##

    state = prototype_state()
    state_path = path + "_state.pkl"
    model_path = path + "_model.npz"

    with open(state_path) as src:
        state.update(pickle.load(src))

    logging.basicConfig(
        level=getattr(logging, state['level']),
        format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")

    model = DialogEncoderDecoder(state)

    if os.path.isfile(model_path):
        logger.debug("Loading previous model")
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")

    encoding_function = model.build_encoder_function()

    return model, encoding_function
def init(path):
    
    ##------------------------------------------------------------------------------##
    #                Compile and load the model.                                     #          
    #                Compile the encoder.                                            #                                           
    ##------------------------------------------------------------------------------##

    state = prototype_state()
    state_path   = path  + "_state.pkl"
    model_path   = path  + "_model.npz"
    
    with open(state_path) as src:
        state.update(cPickle.load(src))
    
    logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")

    model = DialogEncoderDecoder(state) 
    
    if os.path.isfile(model_path):
        logger.debug("Loading previous model")
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")
    
    encoding_function = model.build_encoder_function()
    
    return model, encoding_function
Example #4
0
    def __init__(self, model_prefix, dict_file, name):
        # Load the HRED model.
        self.name = name
        state_path = '%s_state.pkl' % model_prefix
        model_path = '%s_model.npz' % model_prefix

        state = prototype_state()
        with open(state_path, 'r') as handle:
            state.update(cPickle.load(handle))
        state['dictionary'] = dict_file
        print 'Building %s model...' % name
        self.model = DialogEncoderDecoder(state)
        print 'Building sampler...'
        self.sampler = search.BeamSampler(self.model)
        print 'Loading model...'
        self.model.load(model_path)
        print 'Model built (%s).' % name

        self.speaker_token = '<first_speaker>'
        if name == 'reddit':
            self.speaker_token = '<speaker_1>'

        self.remove_tokens = ['<first_speaker>', '<at>', '<second_speaker>']
        for i in range(0, 10):
            self.remove_tokens.append('<speaker_%d>' % i)
Example #5
0
def main():
    args = parse_args()
    state = prototype_state()

    state_path = args.model_prefix + "_state.pkl"
    model_path = args.model_prefix + "_model.npz"

    with open(state_path) as src:
        state.update(cPickle.load(src))

    logging.basicConfig(
        level=getattr(logging, state['level']),
        format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")

    model = DialogEncoderDecoder(state)
    if os.path.isfile(model_path):
        logger.debug("Loading previous model")
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")

    logger.info("This model uses " + model.decoder_bias_type + " bias type")

    #sampler = search.RandomSampler(model)
    sampler = search.BeamSampler(model)

    # Start chat loop
    utterances = collections.deque()

    while (True):
        var = raw_input("User - ")

        # Increase number of utterances. We just set it to zero for simplicity so that model has no memory.
        # But it works fine if we increase this number
        while len(utterances) > 0:
            utterances.popleft()

        current_utterance = [model.end_sym_sentence] + [
            '<first_speaker>'
        ] + var.split() + [model.end_sym_sentence]
        utterances.append(current_utterance)

        #TODO Sample a random reply. To spicy it up, we could pick the longest reply or the reply with the fewest placeholders...
        seqs = list(itertools.chain(*utterances))

        #TODO Retrieve only replies which are generated for second speaker...
        sentences = sample(model, \
             seqs=[seqs], ignore_unk=args.ignore_unk, \
             sampler=sampler, n_samples=5)

        if len(sentences) == 0:
            raise ValueError("Generation error, no sentences were produced!")

        utterances.append(sentences[0][0].split())

        reply = sentences[0][0].encode('utf-8')
        print "AI - ", remove_speaker_tokens(reply)
Example #6
0
def main():
    args = parse_args()
    state = prototype_state()

    state_path = args.model_prefix + "_state.pkl"
    model_path = args.model_prefix + "_model.npz"

    with open(state_path) as src:
        state.update(cPickle.load(src))

    logging.basicConfig(
        level=getattr(logging, state['level']),
        format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")

    state['compute_training_updates'] = False

    model = DialogEncoderDecoder(state)

    sampler = search.RandomSampler(model)
    if args.beam_search:
        sampler = search.BeamSampler(model)
    if args.diverse_beam_search:
        sampler = search.DiverseBeamSampler(model, args.gamma)

    if os.path.isfile(model_path):
        logger.debug("Loading previous model")
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")

    contexts = [[]]
    lines = open(args.context, "r").readlines()
    if len(lines):
        contexts = [x.strip() for x in lines]

    print('Sampling started...')
    context_samples, context_costs = sampler.sample(contexts,
                                                    n_samples=args.n_samples,
                                                    n_turns=args.n_turns,
                                                    ignore_unk=args.ignore_unk,
                                                    verbose=args.verbose,
                                                    return_words=True)
    print('Sampling finished.')
    print('Saving to file...')

    # Write to output file
    print type(context_samples)
    print type(context_samples[0])
    print context_samples[0]
    output_handle = open(args.output, "w")
    for context_sample in context_samples:

        print >> output_handle, '\t'.join(context_sample)
    output_handle.close()
    print('Saving to file finished.')
    print('All done!')
Example #7
0
class HRED_Wrapper(object):
    def __init__(self, model_prefix, dict_file, name):
        # Load the HRED model.
        self.name = name
        state_path = '%s_state.pkl' % model_prefix
        model_path = '%s_model.npz' % model_prefix

        state = prototype_state()
        with open(state_path, 'r') as handle:
            state.update(cPickle.load(handle))
        state['dictionary'] = dict_file
        print 'Building %s model...' % name
        self.model = DialogEncoderDecoder(state)
        print 'Building sampler...'
        self.sampler = search.BeamSampler(self.model)
        print 'Loading model...'
        self.model.load(model_path)
        print 'Model built (%s).' % name

        self.speaker_token = '<first_speaker>'
        if name == 'reddit':
            self.speaker_token = '<speaker_1>'

        self.remove_tokens = ['<first_speaker>', '<at>', '<second_speaker>']
        for i in range(0, 10):
            self.remove_tokens.append('<speaker_%d>' % i)

    def _preprocess(self, text):
        text = text.replace("'", " '")
        text = '%s %s </s>' % (self.speaker_token, text.strip().lower())
        return text

    def _format_output(self, text):
        text = text.replace(" '", "'")
        for token in self.remove_tokens:
            text = text.replace(token, '')
        return text

    def get_response(self, user_id, text):
        print '--------------------------------'
        print 'Generating HRED response for user %s.' % user_id
        text = self._preprocess(text)
        ai.history[user_id]['context'].append(text)
        context = list(ai.history[user_id]['context'])
        print 'Using context: %s' % ' '.join(context)
        samples, costs = self.sampler.sample([
            ' '.join(context),
        ],
                                             ignore_unk=True,
                                             verbose=False,
                                             return_words=True)
        response = samples[0][0].replace('@@ ', '').replace('@@', '')
        ai.history[user_id]['context'].append(response)
        response = self._format_output(response)
        print 'Response: %s' % response
        return response
Example #8
0
def main():
    args = parse_args()
    state = prototype_state()
   
    state_path = args.model_prefix + "_state.pkl"
    model_path = args.model_prefix + "_model.npz"

    with open(state_path) as src:
        state.update(cPickle.load(src)) 
    
    logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")
     
    model = DialogEncoderDecoder(state)
    if os.path.isfile(model_path):
        logger.debug("Loading previous model")
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")
    
    logger.info("This model uses " + model.decoder_bias_type + " bias type")

    beam_search = None
    sampler = None

    beam_search = search.BeamSearch(model)
    beam_search.compile()

    # Start chat loop    
    utterances = collections.deque()
    
    while (True):
       var = raw_input("User - ")

       while len(utterances) > 2:
           utterances.popleft()
         
       current_utterance = [ model.start_sym_sentence ] + var.split() + [ model.end_sym_sentence ]
       utterances.append(current_utterance)
         
       # Sample a random reply. To spicy it up, we could pick the longest reply or the reply with the fewest placeholders...
       seqs = list(itertools.chain(*utterances))

       sentences = sample(model, \
            seqs=[seqs], ignore_unk=args.ignore_unk, \
            beam_search=beam_search, n_samples=5)

       if len(sentences) == 0:
           raise ValueError("Generation error, no sentences were produced!")

       reply = " ".join(sentences[0]).encode('utf-8') 
       print "AI - ", reply
         
       utterances.append(sentences[0])
Example #9
0
def main():
    args = parse_args()
    state = prototype_state()

    state_path = args.model_prefix + "_state.pkl"
    model_path = args.model_prefix + "_model.npz"

    with open(state_path) as src:
        state.update(cPickle.load(src))

    logging.basicConfig(
        level=getattr(logging, state["level"]), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s"
    )

    model = DialogEncoderDecoder(state)
    if os.path.isfile(model_path):
        logger.debug("Loading previous model")
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")

    logger.info("This model uses " + model.decoder_bias_type + " bias type")

    beam_search = None
    sampler = None

    beam_search = search.BeamSearch(model)
    beam_search.compile()

    # Start chat loop
    utterances = collections.deque()

    while True:
        var = raw_input("User - ")

        while len(utterances) > 2:
            utterances.popleft()

        current_utterance = [model.start_sym_sentence] + var.split() + [model.end_sym_sentence]
        utterances.append(current_utterance)

        # Sample a random reply. To spicy it up, we could pick the longest reply or the reply with the fewest placeholders...
        seqs = list(itertools.chain(*utterances))

        sentences = sample(model, seqs=[seqs], ignore_unk=args.ignore_unk, beam_search=beam_search, n_samples=5)

        if len(sentences) == 0:
            raise ValueError("Generation error, no sentences were produced!")

        reply = " ".join(sentences[0]).encode("utf-8")
        print "AI - ", reply

        utterances.append(sentences[0])
Example #10
0
def main():
    args = parse_args()
    state = prototype_state()
   
    state_path = args.model_prefix + "_state.pkl"
    model_path = args.model_prefix + "_model.npz"

    with open(state_path) as src:
        state.update(cPickle.load(src)) 
    
    logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")
     
    model = DialogEncoderDecoder(state)
    if os.path.isfile(model_path):
        logger.debug("Loading previous model")
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")
    
    logger.info("This model uses " + model.decoder_bias_type + " bias type")

    #sampler = search.RandomSampler(model)
    sampler = search.BeamSampler(model)

    # Start chat loop    
    utterances = collections.deque()
    
    while (True):
       var = raw_input("User - ")

       # Increase number of utterances. We just set it to zero for simplicity so that model has no memory. 
       # But it works fine if we increase this number
       while len(utterances) > 0:
           utterances.popleft()
         
       current_utterance = [ model.end_sym_sentence ] + ['<first_speaker>'] + var.split() + [ model.end_sym_sentence ]
       utterances.append(current_utterance)
         
       #TODO Sample a random reply. To spicy it up, we could pick the longest reply or the reply with the fewest placeholders...
       seqs = list(itertools.chain(*utterances))

       #TODO Retrieve only replies which are generated for second speaker...
       sentences = sample(model, \
            seqs=[seqs], ignore_unk=args.ignore_unk, \
            sampler=sampler, n_samples=5)

       if len(sentences) == 0:
           raise ValueError("Generation error, no sentences were produced!")

       utterances.append(sentences[0][0].split())

       reply = sentences[0][0].encode('utf-8')
       print "AI - ", remove_speaker_tokens(reply)
Example #11
0
def main():
    args = parse_args()
    state = prototype_state()

    state_path = args.model_prefix + "_state.pkl"
    model_path = args.model_prefix + "_model.npz"
    timing_path = args.model_prefix + "_timing.npz"

    with open(state_path, 'r') as src:
        state.update(cPickle.load(src))
    with open(timing_path, 'r') as src:
        timings = dict(numpy.load(src))

    state['compute_training_updates'] = False

    logging.basicConfig(
        level=getattr(logging, state['level']),
        format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")

    print "\nLoaded previous state, model, timings:"
    print "state:"
    print state
    print "timings:"
    print timings

    print "\nBuilding model..."
    model = DialogEncoderDecoder(state)

    sampler = search.RandomSampler(model)
    if args.beam_search:
        sampler = search.BeamSampler(model)

    if os.path.isfile(model_path):
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")
    print "build.\n"

    context = []
    while True:
        line = raw_input("user: "******"<first_speaker> <at> " + line + " </s> ")
        print "context: ", [' '.join(context[-4:])]
        context_samples, context_costs = sampler.sample(
            [' '.join(context[-4:])],
            ignore_unk=args.ignore_unk,
            verbose=args.verbose,
            return_words=True)

        print "bot:", context_samples
        context.append(context_samples[0][0] + " </s> ")
        print "cost:", context_costs
Example #12
0
def main():
    args = parse_args()
    state = prototype_state()

    state_path = args.model_prefix + "_state.pkl"
    model_path = args.model_prefix + "_model.npz"

    with open(state_path) as src:
        state.update(cPickle.load(src))

    logging.basicConfig(
        level=getattr(logging, state["level"]), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s"
    )

    model = DialogEncoderDecoder(state)

    sampler = search.RandomSampler(model)
    if args.beam_search:
        sampler = search.BeamSampler(model)

    if os.path.isfile(model_path):
        logger.debug("Loading previous model")
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")

    contexts = [[]]
    lines = open(args.context, "r").readlines()
    if len(lines):
        contexts = [x.strip() for x in lines]

    print("Sampling started...")
    context_samples, context_costs = sampler.sample(
        contexts, n_samples=args.n_samples, n_turns=args.n_turns, ignore_unk=args.ignore_unk, verbose=args.verbose
    )
    print("Sampling finished.")
    print("Saving to file...")

    # Write to output file
    output_handle = open(args.output, "w")
    for context_sample in context_samples:
        print >> output_handle, "\t".join(context_sample)
    output_handle.close()
    print("Saving to file finished.")
    print("All done!")
Example #13
0
def main():
    args = parse_args()
    state = prototype_state()

    state_path = args.model_prefix + "_state.pkl"
    model_path = args.model_prefix + "_model.npz"

    with open(state_path) as src:
        state.update(cPickle.load(src))

    logging.basicConfig(
        level=getattr(logging, state['level']),
        format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")

    model = DialogEncoderDecoder(state)

    sampler = search.RandomSampler(model)
    if args.beam_search:
        sampler = search.BeamSampler(model)

    if os.path.isfile(model_path):
        logger.debug("Loading previous model")
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")

    contexts = [[]]
    lines = open(args.context, "r").readlines()
    if len(lines):
        contexts = [x.strip().split('\t') for x in lines]

    context_samples, context_costs = sampler.sample(contexts,
                                                    n_samples=args.n_samples,
                                                    n_turns=args.n_turns,
                                                    ignore_unk=args.ignore_unk,
                                                    verbose=args.verbose)

    # Write to output file
    output_handle = open(args.output, "w")
    for context_sample in context_samples:
        print >> output_handle, '\t'.join(context_sample)
    output_handle.close()
    def _build_vhred_model(self):
        # Update the state dictionary.
        state = VHRED_prototype_state()
        model_prefix = self.config['vhred_prefix']
        state_path = model_prefix + "_state.pkl"
        model_path = model_prefix + "_model.npz"
        with open(state_path, 'rb') as handle:
            state.update(cPickle.load(handle))
        # Update the bs for the current data.
        state['bs'] = 100
        state['dictionary'] = self.f_dict

        # Create the model:
        ## load trained parameters
        model = VHRED_DialogEncoderDecoder(state)
        model.load(model_path)
        enc_fn = model.build_encoder_function()
        dec_fn = model.build_decoder_encoding()

        return model, enc_fn, dec_fn
Example #15
0
def main():
    ####yawa add
    raw_dict = cPickle.load(open('./Data/Dataset.dict.pkl', 'r'))
    str_to_idx = dict([(tok, tok_id) for tok, tok_id, _, _ in raw_dict])
    idx_to_str = dict([(tok_id, tok) for tok, tok_id, _, _ in raw_dict])
    #########

    args = parse_args()
    state = prototype_state()

    state_path = args.model_prefix + "_state.pkl"
    model_path = args.model_prefix + "_model.npz"

    with open(state_path) as src:
        state.update(cPickle.load(src))

    logging.basicConfig(
        level=getattr(logging, state['level']),
        format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")

    model = DialogEncoderDecoder(state)

    sampler = search.RandomSampler(model)
    if args.beam_search:
        sampler = search.BeamSampler(model)

    if os.path.isfile(model_path):
        logger.debug("Loading previous model")
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")

    contexts = [[]]
    lines = open(args.context, "r").readlines()
    if len(lines):
        contexts = [x.strip() for x in lines]
    #contexts = cPickle.load(open('./Data/Test.dialogues.pkl', 'r'))
    print('Sampling started...')
    context_samples, context_costs, att_weights, att_context = sampler.sample(
        contexts,
        n_samples=args.n_samples,
        n_turns=args.n_turns,
        ignore_unk=args.ignore_unk,
        verbose=args.verbose)
    print('Sampling finished.')
    print('Saving to file...')

    # Write to output file
    output_handle = open(args.output, "w")
    for context_sample in context_samples:
        print >> output_handle, '\t'.join(context_sample)
    outline = ''
    #for att_weight in att_weights:
    #for att_in in att_weight:
    #print >> output_handle, str(att_in)
    print "number of weights:" + str(len(att_weights))
    #for i in range(len(att_weights)):
    #outline = att_weights[0]
    cPickle.dump(att_weights,
                 open('Data/beam_search_2000_2_weight.pkl', 'wb'),
                 protocol=cPickle.HIGHEST_PROTOCOL)
    cPickle.dump(att_context,
                 open('Data/beam_search_2000_2_context.pkl', 'wb'),
                 protocol=cPickle.HIGHEST_PROTOCOL)
    #for i in range(len(att_context)):
    #print att_context[i]
    #print numpy.array(att_weights[0])
    #print type(att_weights[0])
    #aa = numpy.array(att_weights[0])
    #size  = aa.shape[1]
    #bb = aa.reshape(5,5,size/5)
    #print bb.shape

    output_handle.close()
    print('Saving to file finished.')
    print('All done!')
Example #16
0
    def preprocess_data(twitter_contexts, twitter_gtresponses, twitter_modelresponses, context_embedding_file, \
            gtresponses_embedding_file, modelresponses_embedding_file, use_precomputed_embeddings, liu=False):
        # Encode text into BPE format
        twitter_context_ids = strs_to_idxs(twitter_contexts, twitter_bpe, twitter_str_to_idx)
        twitter_gtresponse_ids = strs_to_idxs(twitter_gtresponses, twitter_bpe, twitter_str_to_idx)
        twitter_modelresponse_ids = strs_to_idxs(twitter_modelresponses, twitter_bpe, twitter_str_to_idx)
        
        # Compute VHRED embeddings
        if use_precomputed_embeddings:
            print 'Loading precomputed embeddings...'
            with open(context_embedding_file, 'r') as f1:
                twitter_context_embeddings = cPickle.load(f1)
            with open(gtresponses_embedding_file, 'r') as f1:
                twitter_gtresponse_embeddings = cPickle.load(f1)
            with open(modelresponses_embedding_file, 'r') as f1:
                twitter_modelresponse_embeddings = cPickle.load(f1)
        
        elif 'gpu' in theano.config.device.lower():
            print 'Loading model...'
            state = prototype_state()
            state_path = twitter_model_prefix + "_state.pkl"
            model_path = twitter_model_prefix + "_model.npz"

            with open(state_path) as src:
                state.update(cPickle.load(src))

            state['bs'] = 20
            state['dictionary'] = twitter_model_dictionary

            model = DialogEncoderDecoder(state) 
            
            print 'Computing context embeddings...'
            twitter_context_embeddings = compute_model_embeddings(twitter_context_ids, model, embedding_type)
            with open(context_embedding_file, 'w') as f1:
                cPickle.dump(twitter_context_embeddings, f1)
            print 'Computing ground truth response embeddings...'
            twitter_gtresponse_embeddings = compute_model_embeddings(twitter_gtresponse_ids, model, embedding_type)
            with open(gtresponses_embedding_file, 'w') as f1:
                cPickle.dump(twitter_gtresponse_embeddings, f1)
            print 'Computing model response embeddings...'
            twitter_modelresponse_embeddings = compute_model_embeddings(twitter_modelresponse_ids, model, embedding_type)
            with open(modelresponses_embedding_file, 'w') as f1:
                cPickle.dump(twitter_modelresponse_embeddings, f1)
       
        else:
            # Set embeddings to 0 for now. alternatively, we can load them from disc...
            #embeddings = cPickle.load(open(embedding_file, 'rb'))
            print 'ERROR: No GPU specified!'
            print ' To save testing time, model will be trained with zero context / response embeddings...'
            twitter_context_embeddings = np.zeros((len(twitter_context_embeddings), 3, emb_dim))
            twitter_gtresponses_embedding = np.zeros((len(twitter_context_embeddings), 3, emb_dim))
            twitter_modelresponse_embeddings = np.zeros((len(twitter_context_embeddings), 3, emb_dim))

        if not liu:
            # Copy the contexts and gt responses 4 times (to align with the model responses)
            temp_c_emb = []
            temp_gt_emb = []
            temp_gt = []
            for i in xrange(len(twitter_context_embeddings)):
                temp_c_emb.append([twitter_context_embeddings[i]]*4)
                temp_gt_emb.append([twitter_gtresponse_embeddings[i]]*4)
                temp_gt.append([twitter_gtresponses[i]]*4)
            twitter_context_embeddings = flatten(temp_c_emb)
            twitter_gtresponse_embeddings = flatten(temp_gt_emb)
            twitter_gtresponses = flatten(temp_gt)

        assert len(twitter_context_embeddings) == len(twitter_gtresponse_embeddings)
        assert len(twitter_context_embeddings) == len(twitter_modelresponse_embeddings)

        emb_dim = twitter_context_embeddings[0].shape[0]
        
        twitter_dialogue_embeddings = np.zeros((len(twitter_context_embeddings), 3, emb_dim))
        for i in range(len(twitter_context_embeddings)):
            twitter_dialogue_embeddings[i, 0, :] =  twitter_context_embeddings[i]
            twitter_dialogue_embeddings[i, 1, :] =  twitter_gtresponse_embeddings[i]
            twitter_dialogue_embeddings[i, 2, :] =  twitter_modelresponse_embeddings[i]
     
        print 'Computing auxiliary features...'
        if use_aux_features:
            aux_features = get_auxiliary_features(twitter_contexts, twitter_gtresponses, twitter_modelresponses, len(twitter_modelresponses))
        else:
            aux_features = np.zeros((len(twitter_modelresponses), 5))
        
        return twitter_dialogue_embeddings, aux_features 
Example #17
0
def main():
    args = parse_args()
    state = prototype_state()
   
    state_path = args.model_prefix + "_state.pkl"
    model_path = args.model_prefix + "_model.npz"

    with open(state_path) as src:
        state.update(cPickle.load(src)) 
    
    logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")
     
    model = DialogEncoderDecoder(state)
    if os.path.isfile(model_path):
        logger.debug("Loading previous model")
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")
    
    eval_batch = model.build_eval_function()
    eval_misclass_batch = model.build_eval_misclassification_function()
    
    if args.test_path:
        state['test_triples'] = args.test_path

    # Initialize list of stopwords to remove
    if args.exclude_stop_words:
        logger.debug("Initializing stop-word list")
        stopwords_lowercase = stopwords.lower().split(' ')
        stopwords_indices = []
        for word in stopwords_lowercase:
            if word in model.str_to_idx:
                stopwords_indices.append(model.str_to_idx[word])

    test_data = get_test_iterator(state)
    test_data.start()

    # Load document ids
    if args.document_ids:
        labels_file = open(args.document_ids, 'r')
        labels_text = labels_file.readlines()
        document_ids = numpy.zeros((len(labels_text)), dtype='int32')
        for i in range(len(labels_text)):
            document_ids[i] = int(labels_text[i].split('\t')[0])

        unique_document_ids = numpy.unique(document_ids)
        
        assert(test_data.data_len == document_ids.shape[0])

    else:
        print 'Warning no file with document ids given... standard deviations cannot be computed.'
        document_ids = numpy.zeros((test_data.data_len), dtype='int32')
        unique_document_ids = numpy.unique(document_ids)
    
    # Variables to store test statistics
    test_cost = 0 # negative log-likelihood
    test_cost_first_utterances = 0 # marginal negative log-likelihood of first two utterances
    test_cost_last_utterance_marginal = 0 # marginal (approximate) negative log-likelihood of last utterances
    test_misclass = 0 # misclassification error-rate
    test_misclass_first_utterances = 0 # misclassification error-rate of first two utterances
    test_empirical_mutual_information = 0  # empirical mutual information between first two utterances and third utterance, where the marginal P(U_3) is approximated by P(U_3, empty, empty).

    if model.bootstrap_from_semantic_information:
        test_semantic_cost = 0
        test_semantic_misclass = 0

    test_wordpreds_done = 0 # number of words in total
    test_wordpreds_done_last_utterance = 0 # number of words in last utterances
    test_triples_done = 0 # number of triples evaluated

    # Variables to compute negative log-likelihood and empirical mutual information per genre
    compute_genre_specific_metrics = False
    if hasattr(model, 'semantic_information_dim'):
        compute_genre_specific_metrics = True
        test_cost_per_genre = numpy.zeros((model.semantic_information_dim, 1), dtype='float32')
        test_mi_per_genre = numpy.zeros((model.semantic_information_dim, 1), dtype='float32')
        test_wordpreds_done_per_genre = numpy.zeros((model.semantic_information_dim, 1), dtype='float32')
        test_triples_done_per_genre = numpy.zeros((model.semantic_information_dim, 1), dtype='float32')

    # Number of triples in dataset
    test_data_len = test_data.data_len

    # Correspond to the same variables as above, but now for each triple.
    # e.g. test_cost_list is a numpy array with the negative log-likelihood for each triple in the test set
    test_cost_list = numpy.zeros((test_data_len,))
    test_pmi_list = numpy.zeros((test_data_len,))
    test_cost_last_utterance_marginal_list = numpy.zeros((test_data_len,))
    test_misclass_list = numpy.zeros((test_data_len,))
    test_misclass_last_utterance_list = numpy.zeros((test_data_len,))

    # Array containing number of words in each triple
    words_in_triples_list = numpy.zeros((test_data_len,))

    # Array containing number of words in last utterance of each triple
    words_in_last_utterance_list = numpy.zeros((test_data_len,))

    # Prepare variables for printing the test examples the model performs best and worst on
    test_extrema_setsize = min(state['track_extrema_samples_count'], test_data_len)
    test_extrema_samples_to_print = min(state['print_extrema_samples_count'], test_extrema_setsize)

    test_lowest_costs = numpy.ones((test_extrema_setsize,))*1000
    test_lowest_triples = numpy.ones((test_extrema_setsize,state['seqlen']))*1000
    test_highest_costs = numpy.ones((test_extrema_setsize,))*(-1000)
    test_highest_triples = numpy.ones((test_extrema_setsize,state['seqlen']))*(-1000)

    logger.debug("[TEST START]") 

    while True:
        batch = test_data.next()
        # Train finished
        if not batch:
            break
         
        logger.debug("[TEST] - Got batch %d,%d" % (batch['x'].shape[1], batch['max_length']))

        x_data = batch['x']
        x_data_reversed = batch['x_reversed']
        max_length = batch['max_length']
        x_cost_mask = batch['x_mask']
        x_semantic = batch['x_semantic']
        x_semantic_nonempty_indices = numpy.where(x_semantic >= 0)

        # Hack to get rid of start of sentence token.
        if args.exclude_sos and model.sos_sym != -1:
            x_cost_mask[x_data == model.sos_sym] = 0

        if args.exclude_stop_words:
            for word_index in stopwords_indices:
                x_cost_mask[x_data == word_index] = 0

        batch['num_preds'] = numpy.sum(x_cost_mask)

        c, c_list = eval_batch(x_data, x_data_reversed, max_length, x_cost_mask, x_semantic)
        
        c_list = c_list.reshape((batch['x'].shape[1],max_length), order=(1,0))
        c_list = numpy.sum(c_list, axis=1)
       

        # Compute genre specific stats...
        if compute_genre_specific_metrics:
            non_nan_entries = numpy.array(c_list >= 0, dtype=int)
            c_list[numpy.where(non_nan_entries==0)] = 0
            test_cost_per_genre += (numpy.asmatrix(non_nan_entries*c_list) * numpy.asmatrix(x_semantic)).T
            test_wordpreds_done_per_genre += (numpy.asmatrix(non_nan_entries*numpy.sum(x_cost_mask, axis=0)) * numpy.asmatrix(x_semantic)).T

        if numpy.isinf(c) or numpy.isnan(c):
            continue
        
        test_cost += c

        # Store test costs in list
        nxt =  min((test_triples_done+batch['x'].shape[1]), test_data_len)
        triples_in_batch = nxt-test_triples_done

        words_in_triples = numpy.sum(x_cost_mask, axis=0)
        words_in_triples_list[(nxt-triples_in_batch):nxt] = words_in_triples[0:triples_in_batch]

        # We don't need to normalzie by the number of words... not if we're computing standard deviations at least...
        test_cost_list[(nxt-triples_in_batch):nxt] = c_list[0:triples_in_batch]

        # Store best and worst test costs        
        con_costs = numpy.concatenate([test_lowest_costs, c_list[0:triples_in_batch]])
        con_triples = numpy.concatenate([test_lowest_triples, x_data[:, 0:triples_in_batch].T], axis=0)
        con_indices = con_costs.argsort()[0:test_extrema_setsize][::1]
        test_lowest_costs = con_costs[con_indices]
        test_lowest_triples = con_triples[con_indices]

        con_costs = numpy.concatenate([test_highest_costs, c_list[0:triples_in_batch]])
        con_triples = numpy.concatenate([test_highest_triples, x_data[:, 0:triples_in_batch].T], axis=0)
        con_indices = con_costs.argsort()[-test_extrema_setsize:][::-1]
        test_highest_costs = con_costs[con_indices]
        test_highest_triples = con_triples[con_indices]

        # Compute word-error rate
        miscl, miscl_list = eval_misclass_batch(x_data, x_data_reversed, max_length, x_cost_mask, x_semantic)
        if numpy.isinf(c) or numpy.isnan(c):
            continue

        test_misclass += miscl

        # Store misclassification errors in list
        miscl_list = miscl_list.reshape((batch['x'].shape[1],max_length), order=(1,0))
        miscl_list = numpy.sum(miscl_list, axis=1)
        test_misclass_list[(nxt-triples_in_batch):nxt] = miscl_list[0:triples_in_batch]

        # Equations to compute empirical mutual information

        # Compute marginal log-likelihood of last utterance in triple:
        # We approximate it with the margina log-probabiltiy of the utterance being observed first in the triple
        x_data_last_utterance = batch['x_last_utterance']
        x_data_last_utterance_reversed = batch['x_last_utterance_reversed']
        x_cost_mask_last_utterance = batch['x_mask_last_utterance']
        x_start_of_last_utterance = batch['x_start_of_last_utterance']

        # Hack to get rid of start of sentence token.
        if args.exclude_sos and model.sos_sym != -1:
            x_cost_mask_last_utterance[x_data_last_utterance == model.sos_sym] = 0

        if args.exclude_stop_words:
            for word_index in stopwords_indices:
                x_cost_mask_last_utterance[x_data_last_utterance == word_index] = 0


        words_in_last_utterance = numpy.sum(x_cost_mask_last_utterance, axis=0)
        words_in_last_utterance_list[(nxt-triples_in_batch):nxt] = words_in_last_utterance[0:triples_in_batch]

        batch['num_preds_at_utterance'] = numpy.sum(x_cost_mask_last_utterance)

        marginal_last_utterance_loglikelihood, marginal_last_utterance_loglikelihood_list = eval_batch(x_data_last_utterance, x_data_last_utterance_reversed, max_length, x_cost_mask_last_utterance, x_semantic)

        marginal_last_utterance_loglikelihood_list = marginal_last_utterance_loglikelihood_list.reshape((batch['x'].shape[1],max_length), order=(1,0))
        marginal_last_utterance_loglikelihood_list = numpy.sum(marginal_last_utterance_loglikelihood_list, axis=1)
        test_cost_last_utterance_marginal_list[(nxt-triples_in_batch):nxt] = marginal_last_utterance_loglikelihood_list[0:triples_in_batch]

        # Compute marginal log-likelihood of first utterances in triple by masking the last utterance
        x_cost_mask_first_utterances = numpy.copy(x_cost_mask)
        for i in range(batch['x'].shape[1]):
            x_cost_mask_first_utterances[x_start_of_last_utterance[i]:max_length, i] = 0

        marginal_first_utterances_loglikelihood, marginal_first_utterances_loglikelihood_list = eval_batch(x_data, x_data_reversed, max_length, x_cost_mask_first_utterances, x_semantic)

        marginal_first_utterances_loglikelihood_list = marginal_first_utterances_loglikelihood_list.reshape((batch['x'].shape[1],max_length), order=(1,0))
        marginal_first_utterances_loglikelihood_list = numpy.sum(marginal_first_utterances_loglikelihood_list, axis=1)

        # Compute empirical mutual information and pointwise empirical mutual information
        test_empirical_mutual_information += -c + marginal_first_utterances_loglikelihood + marginal_last_utterance_loglikelihood
        test_pmi_list[(nxt-triples_in_batch):nxt] = (-c_list*words_in_triples + marginal_first_utterances_loglikelihood_list + marginal_last_utterance_loglikelihood_list)[0:triples_in_batch]

        # Compute genre specific stats...
        if compute_genre_specific_metrics:
            if triples_in_batch==batch['x'].shape[1]:
                mi_list = (-c_list*words_in_triples + marginal_first_utterances_loglikelihood_list + marginal_last_utterance_loglikelihood_list)[0:triples_in_batch]
                non_nan_entries = numpy.array(mi_list >= 0, dtype=int)*numpy.array(mi_list != numpy.nan, dtype=int)
                test_mi_per_genre += (numpy.asmatrix(non_nan_entries*mi_list) * numpy.asmatrix(x_semantic)).T
                test_triples_done_per_genre += numpy.reshape(numpy.sum(x_semantic, axis=0), test_triples_done_per_genre.shape)

        # Store log P(U_1, U_2) cost computed during mutual information
        test_cost_first_utterances += marginal_first_utterances_loglikelihood

        # Store marginal log P(U_3)
        test_cost_last_utterance_marginal += marginal_last_utterance_loglikelihood


        # Compute word-error rate for first utterances
        miscl_first_utterances, miscl_first_utterances_list = eval_misclass_batch(x_data, x_data_reversed, max_length, x_cost_mask_first_utterances, x_semantic)
        test_misclass_first_utterances += miscl_first_utterances
        if numpy.isinf(c) or numpy.isnan(c):
            continue

        # Store misclassification for last utterance
        miscl_first_utterances_list = miscl_first_utterances_list.reshape((batch['x'].shape[1],max_length), order=(1,0))
        miscl_first_utterances_list = numpy.sum(miscl_first_utterances_list, axis=1)

        miscl_last_utterance_list = miscl_list - miscl_first_utterances_list

        test_misclass_last_utterance_list[(nxt-triples_in_batch):nxt] = miscl_last_utterance_list[0:triples_in_batch]


        if model.bootstrap_from_semantic_information:
            # Compute cross-entropy error on predicting the semantic class and retrieve predictions
            sem_eval = eval_semantic_batch(x_data, x_data_reversed, max_length, x_cost_mask, x_semantic)

            # Evaluate only non-empty triples (empty triples are created to fill 
            #   the whole batch sometimes).
            sem_cost = sem_eval[0][-1, :, :]
            test_semantic_cost += numpy.sum(sem_cost[x_semantic_nonempty_indices])

            # Compute misclassified predictions on last timestep over all labels
            sem_preds = sem_eval[1][-1, :, :]
            sem_preds_misclass = len(numpy.where(((x_semantic-0.5)*(sem_preds-0.5))[x_semantic_nonempty_indices] < 0)[0])
            test_semantic_misclass += sem_preds_misclass


        test_wordpreds_done += batch['num_preds']
        test_wordpreds_done_last_utterance += batch['num_preds_at_utterance']
        test_triples_done += batch['num_triples']
     
    logger.debug("[TEST END]") 

    test_cost_last_utterance_marginal /= test_wordpreds_done_last_utterance
    test_cost_last_utterance = (test_cost - test_cost_first_utterances) / test_wordpreds_done_last_utterance
    test_cost /= test_wordpreds_done
    test_cost_first_utterances /= float(test_wordpreds_done - test_wordpreds_done_last_utterance)

    test_misclass_last_utterance = float(test_misclass - test_misclass_first_utterances) / float(test_wordpreds_done_last_utterance)
    test_misclass_first_utterances /= float(test_wordpreds_done - test_wordpreds_done_last_utterance)
    test_misclass /= float(test_wordpreds_done)
    test_empirical_mutual_information /= float(test_triples_done)

    if model.bootstrap_from_semantic_information:
        test_semantic_cost /= float(test_triples_done)
        test_semantic_misclass /= float(test_done_triples)
        print "** test semantic cost = %.4f, test semantic misclass error = %.4f" % (float(test_semantic_cost), float(test_semantic_misclass))

    print "** test cost (NLL) = %.4f, test word-perplexity = %.4f, test word-perplexity last utterance = %.4f, test word-perplexity marginal last utterance = %.4f, test mean word-error = %.4f, test mean word-error last utterance = %.4f, test emp. mutual information = %.4f" % (float(test_cost), float(math.exp(test_cost)), float(math.exp(test_cost_last_utterance)), float(math.exp(test_cost_last_utterance_marginal)), float(test_misclass), float(test_misclass_last_utterance), test_empirical_mutual_information)

    if compute_genre_specific_metrics:
        print '** test perplexity per genre', numpy.exp(test_cost_per_genre/test_wordpreds_done_per_genre)
        print '** test_mi_per_genre', test_mi_per_genre

        print '** words per genre', test_wordpreds_done_per_genre




    # Plot histogram over test costs
    if args.plot_graphs:
        try:
            pylab.figure()
            bins = range(0, 50, 1)
            pylab.hist(numpy.exp(test_cost_list), normed=1, histtype='bar')
            pylab.savefig(model.state['save_dir'] + '/' + model.state['run_id'] + "_" + model.state['prefix'] + 'Test_WordPerplexities.png')
        except:
            pass

    # Print 5 of 10% test samples with highest log-likelihood
    if args.plot_graphs:
        print " highest word log-likelihood test samples: " 
        numpy.random.shuffle(test_lowest_triples)
        for i in range(test_extrema_samples_to_print):
            print "      Sample: {}".format(" ".join(model.indices_to_words(numpy.ravel(test_lowest_triples[i,:]))))

        print " lowest word log-likelihood test samples: " 
        numpy.random.shuffle(test_highest_triples)
        for i in range(test_extrema_samples_to_print):
            print "      Sample: {}".format(" ".join(model.indices_to_words(numpy.ravel(test_highest_triples[i,:]))))


    # Plot histogram over empirical pointwise mutual informations
    if args.plot_graphs:
        try:
            pylab.figure()
            bins = range(0, 100, 1)
            pylab.hist(test_pmi_list, normed=1, histtype='bar')
            pylab.savefig(model.state['save_dir'] + '/' + model.state['run_id'] + "_" + model.state['prefix'] + 'Test_PMI.png')
        except:
            pass

    # To estimate the standard deviations, we assume that triples across documents (movies) are independent.
    # We compute the mean metric for each document, and then the variance between documents.
    # We then use the between document variance to compute the:
    # Let m be a metric:
    # Var[m] = Var[1/(words in total) \sum_d \sum_i m_{di}]
    #        = Var[1/(words in total) \sum_d (words in doc d)/(words in doc d) \sum_i m_{di}]
    #        = \sum_d (words in doc d)^2/(words in total)^2 Var [ 1/(words in doc d) \sum_i ]
    #        = \sum_d (words in doc d)^2/(words in total)^2 sigma^2
    #
    # where sigma^2 is the variance computed for the means across documents.

    # negative log-likelihood for each document (movie)
    per_document_test_cost = numpy.zeros((len(unique_document_ids)), dtype='float32')
    # negative log-likelihood for last utterance for each document (movie)
    per_document_test_cost_last_utterance = numpy.zeros((len(unique_document_ids)), dtype='float32')
    # misclassification error for each document (movie)
    per_document_test_misclass = numpy.zeros((len(unique_document_ids)), dtype='float32')
    # misclassification error for last utterance for each document (movie)
    per_document_test_misclass_last_utterance = numpy.zeros((len(unique_document_ids)), dtype='float32')


    # Compute standard deviations based on means across documents (sigma^2 above)
    all_words_squared = 0 # \sum_d (words in doc d)^2
    all_words_in_last_utterance_squared = 0 # \sum_d (words in last utterance of doc d)^2
    for doc_id in range(len(unique_document_ids)):
        doc_indices = numpy.where(document_ids == unique_document_ids[doc_id])

        per_document_test_cost[doc_id] = numpy.sum(test_cost_list[doc_indices]) / numpy.sum(words_in_triples_list[doc_indices])
        per_document_test_cost_last_utterance[doc_id] = numpy.sum(test_cost_last_utterance_marginal_list[doc_indices]) / numpy.sum(words_in_last_utterance_list[doc_indices])

        per_document_test_misclass[doc_id] = numpy.sum(test_misclass_list[doc_indices]) / numpy.sum(words_in_triples_list[doc_indices])
        per_document_test_misclass_last_utterance[doc_id] = numpy.sum(test_misclass_last_utterance_list[doc_indices]) / numpy.sum(words_in_last_utterance_list[doc_indices])

        all_words_squared += float(numpy.sum(words_in_triples_list[doc_indices]))**2
        all_words_in_last_utterance_squared += float(numpy.sum(words_in_last_utterance_list[doc_indices]))**2

    # Sanity check that all documents are being used in the standard deviation calculations
    assert(numpy.sum(words_in_triples_list) == test_wordpreds_done)
    assert(numpy.sum(words_in_last_utterance_list) == test_wordpreds_done_last_utterance)

    # Compute final standard deviation equation and print the standard deviations
    per_document_test_cost_variance = numpy.var(per_document_test_cost) * float(all_words_squared) / float(test_wordpreds_done**2)
    per_document_test_cost_last_utterance_variance = numpy.var(per_document_test_cost_last_utterance) * float(all_words_in_last_utterance_squared) / float(test_wordpreds_done_last_utterance**2)
    per_document_test_misclass_variance = numpy.var(per_document_test_misclass) * float(all_words_squared) / float(test_wordpreds_done**2)
    per_document_test_misclass_last_utterance_variance = numpy.var(per_document_test_misclass_last_utterance) * float(all_words_in_last_utterance_squared) / float(test_wordpreds_done_last_utterance**2)

    print 'Standard deviations:'
    print "** test cost (NLL) = ", math.sqrt(per_document_test_cost_variance)
    print "** test perplexity (NLL) = ", math.sqrt((math.exp(per_document_test_cost_variance) - 1)*math.exp(2*test_cost+per_document_test_cost_variance))

    print "** test cost last utterance (NLL) = ", math.sqrt(per_document_test_cost_last_utterance_variance)
    print "** test perplexity last utterance  (NLL) = ", math.sqrt((math.exp(per_document_test_cost_last_utterance_variance) - 1)*math.exp(2*test_cost+per_document_test_cost_last_utterance_variance))

    print "** test word-error = ", math.sqrt(per_document_test_misclass_variance)
    print "** test last utterance word-error = ", math.sqrt(per_document_test_misclass_last_utterance_variance)

    logger.debug("All done, exiting...")
Example #18
0
def main():
    args = parse_args()
    state = prototype_state()
   
    state_path = args.model_prefix + "_state.pkl"
    model_path = args.model_prefix + "_model.npz"

    with open(state_path) as src:
        state.update(cPickle.load(src)) 
    
    logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")

    state['lr'] = 0.01
    state['bs'] = 2

    state['compute_training_updates'] = False
    state['apply_meanfield_inference'] = True

    model = DialogEncoderDecoder(state)
    if os.path.isfile(model_path):
        logger.debug("Loading previous model")
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")
    


    mf_update_batch = model.build_mf_update_function()
    mf_reset_batch = model.build_mf_reset_function()

    saliency_batch = model.build_saliency_eval_function()

    test_dialogues = open(args.test_dialogues, 'r').readlines()
    for test_dialogue_idx, test_dialogue in enumerate(test_dialogues):
        #print 'Visualizing dialogue: ', test_dialogue
        test_dialogue_split = test_dialogue.split()
        # Convert dialogue into list of ids
        dialogue = []
        if len(test_dialogue) == 0:
            dialogue = [model.eos_sym]
        else:
            sentence_ids = model.words_to_indices(test_dialogue_split)
            # Add eos tokens
            if len(sentence_ids) > 0:
                if not sentence_ids[0] == model.eos_sym:
                    sentence_ids = [model.eos_sym] + sentence_ids
                if not sentence_ids[-1] == model.eos_sym:
                    sentence_ids += [model.eos_sym]
            else:
                sentence_ids = [model.eos_sym]


            dialogue += sentence_ids

        if len(dialogue) > 3:
            if ((dialogue[-1] == model.eos_sym)
             and (dialogue[-2] == model.eod_sym)
             and (dialogue[-3] == model.eos_sym)):
                del dialogue[-1]
                del dialogue[-1]

        
        dialogue = numpy.asarray(dialogue, dtype='int32').reshape((len(dialogue), 1))
        #print 'dialogue', dialogue
        dialogue_reversed = model.reverse_utterances(dialogue)

        max_batch_sequence_length = len(dialogue)
        bs = state['bs']

        # Initialize batch with zeros
        batch_dialogues = numpy.zeros((max_batch_sequence_length, bs), dtype='int32')
        batch_dialogues_reversed = numpy.zeros((max_batch_sequence_length, bs), dtype='int32')
        batch_dialogues_mask = numpy.zeros((max_batch_sequence_length, bs), dtype='float32')
        batch_dialogues_reset_mask = numpy.zeros((bs), dtype='float32')
        batch_dialogues_drop_mask = numpy.ones((max_batch_sequence_length, bs), dtype='float32')

        # Fill in batch with values
        batch_dialogues[:,0]  = dialogue[:, 0]
        batch_dialogues_reversed[:,0] = dialogue_reversed[:, 0]
        #batch_dialogues  = dialogue
        #batch_dialogues_reversed = dialogue_reversed


        batch_dialogues_ran_gaussian_vectors = numpy.zeros((max_batch_sequence_length, bs, model.latent_gaussian_per_utterance_dim), dtype='float32')
        batch_dialogues_ran_uniform_vectors = numpy.zeros((max_batch_sequence_length, bs, model.latent_piecewise_per_utterance_dim), dtype='float32')

        eos_sym_list = numpy.where(batch_dialogues[:, 0] == state['eos_sym'])[0]
        if len(eos_sym_list) > 1:
            second_last_eos_sym = eos_sym_list[-2]
        else:
            print 'WARNING: dialogue does not have at least two EOS tokens!'

        batch_dialogues_mask[:, 0] = 1.0
        batch_dialogues_mask[0:second_last_eos_sym+1, 0] = 0.0

        print '###'
        print '###'
        print '###'
        if False==True:
            mf_reset_batch()
            for i in range(10):
                print  '  SGD Update', i

                batch_dialogues_ran_gaussian_vectors[second_last_eos_sym:, 0, :] = model.rng.normal(loc=0, scale=1, size=model.latent_gaussian_per_utterance_dim)
                batch_dialogues_ran_uniform_vectors[second_last_eos_sym:, 0, :] = model.rng.uniform(low=0.0, high=1.0, size=model.latent_piecewise_per_utterance_dim)


                training_cost, kl_divergence_cost_acc, kl_divergences_between_piecewise_prior_and_posterior, kl_divergences_between_gaussian_prior_and_posterior = mf_update_batch(batch_dialogues, batch_dialogues_reversed, max_batch_sequence_length, batch_dialogues_mask, batch_dialogues_reset_mask, batch_dialogues_ran_gaussian_vectors, batch_dialogues_ran_uniform_vectors, batch_dialogues_drop_mask)

                print '     training_cost', training_cost
                print '     kl_divergence_cost_acc', kl_divergence_cost_acc
                print '     kl_divergences_between_gaussian_prior_and_posterior',  numpy.sum(kl_divergences_between_gaussian_prior_and_posterior)
                print '     kl_divergences_between_piecewise_prior_and_posterior', numpy.sum(kl_divergences_between_piecewise_prior_and_posterior)


        batch_dialogues_ran_gaussian_vectors[second_last_eos_sym:, 0, :] = model.rng.normal(loc=0, scale=1, size=model.latent_gaussian_per_utterance_dim)
        batch_dialogues_ran_uniform_vectors[second_last_eos_sym:, 0, :] = model.rng.uniform(low=0.0, high=1.0, size=model.latent_piecewise_per_utterance_dim)

        gaussian_saliency, piecewise_saliency = saliency_batch(batch_dialogues, batch_dialogues_reversed, max_batch_sequence_length, batch_dialogues_mask, batch_dialogues_reset_mask, batch_dialogues_ran_gaussian_vectors, batch_dialogues_ran_uniform_vectors, batch_dialogues_drop_mask)

        if test_dialogue_idx < 2:
            print 'gaussian_saliency', gaussian_saliency.shape, gaussian_saliency
            print 'piecewise_saliency', piecewise_saliency.shape, piecewise_saliency

        gaussian_sum = 0.0
        piecewise_sum = 0.0
        for i in range(second_last_eos_sym+1, max_batch_sequence_length):
            gaussian_sum += gaussian_saliency[dialogue[i, 0]]
            piecewise_sum += piecewise_saliency[dialogue[i, 0]]

        gaussian_sum = max(gaussian_sum, 0.0000000000001)
        piecewise_sum = max(piecewise_sum, 0.0000000000001)


        print '###'
        print '###'
        print '###'

        print 'Topic: ', ' '.join(test_dialogue_split[0:second_last_eos_sym])

        print ''
        print 'Response', ' '.join(test_dialogue_split[second_last_eos_sym+1:max_batch_sequence_length])
        gaussian_str = ''
        piecewise_str = ''
        for i in range(second_last_eos_sym+1, max_batch_sequence_length):
            gaussian_str += str(gaussian_saliency[dialogue[i, 0]]/gaussian_sum) + ' '
            piecewise_str += str(piecewise_saliency[dialogue[i, 0]]/piecewise_sum) + ' '




        print 'Gaussian_saliency', gaussian_str
        print 'Piecewise_saliency', piecewise_str

        #print ''
        #print 'HEY', gaussian_saliency[:, 0].argsort()[-3:][::-1]
        print ''
        print 'Gaussian Top 3 Words: ', model.indices_to_words(list(gaussian_saliency[:].argsort()[-3:][::-1]))
        print 'Piecewise Top 3 Words: ', model.indices_to_words(list(piecewise_saliency[:].argsort()[-3:][::-1]))

        print 'Gaussian Top 5 Words: ', model.indices_to_words(list(gaussian_saliency[:].argsort()[-5:][::-1]))
        print 'Piecewise Top 5 Words: ', model.indices_to_words(list(piecewise_saliency[:].argsort()[-5:][::-1]))

        print 'Gaussian Top 7 Words: ', model.indices_to_words(list(gaussian_saliency[:].argsort()[-7:][::-1]))
        print 'Piecewise Top 7 Words: ', model.indices_to_words(list(piecewise_saliency[:].argsort()[-7:][::-1]))


    logger.debug("All done, exiting...")
Example #19
0
def main():
    args = parse_args()
    state = prototype_ubuntu_HRED()  #prototype_state()

    state_path = args.model_prefix + "_state.pkl"
    model_path = args.model_prefix + "_model.npz"

    with open(state_path) as src:
        state.update(cPickle.load(src))
    state['dictionary'] = "/home/ml/rlowe1/UbuntuData/Dataset.dict.pkl"

    # MODIFIED: Removed since configuring logging has to be before construction of any logging object
    # logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")
    logging.basicConfig(
        format=
        '[%(asctime)s][%(levelname)s][%(filename)s][%(lineno)d] - %(message)s',
        datefmt='%d/%m/%Y %H:%M:%S',
        filename=
        '/home/2016/pparth2/Desktop/gods/Goal-Oriented_Dialogue_Systems/Gods-master/agents/hred/log.chat',
        filemode='a',
        level=logging.DEBUG)

    logger = logging.getLogger(__name__)

    model = DialogEncoderDecoder(state)
    if os.path.isfile(model_path):

        logger.debug("Loading previous model")
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")

    logger.info("This model uses " + model.decoder_bias_type + " bias type")

    #sampler = search.RandomSampler(model)
    sampler = search.BeamSampler(model)

    # Start chat loop
    utterances = collections.deque()

    while (True):
        var = raw_input("User - ")

        # Increase number of utterances. We just set it to zero for simplicity so that model has no memory.
        # But it works fine if we increase this number
        while len(utterances) > 0:
            utterances.popleft()

        current_utterance = [model.end_sym_utterance] + [
            '<first_speaker>'
        ] + var.split() + [model.end_sym_utterance]
        utterances.append(current_utterance)

        #TODO Sample a random reply. To spice it up, we could pick the longest reply or the reply with the fewest placeholders...
        seqs = list(itertools.chain(*utterances))

        #TODO Retrieve only replies which are generated for second speaker...
        sentences = sample(model, \
             seqs= [seqs], ignore_unk=args.ignore_unk, \
             sampler=sampler, n_samples=1)

        if len(sentences) == 0:
            raise ValueError("Generation error, no sentences were produced!")

        utterances.append(sentences[0][0].split())

        reply = sentences[0][0].encode('utf-8')
        print "AI - ", remove_speaker_tokens(reply)
Example #20
0
def main():
    args = parse_args()
    state = prototype_state()

    state_path = args.model_prefix + "_state.pkl"
    model_path = args.model_prefix + "_model.npz"

    with open(state_path) as src:
        state.update(cPickle.load(src))

    logging.basicConfig(
        level=getattr(logging, state['level']),
        format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")

    model = DialogEncoderDecoder(state)

    if os.path.isfile(model_path):
        logger.debug("Loading previous model")
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")

    contexts = [[]]
    lines = open(args.dialogues, "r").readlines()
    if len(lines):
        contexts = [x.strip().split('\t') for x in lines]

    model_compute_encoding = model.build_encoder_function()
    dialogue_encodings = []

    # Start loop
    joined_contexts = []
    batch_index = 0
    batch_total = int(math.ceil(float(len(contexts)) / float(model.bs)))
    for context_id, context_sentences in enumerate(contexts):

        # Convert contextes into list of ids
        joined_context = []

        if len(context_sentences) == 0:
            joined_context = [model.eos_sym]
        else:
            for sentence in context_sentences:
                sentence_ids = model.words_to_indices(sentence.split())
                # Add sos and eos tokens
                joined_context += [model.sos_sym
                                   ] + sentence_ids + [model.eos_sym]

        # HACK
        for i in range(0, 50):
            joined_context += [model.sos_sym] + [0] + [model.eos_sym]

        joined_contexts.append(joined_context)

        if len(joined_contexts) == model.bs:
            batch_index = batch_index + 1
            logger.debug("[COMPUTE] - Got batch %d / %d" %
                         (batch_index, batch_total))
            encs = compute_encodings(joined_contexts, model,
                                     model_compute_encoding,
                                     args.use_second_last_state)
            for i in range(len(encs)):
                dialogue_encodings.append(encs[i])

            joined_contexts = []

    if len(joined_contexts) > 0:
        logger.debug("[COMPUTE] - Got batch %d / %d" %
                     (batch_total, batch_total))
        encs = compute_encodings(joined_contexts, model,
                                 model_compute_encoding,
                                 args.use_second_last_state)
        for i in range(len(encs)):
            dialogue_encodings.append(encs[i])

    # Save encodings to disc
    cPickle.dump(dialogue_encodings, open(args.output + '.pkl', 'w'))
Example #21
0
def main(model_prefix, dialogue_file, use_second_last_state):
    state = prototype_state()

    state_path = model_prefix + "_state.pkl"
    model_path = model_prefix + "_model.npz"

    with open(state_path) as src:
        state.update(cPickle.load(src))

    logging.basicConfig(
        level=getattr(logging, state['level']),
        format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")

    state['bs'] = 10

    model = DialogEncoderDecoder(state)

    if os.path.isfile(model_path):
        logger.debug("Loading previous model")
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")

    contexts = [[]]
    lines = open(dialogue_file, "r").readlines()
    if len(lines):
        contexts = [x.strip() for x in lines]

    model_compute_encoding = model.build_encoder_function()
    dialogue_encodings = []

    # Start loop
    joined_contexts = []
    batch_index = 0
    batch_total = int(math.ceil(float(len(contexts)) / float(model.bs)))
    for context_id, context_sentences in enumerate(contexts):
        # Convert contexts into list of ids
        joined_context = []

        if len(context_sentences) == 0:
            joined_context = [model.eos_sym]
        else:
            joined_context = model.words_to_indices(context_sentences.split())

            if joined_context[0] != model.eos_sym:
                joined_context = [model.eos_sym] + joined_context

            if joined_context[-1] != model.eos_sym:
                joined_context += [model.eos_sym]

        joined_contexts.append(joined_context)

        if len(joined_contexts) == model.bs:
            batch_index = batch_index + 1
            logger.debug("[COMPUTE] - Got batch %d / %d" %
                         (batch_index, batch_total))
            encs = compute_encodings(joined_contexts, model,
                                     model_compute_encoding,
                                     use_second_last_state)
            for i in range(len(encs)):
                dialogue_encodings.append(encs[i])

            joined_contexts = []

    if len(joined_contexts) > 0:
        logger.debug("[COMPUTE] - Got batch %d / %d" %
                     (batch_total, batch_total))
        encs = compute_encodings(joined_contexts, model,
                                 model_compute_encoding, use_second_last_state)
        for i in range(len(encs)):
            dialogue_encodings.append(encs[i])

    return dialogue_encodings
Example #22
0
def main():
    args = parse_args()
    state = prototype_state()

    state_path = args.model_prefix + "_state.pkl"
    model_path = args.model_prefix + "_model.npz"

    with open(state_path) as src:
        state.update(cPickle.load(src))

    logging.basicConfig(
        level=getattr(logging, state['level']),
        format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")

    # For simplicity, we force the batch size to be one
    state['bs'] = 1
    model = DialogEncoderDecoder(state)

    if os.path.isfile(model_path):
        logger.debug("Loading previous model")
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")

    eval_batch = model.build_eval_function()

    contexts = [[]]
    lines = open(args.context, "r").readlines()
    if len(lines):
        contexts = [x.strip() for x in lines]

    potential_responses_set = open(args.responses, "r").readlines()
    most_probable_responses_string = ''

    print('Retrieval started...')

    for context_idx, context in enumerate(contexts):
        if context_idx % 100 == 0:
            print '     processing example: ' + str(context_idx) + ' / ' + str(
                len(contexts))
        potential_responses = potential_responses_set[context_idx].strip(
        ).split('\t')

        most_probable_response_loglikelihood = -1.0
        most_probable_response = ''

        for potential_response_idx, potential_response in enumerate(
                potential_responses):
            # Convert contexts into list of ids
            dialogue = []
            if len(context) == 0:
                dialogue = [model.eos_sym]
            else:
                sentence_ids = model.words_to_indices(context.split())
                # Add eos tokens
                if len(sentence_ids) > 0:
                    if not sentence_ids[0] == model.eos_sym:
                        sentence_ids = [model.eos_sym] + sentence_ids
                    if not sentence_ids[-1] == model.eos_sym:
                        sentence_ids += [model.eos_sym]
                else:
                    sentence_ids = [model.eos_sym]

                dialogue += sentence_ids

            response = model.words_to_indices(potential_response.split())
            if len(response) > 0:
                if response[0] == model.eos_sym:
                    del response[0]
                if not response[-1] == model.eos_sym:
                    response += [model.eos_sym]

            dialogue += response
            dialogue = numpy.asarray(dialogue, dtype='int32').reshape(
                (len(dialogue), 1))

            dialogue_reversed = model.reverse_utterances(dialogue)

            dialogue_mask = numpy.ones((len(dialogue), 1), dtype='float32')
            dialogue_weight = numpy.ones((len(dialogue), 1), dtype='float32')
            dialogue_reset_mask = numpy.zeros((1), dtype='float32')
            dialogue_ran_vectors = model.rng.normal(size=(
                1, model.latent_gaussian_per_utterance_dim)).astype('float32')
            dialogue_ran_vectors = numpy.tile(dialogue_ran_vectors,
                                              (len(dialogue), 1, 1))

            dialogue_drop_mask = numpy.ones((len(dialogue), 1),
                                            dtype='float32')

            c, _, _, _, _ = eval_batch(dialogue, dialogue_reversed,
                                       len(dialogue), dialogue_mask,
                                       dialogue_weight, dialogue_reset_mask,
                                       dialogue_ran_vectors,
                                       dialogue_drop_mask)
            c = c / len(dialogue)

            print 'c', c

            if (potential_response_idx
                    == 0) or (-c > most_probable_response_loglikelihood):
                most_probable_response_loglikelihood = -c
                most_probable_response = potential_response

        most_probable_responses_string += most_probable_response + '\n'

    print('Retrieval finished.')
    print('Saving to file...')

    # Write to output file
    output_handle = open(args.output, "w")
    output_handle.write(most_probable_responses_string)
    output_handle.close()

    print('Saving to file finished.')
    print('All done!')
Example #23
0
def main():
    args = parse_args()
    state = prototype_state()
   
    state_path = args.model_prefix + "_state.pkl"
    model_path = args.model_prefix + "_model.npz"

    with open(state_path) as src:
        state.update(cPickle.load(src)) 
    
    logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")
    # Force batch size to be one, so that we can condition the prediction at time t on its prediction at time t-1.
    state['bs'] = 1 
 
    model = DialogEncoderDecoder(state)
    if os.path.isfile(model_path):
        logger.debug("Loading previous model")
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")
    
    eval_batch = model.build_eval_function()
    
    if args.test_path:
        state['test_dialogues'] = args.test_path

    sentence_break_symbols = [model.str_to_idx['.'], model.str_to_idx['?'], model.str_to_idx['!']]
    test_data = get_test_iterator(state, sentence_break_symbols)
    test_data.start()

    tokens_per_sample = 3

    # Load document ids
    if args.document_ids:
        print("Warning. Evaluation using document ids is not supported")
        labels_file = open(args.document_ids, 'r')
        labels_text = labels_file.readlines()
        document_ids = numpy.zeros((len(labels_text)), dtype='int32')
        for i in range(len(labels_text)):
            document_ids[i] = int(labels_text[i].split('\t')[0])

        unique_document_ids = numpy.unique(document_ids)
        
        assert(test_data.data_len == document_ids.shape[0])

    else:
        document_ids = numpy.zeros((test_data.data_len), dtype='int32')
        unique_document_ids = numpy.unique(document_ids)
    
    # Variables to store test statistics
    test_cost = 0 # negative log-likelihood
    test_misclass_first = 0 # misclassification error-rate
    test_misclass_second = 0 # misclassification error-rate
    test_samples_done = 0 # number of examples evaluated

    # Number of examples in dataset
    test_data_len = test_data.data_len

    logger.debug("[TEST START]") 

    prev_doc_id = -1
    prev_predicted_speaker = 4

    while True:
        batch = test_data.next()
        # Train finished
        if not batch:
            break
         
        logger.debug("[TEST] - Got batch %d,%d" % (batch['x_prev'].shape[1], batch['max_length']))

        x_data_prev = batch['x_prev']
        x_mask_prev = batch['x_mask_prev']
        x_data_next = batch['x_next']
        x_mask_next = batch['x_mask_next']
        x_precomputed_features = batch['x_precomputed_features']
        y_data = batch['y']
        y_data_prev_true = batch['y_prev']
        x_max_length = batch['max_length']


        doc_id = batch['document_id'][0]
        y_data_prev_estimate = numpy.zeros((2, 1), dtype='int32')
        # If we continue in the same dialogue, use previous prediction to inform current prediction
        if prev_doc_id == doc_id:
            y_data_prev_estimate[0,0] = prev_predicted_speaker
        else: # Otherwise, we assume the previous (non-existing utterance) was labelled as "minor_speaker"
            y_data_prev_estimate[0,0] = 4

        #print 'y_data_prev_estimate', y_data_prev_estimate
        #print 'y_data_prev_true', y_data_prev_true

        c, _, miscl_first, miscl_second, training_preds_first, training_preds_second = eval_batch(x_data_prev, x_mask_prev, x_data_next, x_mask_next, x_precomputed_features, y_data, y_data_prev_estimate, x_max_length)

        prev_doc_id = doc_id
        prev_predicted_speaker = training_preds_second[0]

        test_cost += c
        test_misclass_first += miscl_first
        test_misclass_second += miscl_second
        test_samples_done += batch['num_samples']
     
    logger.debug("[TEST END]") 

    test_cost /= float(test_samples_done*tokens_per_sample)
    test_misclass_first /= float(test_samples_done)
    test_misclass_second /= float(test_samples_done)

    print "** test cost (NLL) = %.4f, valid word-perplexity = %.4f, valid mean turn-taking class error = %.4f, valid mean speaker class error = %.4f" % (float(test_cost), float(math.exp(test_cost)), float(test_misclass_first), float(test_misclass_second))


    logger.debug("All done, exiting...")
Example #24
0
#MODEL_PREFIX = 'Output/1485188791.05_RedditHRED'
MODEL_PREFIX = '/home/ml/mnosew1/SavedModels/Twitter/1489857182.98_TwitterModel'

state_path = '%s_state.pkl' % MODEL_PREFIX
model_path = '%s_model.npz' % MODEL_PREFIX

state = prototype_state()
with open(state_path, 'r') as handle:
    state.update(cPickle.load(handle))

#state['dictionary'] = '/home/ml/mnosew1/data/twitter/hred_bpe/Dataset.dict.pkl'
state[
    'dictionary'] = '/home/ml/mnosew1/SavedModels/Twitter/Dataset.dict-5k.pkl'
print 'Building model...'
model = DialogEncoderDecoder(state)
print 'Building sampler...'
sampler = search.BeamSampler(model)
print 'Loading model...'
model.load(model_path)
print 'Model built.'

HISTORY = []


@app.route('/hred', methods=['POST'])
def hred_response():
    print 'Generating HRED response...'
    text = request.json['result']['resolvedQuery']
    text = text.replace("'", " '")
    context = '<first_speaker> %s </s>' % text.strip().lower()
Example #25
0
def main():
    args = parse_args()
    state = prototype_state()

    state_path = args.model_prefix + "_state.pkl"
    model_path = args.model_prefix + "_model.npz"

    with open(state_path) as src:
        state.update(pickle.load(src))

    logging.basicConfig(
        level=getattr(logging, state['level']),
        format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")

    model = DialogEncoderDecoder(state)
    if os.path.isfile(model_path):
        logger.debug("Loading previous model")
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")

    eval_batch = model.build_eval_function()

    if args.test_path:
        state['test_dialogues'] = args.test_path

    # Initialize list of stopwords to remove
    if args.exclude_stop_words:
        logger.debug("Initializing stop-word list")
        stopwords_lowercase = stopwords.lower().split(' ')
        stopwords_indices = []
        for word in stopwords_lowercase:
            if word in model.str_to_idx:
                stopwords_indices.append(model.str_to_idx[word])

    test_data = get_test_iterator(state)
    test_data.start()

    # Load document ids
    if args.document_ids:
        labels_file = open(args.document_ids, 'r')
        labels_text = labels_file.readlines()
        document_ids = numpy.zeros((len(labels_text)), dtype='int32')
        for i in range(len(labels_text)):
            document_ids[i] = int(labels_text[i].split('\t')[0])

        unique_document_ids = numpy.unique(document_ids)

        assert (test_data.data_len == document_ids.shape[0])

    else:
        print(
            'Warning no file with document ids given... standard deviations cannot be computed.'
        )
        document_ids = numpy.zeros((test_data.data_len), dtype='int32')
        unique_document_ids = numpy.unique(document_ids)

    # Variables to store test statistics
    test_cost = 0  # negative log-likelihood
    test_wordpreds_done = 0  # number of words in total

    # Number of triples in dataset
    test_data_len = test_data.data_len

    max_stored_len = 160  # Maximum number of tokens to store for dialogues with highest and lowest validation errors

    logger.debug("[TEST START]")

    while True:
        batch = next(test_data)
        # Train finished
        if not batch:
            break

        logger.debug("[TEST] - Got batch %d,%d" %
                     (batch['x'].shape[1], batch['max_length']))

        x_data = batch['x']
        x_data_reversed = batch['x_reversed']
        max_length = batch['max_length']
        x_cost_mask = batch['x_mask']
        reset_mask = batch['x_reset']
        ran_cost_utterance = batch['ran_var_constutterance']
        ran_decoder_drop_mask = batch['ran_decoder_drop_mask']

        if args.exclude_stop_words:
            for word_index in stopwords_indices:
                x_cost_mask[x_data == word_index] = 0

        batch['num_preds'] = numpy.sum(x_cost_mask)

        c, _, c_list, _, _ = eval_batch(x_data, x_data_reversed, max_length,
                                        x_cost_mask, reset_mask,
                                        ran_cost_utterance,
                                        ran_decoder_drop_mask)

        c_list = c_list.reshape((batch['x'].shape[1], max_length - 1),
                                order=(1, 0))
        c_list = numpy.sum(c_list, axis=1)

        if numpy.isinf(c) or numpy.isnan(c):
            continue

        test_cost += c

        words_in_triples = numpy.sum(x_cost_mask, axis=0)

        if numpy.isinf(c) or numpy.isnan(c):
            continue

        if numpy.isinf(c) or numpy.isnan(c):
            continue

        test_wordpreds_done += batch['num_preds']

    logger.debug("[TEST END]")

    print('test_wordpreds_done (number of words) ', test_wordpreds_done)
    test_cost /= test_wordpreds_done

    print("** test cost (NLL) = %.4f, test word-perplexity = %.4f " %
          (float(test_cost), float(math.exp(test_cost))))

    logger.debug("All done, exiting...")
def main():
    args = parse_args()
    state = prototype_state()

    state_path = args.model_prefix + "_state.pkl"
    model_path = args.model_prefix + "_model.npz"

    with open(state_path) as src:
        state.update(cPickle.load(src))

    logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")

    model = DialogEncoderDecoder(state) 
    
    if os.path.isfile(model_path):
        logger.debug("Loading previous model")
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")
    
    contexts = [[]]
    lines = open(args.dialogues, "r").readlines()
    if len(lines):
        contexts = [x.strip().split('\t') for x in lines]
   
    model_compute_encoding = model.build_encoder_function()
    dialogue_encodings = []

    # Start loop
    joined_contexts = []
    batch_index = 0
    batch_total = int(math.ceil(float(len(contexts)) / float(model.bs)))
    for context_id, context_sentences in enumerate(contexts):

        # Convert contextes into list of ids
        joined_context = []

        if len(context_sentences) == 0:
            joined_context = [model.eos_sym]
        else:
            joined_context += [model.eos_sym]
            for sentence in context_sentences:
                sentence_ids = model.words_to_indices(sentence.split())
                # Add eos tokens
                joined_context += sentence_ids + [model.eos_sym]

        # HACK
        #for i in range(0, 50):
        #    joined_context += [0] + [model.eos_sym]

        joined_contexts.append(joined_context)

        if len(joined_contexts) == model.bs:
            batch_index = batch_index + 1
            logger.debug("[COMPUTE] - Got batch %d / %d" % (batch_index, batch_total))
            encs = compute_encodings(joined_contexts, model, model_compute_encoding, args.use_second_last_state)
            for i in range(len(encs)):
                dialogue_encodings.append(encs[i])

            joined_contexts = []


    if len(joined_contexts) > 0:
        logger.debug("[COMPUTE] - Got batch %d / %d" % (batch_total, batch_total))
        encs = compute_encodings(joined_contexts, model, model_compute_encoding, args.use_second_last_state)
        for i in range(len(encs)):
            dialogue_encodings.append(encs[i])

    # Save encodings to disc
    cPickle.dump(dialogue_encodings, open(args.output + '.pkl', 'w'))
Example #27
0
def main():
    args = parse_args()
    state = prototype_state()

    state_path = args.model_prefix + "_state.pkl"
    model_path = args.model_prefix + "_model.npz"

    with open(state_path) as src:
        state.update(cPickle.load(src))

    logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")

    state['bs'] = 20
    state['compute_training_updates'] = False

    if args.mf_inference_steps > 0:
        state['lr'] = 0.01
        state['apply_meanfield_inference'] = True

    model = DialogEncoderDecoder(state) 

    if os.path.isfile(model_path):
        logger.debug("Loading previous model")
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")
    

    eval_batch = model.build_eval_function()

    if args.mf_inference_steps > 0:
        mf_update_batch = model.build_mf_update_function()
        mf_reset_batch = model.build_mf_reset_function()
        
    contexts = [[]]
    lines = open(args.context, "r").readlines()
    if len(lines):
        contexts = [x.strip() for x in lines]

    lines = open(args.responses, "r").readlines()
    if len(lines):
        potential_responses_set = [x.strip() for x in lines]
    
    print('Building data batches...')

    # This is the maximum sequence length we can process on a 12 GB GPU with large HRED models
    max_sequence_length = 80*(80/state['bs'])

    assert len(potential_responses_set) == len(contexts)

    # Note we assume that each example has the same number of potential responses
    # The code can be adapted, however, by simply counting the total number of responses to consider here...
    examples_to_process = len(contexts) * len(potential_responses_set[0].strip().split('\t'))

    all_dialogues = numpy.zeros((max_sequence_length, examples_to_process), dtype='int32')
    all_dialogues_reversed = numpy.zeros((max_sequence_length, examples_to_process), dtype='int32')
    all_dialogues_mask = numpy.zeros((max_sequence_length, examples_to_process), dtype='float32')
    all_dialogues_reset_mask = numpy.zeros((examples_to_process), dtype='float32')
    all_dialogues_drop_mask = numpy.ones((max_sequence_length, examples_to_process), dtype='float32')

    all_dialogues_len = numpy.zeros((examples_to_process), dtype='int32')

    example_idx = -1

    for context_idx, context in enumerate(contexts):
        if context_idx % 100 == 0:
            print '     processing context idx: ' + str(context_idx) + ' / ' + str(len(contexts))

        potential_responses = potential_responses_set[context_idx].strip().split('\t')

        most_probable_response_loglikelihood = -1.0
        most_probable_response = ''

        for potential_response_idx, potential_response in enumerate(potential_responses):
            example_idx += 1

            # Convert contexts into list of ids
            dialogue = []
            if len(context) == 0:
                dialogue = [model.eos_sym]
            else:
                sentence_ids = model.words_to_indices(context.split())
                # Add eos tokens
                if len(sentence_ids) > 0:
                    if not sentence_ids[0] == model.eos_sym:
                        sentence_ids = [model.eos_sym] + sentence_ids
                    if not sentence_ids[-1] == model.eos_sym:
                        sentence_ids += [model.eos_sym]
                else:
                    sentence_ids = [model.eos_sym]


                dialogue += sentence_ids

            response = model.words_to_indices(potential_response.split())
            if len(response) > 0:
                if response[0] == model.eos_sym:
                    del response[0]
                if not response[-1] == model.eos_sym:
                    response += [model.eos_sym]

            if len(response) > 3:
                if ((response[-1] == model.eos_sym)
                 and (response[-2] == model.eod_sym)
                 and (response[-3] == model.eos_sym)):
                    del response[-1]
                    del response[-1]

            dialogue += response

            if context_idx == 0:
                print 'DEBUG INFO'
                print 'dialogue', dialogue

            if len(numpy.where(numpy.asarray(dialogue) == state['eos_sym'])[0]) < 3:
                print 'POTENTIAL PROBLEM WITH DIALOGUE CONTEXT IDX', context_idx
                print '     dialogue', dialogue


            # Trim beginning of dialogue words if the dialogue is too long... this should be a rare event.
            if len(dialogue) > max_sequence_length:
                if state['do_generate_first_utterance']:
                    dialogue = dialogue[len(dialogue)-max_sequence_length:len(dialogue)]
                else: # CreateDebate specific setting
                    dialogue = dialogue[0:max_sequence_length-1]
                    if not dialogue[-1] == state['eos_sym']:
                        dialogue += [state['eos_sym']]

                    


            dialogue = numpy.asarray(dialogue, dtype='int32').reshape((len(dialogue), 1))

            dialogue_reversed = model.reverse_utterances(dialogue)
            dialogue_mask = numpy.ones((len(dialogue)), dtype='float32') 
            dialogue_weight = numpy.ones((len(dialogue)), dtype='float32')
            dialogue_drop_mask = numpy.ones((len(dialogue)), dtype='float32')

            # Add example to large numpy arrays...
            all_dialogues[0:len(dialogue), example_idx] = dialogue[:,0]
            all_dialogues_reversed[0:len(dialogue), example_idx] = dialogue_reversed[:,0]
            all_dialogues_mask[0:len(dialogue), example_idx] = dialogue_mask
            all_dialogues_drop_mask[0:len(dialogue), example_idx] = dialogue_drop_mask

            all_dialogues_len[example_idx] = len(dialogue)


    sorted_dialogue_indices = numpy.argsort(all_dialogues_len)

    all_dialogues_cost = numpy.ones((examples_to_process), dtype='float32')

    batch_count = int(numpy.ceil(float(examples_to_process) / float(state['bs'])))

    print('Computing costs on batches...')
    for batch_idx in reversed(range(batch_count)):
        if batch_idx % 10 == 0:
            print '     processing batch idx: ' + str(batch_idx) + ' / ' + str(batch_count)

        example_index_start = batch_idx * state['bs']
        example_index_end = min((batch_idx+1) * state['bs'], examples_to_process)
        example_indices = sorted_dialogue_indices[example_index_start:example_index_end]

        current_batch_size = len(example_indices)

        max_batch_sequence_length = all_dialogues_len[example_indices[-1]]

        # Initialize batch with zeros
        batch_dialogues = numpy.zeros((max_batch_sequence_length, state['bs']), dtype='int32')
        batch_dialogues_reversed = numpy.zeros((max_batch_sequence_length, state['bs']), dtype='int32')
        batch_dialogues_mask = numpy.zeros((max_batch_sequence_length, state['bs']), dtype='float32')
        batch_dialogues_weight = numpy.ones((max_batch_sequence_length, state['bs']), dtype='float32')
        batch_dialogues_reset_mask = numpy.zeros((state['bs']), dtype='float32')
        batch_dialogues_drop_mask = numpy.ones((max_batch_sequence_length, state['bs']), dtype='float32')

        # Fill in batch with values
        batch_dialogues[:,0:current_batch_size]  = all_dialogues[0:max_batch_sequence_length, example_indices]
        batch_dialogues_reversed[:,0:current_batch_size] = all_dialogues_reversed[0:max_batch_sequence_length, example_indices]
        batch_dialogues_mask[:,0:current_batch_size] = all_dialogues_mask[0:max_batch_sequence_length, example_indices]
        batch_dialogues_drop_mask[:,0:current_batch_size] = all_dialogues_drop_mask[0:max_batch_sequence_length, example_indices]

        if batch_idx < 10:
            print '###'
            print '###'
            print '###'

            print 'DEBUG THIS:'
            print 'batch_dialogues', batch_dialogues[:, 0]
            print 'batch_dialogues_reversed',batch_dialogues_reversed[:, 0]
            print 'max_batch_sequence_length', max_batch_sequence_length
            print 'batch_dialogues_reset_mask', batch_dialogues_reset_mask


        batch_dialogues_ran_gaussian_vectors = numpy.zeros((max_batch_sequence_length, state['bs'], model.latent_gaussian_per_utterance_dim), dtype='float32')
        batch_dialogues_ran_uniform_vectors = numpy.zeros((max_batch_sequence_length, state['bs'], model.latent_piecewise_per_utterance_dim), dtype='float32')

        for idx in range(state['bs']):
            eos_sym_list = numpy.where(batch_dialogues[:, idx] == state['eos_sym'])[0]
            if len(eos_sym_list) > 1:
                second_last_eos_sym = eos_sym_list[-2]
            else:
                print 'WARNING: batch_idx ' + str(batch_idx) + ' example index ' + str(idx) + ' does not have at least two EOS tokens!'
                print '     batch was: ', batch_dialogues[:, idx]

                if len(eos_sym_list) > 0:
                    second_last_eos_sym = eos_sym_list[-1]
                else:
                    second_last_eos_sym = 0

            batch_dialogues_ran_gaussian_vectors[second_last_eos_sym:, idx, :] = model.rng.normal(loc=0, scale=1, size=model.latent_gaussian_per_utterance_dim)
            batch_dialogues_ran_uniform_vectors[second_last_eos_sym:, idx, :] = model.rng.uniform(low=0.0, high=1.0, size=model.latent_piecewise_per_utterance_dim)

            batch_dialogues_mask[0:second_last_eos_sym+1, idx] = 0.0




        if batch_idx < 10:
            print 'batch_dialogues_mask', batch_dialogues_mask[:, 0]
            print 'batch_dialogues_ran_gaussian_vectors', batch_dialogues_ran_gaussian_vectors[:, 0, 0]
            print 'batch_dialogues_ran_gaussian_vectors [-1]', batch_dialogues_ran_gaussian_vectors[:, 0, -1]
            print 'batch_dialogues_ran_uniform_vectors', batch_dialogues_ran_uniform_vectors[:, 0, 0]
            print 'batch_dialogues_ran_uniform_vectors [-1]', batch_dialogues_ran_uniform_vectors[:, 0, -1]


        # Carry out mean-field inference:
        if args.mf_inference_steps > 0:
            mf_reset_batch()
            model.build_mf_reset_function()
            for mf_step in range(args.mf_inference_steps):


                training_cost, kl_divergence_cost_acc, kl_divergences_between_piecewise_prior_and_posterior, kl_divergences_between_gaussian_prior_and_posterior = mf_update_batch(batch_dialogues, batch_dialogues_reversed, max_batch_sequence_length, batch_dialogues_mask, batch_dialogues_reset_mask, batch_dialogues_ran_gaussian_vectors, batch_dialogues_ran_uniform_vectors, batch_dialogues_drop_mask)
                if batch_idx % 10 == 0:
                    print  '  mf_step', mf_step
                    print '     training_cost', training_cost
                    print '     kl_divergence_cost_acc', kl_divergence_cost_acc
                    print '     kl_divergences_between_gaussian_prior_and_posterior',  numpy.sum(kl_divergences_between_gaussian_prior_and_posterior)
                    print '     kl_divergences_between_piecewise_prior_and_posterior', numpy.sum(kl_divergences_between_piecewise_prior_and_posterior)



        eval_batch(batch_dialogues, batch_dialogues_reversed, max_batch_sequence_length, batch_dialogues_mask, batch_dialogues_reset_mask, batch_dialogues_ran_gaussian_vectors, batch_dialogues_ran_uniform_vectors, batch_dialogues_drop_mask)

        # Train on batch and get results
        _, c_list, _ = eval_batch(batch_dialogues, batch_dialogues_reversed, max_batch_sequence_length, batch_dialogues_mask, batch_dialogues_reset_mask, batch_dialogues_ran_gaussian_vectors, batch_dialogues_ran_uniform_vectors, batch_dialogues_drop_mask)
        c_list = c_list.reshape((state['bs'],max_batch_sequence_length-1), order=(1,0))
        c_list = numpy.sum(c_list, axis=1) / numpy.sum(batch_dialogues_mask, axis=0)
        all_dialogues_cost[example_indices] = c_list[0:current_batch_size]

        #print 'hd_input [0]', hd_input[:, 0, 0]
        #print 'hd_input [-]', hd_input[:, 0, -1]


    print('Ranking responses...')
    example_idx = -1
    ranked_all_responses_string = ''
    ranked_all_responses_costs_string = ''
    for context_idx, context in enumerate(contexts):
        if context_idx % 100 == 0:
            print '     processing context idx: ' + str(context_idx) + ' / ' + str(len(contexts))

        potential_responses = potential_responses_set[context_idx].strip().split('\t')
        potential_example_response_indices = numpy.zeros((len(potential_responses)), dtype='int32')

        for potential_response_idx, potential_response in enumerate(potential_responses):
            example_idx += 1
            potential_example_response_indices[potential_response_idx] = example_idx

        ranked_potential_example_response_indices = numpy.argsort(all_dialogues_cost[potential_example_response_indices])

        ranked_responses_costs = all_dialogues_cost[potential_example_response_indices]

        ranked_responses_string = ''
        ranked_responses_costs_string = ''
        for idx in ranked_potential_example_response_indices:
            ranked_responses_string += potential_responses[idx] + '\t'
            ranked_responses_costs_string += str(ranked_responses_costs[idx]) + '\t'

        ranked_all_responses_string += ranked_responses_string[0:len(ranked_responses_string)-1] + '\n'
        ranked_all_responses_costs_string += ranked_responses_costs_string[0:len(ranked_responses_string)-1] + '\n'

    print('Finished all computations!')
    print('Saving to file...')

    # Write to output file
    output_handle = open(args.output + '.txt', "w")
    output_handle.write(ranked_all_responses_string)
    output_handle.close()

    output_handle = open(args.output + '_Costs.txt' , "w")
    output_handle.write(ranked_all_responses_costs_string)
    output_handle.close()

    print('Saving to file finished.')
    print('All done!')
Example #28
0
def main():
    args = parse_args()
    state = prototype_state()
   
    state_path = args.model_prefix + "_state.pkl"
    model_path = args.model_prefix + "_model.npz"

    with open(state_path) as src:
        state.update(cPickle.load(src)) 
    
    logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")
     
    model = DialogEncoderDecoder(state)
    if os.path.isfile(model_path):
        logger.debug("Loading previous model")
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")
    
    eval_batch = model.build_eval_function()
    
    if args.test_path:
        state['test_dialogues'] = args.test_path

    # Initialize list of stopwords to remove
    if args.exclude_stop_words:
        logger.debug("Initializing stop-word list")
        stopwords_lowercase = stopwords.lower().split(' ')
        stopwords_indices = []
        for word in stopwords_lowercase:
            if word in model.str_to_idx:
                stopwords_indices.append(model.str_to_idx[word])

    test_data = get_test_iterator(state)
    test_data.start()

    # Load document ids
    if args.document_ids:
        labels_file = open(args.document_ids, 'r')
        labels_text = labels_file.readlines()
        document_ids = numpy.zeros((len(labels_text)), dtype='int32')
        for i in range(len(labels_text)):
            document_ids[i] = int(labels_text[i].split('\t')[0])

        unique_document_ids = numpy.unique(document_ids)
        
        assert(test_data.data_len == document_ids.shape[0])

    else:
        print 'Warning no file with document ids given... standard deviations cannot be computed.'
        document_ids = numpy.zeros((test_data.data_len), dtype='int32')
        unique_document_ids = numpy.unique(document_ids)

    # Variables to store test statistics
    test_cost = 0 # negative log-likelihood
    test_wordpreds_done = 0 # number of words in total

    # Number of triples in dataset
    test_data_len = test_data.data_len

    max_stored_len = 160 # Maximum number of tokens to store for dialogues with highest and lowest validation errors

    logger.debug("[TEST START]") 

    while True:
        batch = test_data.next()
        # Train finished
        if not batch:
            break

        logger.debug("[TEST] - Got batch %d,%d" % (batch['x'].shape[1], batch['max_length']))

        x_data = batch['x']
        x_data_reversed = batch['x_reversed']
        max_length = batch['max_length']
        x_cost_mask = batch['x_mask']
        x_semantic = batch['x_semantic']
        reset_mask = batch['x_reset']
        ran_cost_utterance = batch['ran_var_constutterance']
        ran_decoder_drop_mask = batch['ran_decoder_drop_mask']

        if args.exclude_stop_words:
            for word_index in stopwords_indices:
                x_cost_mask[x_data == word_index] = 0

        batch['num_preds'] = numpy.sum(x_cost_mask)

        c, c_list, _, _  = eval_batch(x_data, x_data_reversed, max_length, x_cost_mask, x_semantic, reset_mask, ran_cost_utterance, ran_decoder_drop_mask)

        c_list = c_list.reshape((batch['x'].shape[1],max_length-1), order=(1,0))
        c_list = numpy.sum(c_list, axis=1)     

        if numpy.isinf(c) or numpy.isnan(c):
            continue
        
        test_cost += c

        words_in_triples = numpy.sum(x_cost_mask, axis=0)

        if numpy.isinf(c) or numpy.isnan(c):
            continue

        if numpy.isinf(c) or numpy.isnan(c):
            continue


        test_wordpreds_done += batch['num_preds']
     
    logger.debug("[TEST END]") 

    print 'test_wordpreds_done (number of words) ', test_wordpreds_done
    test_cost /= test_wordpreds_done

    print "** test cost (NLL) = %.4f, test word-perplexity = %.4f " % (float(test_cost), float(math.exp(test_cost)))  

    logger.debug("All done, exiting...")
Example #29
0
        train_context_embeddings = flatten_list(train_emb)
        test_context_embeddings = flatten_list(test_emb)


    elif 'gpu' in theano.config.device.lower():
        state = prototype_state()
        state_path = twitter_model_prefix + "_state.pkl"
        model_path = twitter_model_prefix + "_model.npz"
        
        with open(state_path) as src:
            state.update(cPickle.load(src))

        state['bs'] = 20
        state['dictionary'] = twitter_model_dictionary

        model = DialogEncoderDecoder(state) 
        
        calc_response_embeddings = False
        calc_context_embeddings = True
        calc_test = False
        start_batch = 24000
        max_batches = 35000

        if calc_response_embeddings:
            print 'Computing training response embeddings...'
            train_response_embeddings = compute_model_embeddings(train_responses, model, embedding_type, 'train_response', starting_batch=start_batch, max_batches=max_batches)
            if calc_test:
                print 'Computing test response embeddings...'
                test_response_embeddings = compute_model_embeddings(test_responses, model, embedding_type, 'test_response', max_batches=2000)

        # Computed up to batch 22420 for DECODER
Example #30
0
def main():
    args = parse_args()
    state = prototype_state()

    state_path = args.model_prefix + "_state.pkl"
    model_path = args.model_prefix + "_model.npz"

    with open(state_path) as src:
        state.update(cPickle.load(src))

    logging.basicConfig(
        level=getattr(logging, state['level']),
        format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")

    model = DialogEncoderDecoder(state)
    if os.path.isfile(model_path):
        logger.debug("Loading previous model")
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")

    eval_batch = model.build_eval_function()
    eval_misclass_batch = model.build_eval_misclassification_function()

    if args.test_path:
        state['test_triples'] = args.test_path

    # Initialize list of stopwords to remove
    if args.exclude_stop_words:
        logger.debug("Initializing stop-word list")
        stopwords_lowercase = stopwords.lower().split(' ')
        stopwords_indices = []
        for word in stopwords_lowercase:
            if word in model.str_to_idx:
                stopwords_indices.append(model.str_to_idx[word])

    test_data = get_test_iterator(state)
    test_data.start()

    # Load document ids
    if args.document_ids:
        labels_file = open(args.document_ids, 'r')
        labels_text = labels_file.readlines()
        document_ids = numpy.zeros((len(labels_text)), dtype='int32')
        for i in range(len(labels_text)):
            document_ids[i] = int(labels_text[i].split('\t')[0])

        unique_document_ids = numpy.unique(document_ids)

        assert (test_data.data_len == document_ids.shape[0])

    else:
        print 'Warning no file with document ids given... standard deviations cannot be computed.'
        document_ids = numpy.zeros((test_data.data_len), dtype='int32')
        unique_document_ids = numpy.unique(document_ids)

    # Variables to store test statistics
    test_cost = 0  # negative log-likelihood
    test_cost_first_utterances = 0  # marginal negative log-likelihood of first two utterances
    test_cost_last_utterance_marginal = 0  # marginal (approximate) negative log-likelihood of last utterances
    test_misclass = 0  # misclassification error-rate
    test_misclass_first_utterances = 0  # misclassification error-rate of first two utterances
    test_empirical_mutual_information = 0  # empirical mutual information between first two utterances and third utterance, where the marginal P(U_3) is approximated by P(U_3, empty, empty).

    if model.bootstrap_from_semantic_information:
        test_semantic_cost = 0
        test_semantic_misclass = 0

    test_wordpreds_done = 0  # number of words in total
    test_wordpreds_done_last_utterance = 0  # number of words in last utterances
    test_triples_done = 0  # number of triples evaluated

    # Variables to compute negative log-likelihood and empirical mutual information per genre
    compute_genre_specific_metrics = False
    if hasattr(model, 'semantic_information_dim'):
        compute_genre_specific_metrics = True
        test_cost_per_genre = numpy.zeros((model.semantic_information_dim, 1),
                                          dtype='float32')
        test_mi_per_genre = numpy.zeros((model.semantic_information_dim, 1),
                                        dtype='float32')
        test_wordpreds_done_per_genre = numpy.zeros(
            (model.semantic_information_dim, 1), dtype='float32')
        test_triples_done_per_genre = numpy.zeros(
            (model.semantic_information_dim, 1), dtype='float32')

    # Number of triples in dataset
    test_data_len = test_data.data_len

    # Correspond to the same variables as above, but now for each triple.
    # e.g. test_cost_list is a numpy array with the negative log-likelihood for each triple in the test set
    test_cost_list = numpy.zeros((test_data_len, ))
    test_pmi_list = numpy.zeros((test_data_len, ))
    test_cost_last_utterance_marginal_list = numpy.zeros((test_data_len, ))
    test_misclass_list = numpy.zeros((test_data_len, ))
    test_misclass_last_utterance_list = numpy.zeros((test_data_len, ))

    # Array containing number of words in each triple
    words_in_triples_list = numpy.zeros((test_data_len, ))

    # Array containing number of words in last utterance of each triple
    words_in_last_utterance_list = numpy.zeros((test_data_len, ))

    # Prepare variables for printing the test examples the model performs best and worst on
    test_extrema_setsize = min(state['track_extrema_samples_count'],
                               test_data_len)
    test_extrema_samples_to_print = min(state['print_extrema_samples_count'],
                                        test_extrema_setsize)

    test_lowest_costs = numpy.ones((test_extrema_setsize, )) * 1000
    test_lowest_triples = numpy.ones(
        (test_extrema_setsize, state['seqlen'])) * 1000
    test_highest_costs = numpy.ones((test_extrema_setsize, )) * (-1000)
    test_highest_triples = numpy.ones(
        (test_extrema_setsize, state['seqlen'])) * (-1000)

    logger.debug("[TEST START]")

    while True:
        batch = test_data.next()
        # Train finished
        if not batch:
            break

        logger.debug("[TEST] - Got batch %d,%d" %
                     (batch['x'].shape[1], batch['max_length']))

        x_data = batch['x']
        x_data_reversed = batch['x_reversed']
        max_length = batch['max_length']
        x_cost_mask = batch['x_mask']
        x_semantic = batch['x_semantic']
        x_semantic_nonempty_indices = numpy.where(x_semantic >= 0)

        # Hack to get rid of start of sentence token.
        if args.exclude_sos and model.sos_sym != -1:
            x_cost_mask[x_data == model.sos_sym] = 0

        if args.exclude_stop_words:
            for word_index in stopwords_indices:
                x_cost_mask[x_data == word_index] = 0

        batch['num_preds'] = numpy.sum(x_cost_mask)

        c, c_list = eval_batch(x_data, x_data_reversed, max_length,
                               x_cost_mask, x_semantic)

        c_list = c_list.reshape((batch['x'].shape[1], max_length),
                                order=(1, 0))
        c_list = numpy.sum(c_list, axis=1)

        # Compute genre specific stats...
        if compute_genre_specific_metrics:
            non_nan_entries = numpy.array(c_list >= 0, dtype=int)
            c_list[numpy.where(non_nan_entries == 0)] = 0
            test_cost_per_genre += (numpy.asmatrix(non_nan_entries * c_list) *
                                    numpy.asmatrix(x_semantic)).T
            test_wordpreds_done_per_genre += (numpy.asmatrix(
                non_nan_entries * numpy.sum(x_cost_mask, axis=0)) *
                                              numpy.asmatrix(x_semantic)).T

        if numpy.isinf(c) or numpy.isnan(c):
            continue

        test_cost += c

        # Store test costs in list
        nxt = min((test_triples_done + batch['x'].shape[1]), test_data_len)
        triples_in_batch = nxt - test_triples_done

        words_in_triples = numpy.sum(x_cost_mask, axis=0)
        words_in_triples_list[(
            nxt - triples_in_batch):nxt] = words_in_triples[0:triples_in_batch]

        # We don't need to normalzie by the number of words... not if we're computing standard deviations at least...
        test_cost_list[(nxt -
                        triples_in_batch):nxt] = c_list[0:triples_in_batch]

        # Store best and worst test costs
        con_costs = numpy.concatenate(
            [test_lowest_costs, c_list[0:triples_in_batch]])
        con_triples = numpy.concatenate(
            [test_lowest_triples, x_data[:, 0:triples_in_batch].T], axis=0)
        con_indices = con_costs.argsort()[0:test_extrema_setsize][::1]
        test_lowest_costs = con_costs[con_indices]
        test_lowest_triples = con_triples[con_indices]

        con_costs = numpy.concatenate(
            [test_highest_costs, c_list[0:triples_in_batch]])
        con_triples = numpy.concatenate(
            [test_highest_triples, x_data[:, 0:triples_in_batch].T], axis=0)
        con_indices = con_costs.argsort()[-test_extrema_setsize:][::-1]
        test_highest_costs = con_costs[con_indices]
        test_highest_triples = con_triples[con_indices]

        # Compute word-error rate
        miscl, miscl_list = eval_misclass_batch(x_data, x_data_reversed,
                                                max_length, x_cost_mask,
                                                x_semantic)
        if numpy.isinf(c) or numpy.isnan(c):
            continue

        test_misclass += miscl

        # Store misclassification errors in list
        miscl_list = miscl_list.reshape((batch['x'].shape[1], max_length),
                                        order=(1, 0))
        miscl_list = numpy.sum(miscl_list, axis=1)
        test_misclass_list[(
            nxt - triples_in_batch):nxt] = miscl_list[0:triples_in_batch]

        # Equations to compute empirical mutual information

        # Compute marginal log-likelihood of last utterance in triple:
        # We approximate it with the margina log-probabiltiy of the utterance being observed first in the triple
        x_data_last_utterance = batch['x_last_utterance']
        x_data_last_utterance_reversed = batch['x_last_utterance_reversed']
        x_cost_mask_last_utterance = batch['x_mask_last_utterance']
        x_start_of_last_utterance = batch['x_start_of_last_utterance']

        # Hack to get rid of start of sentence token.
        if args.exclude_sos and model.sos_sym != -1:
            x_cost_mask_last_utterance[x_data_last_utterance ==
                                       model.sos_sym] = 0

        if args.exclude_stop_words:
            for word_index in stopwords_indices:
                x_cost_mask_last_utterance[x_data_last_utterance ==
                                           word_index] = 0

        words_in_last_utterance = numpy.sum(x_cost_mask_last_utterance, axis=0)
        words_in_last_utterance_list[(
            nxt - triples_in_batch
        ):nxt] = words_in_last_utterance[0:triples_in_batch]

        batch['num_preds_at_utterance'] = numpy.sum(x_cost_mask_last_utterance)

        marginal_last_utterance_loglikelihood, marginal_last_utterance_loglikelihood_list = eval_batch(
            x_data_last_utterance, x_data_last_utterance_reversed, max_length,
            x_cost_mask_last_utterance, x_semantic)

        marginal_last_utterance_loglikelihood_list = marginal_last_utterance_loglikelihood_list.reshape(
            (batch['x'].shape[1], max_length), order=(1, 0))
        marginal_last_utterance_loglikelihood_list = numpy.sum(
            marginal_last_utterance_loglikelihood_list, axis=1)
        test_cost_last_utterance_marginal_list[(
            nxt - triples_in_batch
        ):nxt] = marginal_last_utterance_loglikelihood_list[0:triples_in_batch]

        # Compute marginal log-likelihood of first utterances in triple by masking the last utterance
        x_cost_mask_first_utterances = numpy.copy(x_cost_mask)
        for i in range(batch['x'].shape[1]):
            x_cost_mask_first_utterances[
                x_start_of_last_utterance[i]:max_length, i] = 0

        marginal_first_utterances_loglikelihood, marginal_first_utterances_loglikelihood_list = eval_batch(
            x_data, x_data_reversed, max_length, x_cost_mask_first_utterances,
            x_semantic)

        marginal_first_utterances_loglikelihood_list = marginal_first_utterances_loglikelihood_list.reshape(
            (batch['x'].shape[1], max_length), order=(1, 0))
        marginal_first_utterances_loglikelihood_list = numpy.sum(
            marginal_first_utterances_loglikelihood_list, axis=1)

        # Compute empirical mutual information and pointwise empirical mutual information
        test_empirical_mutual_information += -c + marginal_first_utterances_loglikelihood + marginal_last_utterance_loglikelihood
        test_pmi_list[(nxt - triples_in_batch):nxt] = (
            -c_list * words_in_triples +
            marginal_first_utterances_loglikelihood_list +
            marginal_last_utterance_loglikelihood_list)[0:triples_in_batch]

        # Compute genre specific stats...
        if compute_genre_specific_metrics:
            if triples_in_batch == batch['x'].shape[1]:
                mi_list = (-c_list * words_in_triples +
                           marginal_first_utterances_loglikelihood_list +
                           marginal_last_utterance_loglikelihood_list
                           )[0:triples_in_batch]
                non_nan_entries = numpy.array(
                    mi_list >= 0, dtype=int) * numpy.array(
                        mi_list != numpy.nan, dtype=int)
                test_mi_per_genre += (
                    numpy.asmatrix(non_nan_entries * mi_list) *
                    numpy.asmatrix(x_semantic)).T
                test_triples_done_per_genre += numpy.reshape(
                    numpy.sum(x_semantic, axis=0),
                    test_triples_done_per_genre.shape)

        # Store log P(U_1, U_2) cost computed during mutual information
        test_cost_first_utterances += marginal_first_utterances_loglikelihood

        # Store marginal log P(U_3)
        test_cost_last_utterance_marginal += marginal_last_utterance_loglikelihood

        # Compute word-error rate for first utterances
        miscl_first_utterances, miscl_first_utterances_list = eval_misclass_batch(
            x_data, x_data_reversed, max_length, x_cost_mask_first_utterances,
            x_semantic)
        test_misclass_first_utterances += miscl_first_utterances
        if numpy.isinf(c) or numpy.isnan(c):
            continue

        # Store misclassification for last utterance
        miscl_first_utterances_list = miscl_first_utterances_list.reshape(
            (batch['x'].shape[1], max_length), order=(1, 0))
        miscl_first_utterances_list = numpy.sum(miscl_first_utterances_list,
                                                axis=1)

        miscl_last_utterance_list = miscl_list - miscl_first_utterances_list

        test_misclass_last_utterance_list[(
            nxt - triples_in_batch
        ):nxt] = miscl_last_utterance_list[0:triples_in_batch]

        if model.bootstrap_from_semantic_information:
            # Compute cross-entropy error on predicting the semantic class and retrieve predictions
            sem_eval = eval_semantic_batch(x_data, x_data_reversed, max_length,
                                           x_cost_mask, x_semantic)

            # Evaluate only non-empty triples (empty triples are created to fill
            #   the whole batch sometimes).
            sem_cost = sem_eval[0][-1, :, :]
            test_semantic_cost += numpy.sum(
                sem_cost[x_semantic_nonempty_indices])

            # Compute misclassified predictions on last timestep over all labels
            sem_preds = sem_eval[1][-1, :, :]
            sem_preds_misclass = len(
                numpy.where(
                    ((x_semantic - 0.5) *
                     (sem_preds - 0.5))[x_semantic_nonempty_indices] < 0)[0])
            test_semantic_misclass += sem_preds_misclass

        test_wordpreds_done += batch['num_preds']
        test_wordpreds_done_last_utterance += batch['num_preds_at_utterance']
        test_triples_done += batch['num_triples']

    logger.debug("[TEST END]")

    test_cost_last_utterance_marginal /= test_wordpreds_done_last_utterance
    test_cost_last_utterance = (test_cost - test_cost_first_utterances
                                ) / test_wordpreds_done_last_utterance
    test_cost /= test_wordpreds_done
    test_cost_first_utterances /= float(test_wordpreds_done -
                                        test_wordpreds_done_last_utterance)

    test_misclass_last_utterance = float(
        test_misclass - test_misclass_first_utterances) / float(
            test_wordpreds_done_last_utterance)
    test_misclass_first_utterances /= float(test_wordpreds_done -
                                            test_wordpreds_done_last_utterance)
    test_misclass /= float(test_wordpreds_done)
    test_empirical_mutual_information /= float(test_triples_done)

    if model.bootstrap_from_semantic_information:
        test_semantic_cost /= float(test_triples_done)
        test_semantic_misclass /= float(test_done_triples)
        print "** test semantic cost = %.4f, test semantic misclass error = %.4f" % (
            float(test_semantic_cost), float(test_semantic_misclass))

    print "** test cost (NLL) = %.4f, test word-perplexity = %.4f, test word-perplexity last utterance = %.4f, test word-perplexity marginal last utterance = %.4f, test mean word-error = %.4f, test mean word-error last utterance = %.4f, test emp. mutual information = %.4f" % (
        float(test_cost), float(
            math.exp(test_cost)), float(math.exp(test_cost_last_utterance)),
        float(
            math.exp(test_cost_last_utterance_marginal)), float(test_misclass),
        float(test_misclass_last_utterance), test_empirical_mutual_information)

    if compute_genre_specific_metrics:
        print '** test perplexity per genre', numpy.exp(
            test_cost_per_genre / test_wordpreds_done_per_genre)
        print '** test_mi_per_genre', test_mi_per_genre

        print '** words per genre', test_wordpreds_done_per_genre

    # Plot histogram over test costs
    if args.plot_graphs:
        try:
            pylab.figure()
            bins = range(0, 50, 1)
            pylab.hist(numpy.exp(test_cost_list), normed=1, histtype='bar')
            pylab.savefig(model.state['save_dir'] + '/' +
                          model.state['run_id'] + "_" + model.state['prefix'] +
                          'Test_WordPerplexities.png')
        except:
            pass

    # Print 5 of 10% test samples with highest log-likelihood
    if args.plot_graphs:
        print " highest word log-likelihood test samples: "
        numpy.random.shuffle(test_lowest_triples)
        for i in range(test_extrema_samples_to_print):
            print "      Sample: {}".format(" ".join(
                model.indices_to_words(numpy.ravel(
                    test_lowest_triples[i, :]))))

        print " lowest word log-likelihood test samples: "
        numpy.random.shuffle(test_highest_triples)
        for i in range(test_extrema_samples_to_print):
            print "      Sample: {}".format(" ".join(
                model.indices_to_words(numpy.ravel(
                    test_highest_triples[i, :]))))

    # Plot histogram over empirical pointwise mutual informations
    if args.plot_graphs:
        try:
            pylab.figure()
            bins = range(0, 100, 1)
            pylab.hist(test_pmi_list, normed=1, histtype='bar')
            pylab.savefig(model.state['save_dir'] + '/' +
                          model.state['run_id'] + "_" + model.state['prefix'] +
                          'Test_PMI.png')
        except:
            pass

    # To estimate the standard deviations, we assume that triples across documents (movies) are independent.
    # We compute the mean metric for each document, and then the variance between documents.
    # We then use the between document variance to compute the:
    # Let m be a metric:
    # Var[m] = Var[1/(words in total) \sum_d \sum_i m_{di}]
    #        = Var[1/(words in total) \sum_d (words in doc d)/(words in doc d) \sum_i m_{di}]
    #        = \sum_d (words in doc d)^2/(words in total)^2 Var [ 1/(words in doc d) \sum_i ]
    #        = \sum_d (words in doc d)^2/(words in total)^2 sigma^2
    #
    # where sigma^2 is the variance computed for the means across documents.

    # negative log-likelihood for each document (movie)
    per_document_test_cost = numpy.zeros((len(unique_document_ids)),
                                         dtype='float32')
    # negative log-likelihood for last utterance for each document (movie)
    per_document_test_cost_last_utterance = numpy.zeros(
        (len(unique_document_ids)), dtype='float32')
    # misclassification error for each document (movie)
    per_document_test_misclass = numpy.zeros((len(unique_document_ids)),
                                             dtype='float32')
    # misclassification error for last utterance for each document (movie)
    per_document_test_misclass_last_utterance = numpy.zeros(
        (len(unique_document_ids)), dtype='float32')

    # Compute standard deviations based on means across documents (sigma^2 above)
    all_words_squared = 0  # \sum_d (words in doc d)^2
    all_words_in_last_utterance_squared = 0  # \sum_d (words in last utterance of doc d)^2
    for doc_id in range(len(unique_document_ids)):
        doc_indices = numpy.where(document_ids == unique_document_ids[doc_id])

        per_document_test_cost[doc_id] = numpy.sum(
            test_cost_list[doc_indices]) / numpy.sum(
                words_in_triples_list[doc_indices])
        per_document_test_cost_last_utterance[doc_id] = numpy.sum(
            test_cost_last_utterance_marginal_list[doc_indices]) / numpy.sum(
                words_in_last_utterance_list[doc_indices])

        per_document_test_misclass[doc_id] = numpy.sum(
            test_misclass_list[doc_indices]) / numpy.sum(
                words_in_triples_list[doc_indices])
        per_document_test_misclass_last_utterance[doc_id] = numpy.sum(
            test_misclass_last_utterance_list[doc_indices]) / numpy.sum(
                words_in_last_utterance_list[doc_indices])

        all_words_squared += float(
            numpy.sum(words_in_triples_list[doc_indices]))**2
        all_words_in_last_utterance_squared += float(
            numpy.sum(words_in_last_utterance_list[doc_indices]))**2

    # Sanity check that all documents are being used in the standard deviation calculations
    assert (numpy.sum(words_in_triples_list) == test_wordpreds_done)
    assert (numpy.sum(words_in_last_utterance_list) ==
            test_wordpreds_done_last_utterance)

    # Compute final standard deviation equation and print the standard deviations
    per_document_test_cost_variance = numpy.var(
        per_document_test_cost) * float(all_words_squared) / float(
            test_wordpreds_done**2)
    per_document_test_cost_last_utterance_variance = numpy.var(
        per_document_test_cost_last_utterance) * float(
            all_words_in_last_utterance_squared) / float(
                test_wordpreds_done_last_utterance**2)
    per_document_test_misclass_variance = numpy.var(
        per_document_test_misclass) * float(all_words_squared) / float(
            test_wordpreds_done**2)
    per_document_test_misclass_last_utterance_variance = numpy.var(
        per_document_test_misclass_last_utterance) * float(
            all_words_in_last_utterance_squared) / float(
                test_wordpreds_done_last_utterance**2)

    print 'Standard deviations:'
    print "** test cost (NLL) = ", math.sqrt(per_document_test_cost_variance)
    print "** test perplexity (NLL) = ", math.sqrt(
        (math.exp(per_document_test_cost_variance) - 1) *
        math.exp(2 * test_cost + per_document_test_cost_variance))

    print "** test cost last utterance (NLL) = ", math.sqrt(
        per_document_test_cost_last_utterance_variance)
    print "** test perplexity last utterance  (NLL) = ", math.sqrt(
        (math.exp(per_document_test_cost_last_utterance_variance) - 1) *
        math.exp(2 * test_cost +
                 per_document_test_cost_last_utterance_variance))

    print "** test word-error = ", math.sqrt(
        per_document_test_misclass_variance)
    print "** test last utterance word-error = ", math.sqrt(
        per_document_test_misclass_last_utterance_variance)

    logger.debug("All done, exiting...")