def _build_vhred_model(self):
        # Update the state dictionary.
        state = VHRED_prototype_state()
        model_prefix = self.config['vhred_prefix']
        state_path = model_prefix + "_state.pkl"
        model_path = model_prefix + "_model.npz"
        with open(state_path, 'rb') as handle:
            state.update(cPickle.load(handle))
        # Update the bs for the current data.
        state['bs'] = 100
        state['dictionary'] = self.f_dict

        # Create the model:
        ## load trained parameters
        model = VHRED_DialogEncoderDecoder(state)
        model.load(model_path)
        enc_fn = model.build_encoder_function()
        dec_fn = model.build_decoder_encoding()

        return model, enc_fn, dec_fn
Beispiel #2
0
def main(model_prefix, dialogue_file):
    state = prototype_state()

    state_path = model_prefix + "_state.pkl"
    model_path = model_prefix + "_model.npz"

    with open(state_path) as src:
        state.update(cPickle.load(src))

    logging.basicConfig(
        level=getattr(logging, state['level']),
        format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")

    state['bs'] = 10

    model = DialogEncoderDecoder(state)

    if os.path.isfile(model_path):
        logger.debug("Loading previous model")
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")

    contexts = [[]]
    lines = open(dialogue_file, "r").readlines()
    if len(lines):
        contexts = [x.strip() for x in lines]

    model_compute_encoder_state = model.build_encoder_function()
    model_compute_decoder_state = model.build_decoder_encoding()
    dialogue_encodings = []

    # Start loop
    joined_contexts = []
    batch_index = 0
    batch_total = int(math.ceil(float(len(contexts)) / float(model.bs)))
    for context_id, context_sentences in enumerate(contexts):
        # Convert contexts into list of ids
        joined_context = []

        if len(context_sentences) == 0:
            joined_context = [model.eos_sym]
        else:
            joined_context = model.words_to_indices(context_sentences.split())

            if joined_context[0] != model.eos_sym:
                joined_context = [model.eos_sym] + joined_context

            if joined_context[-1] != model.eos_sym:
                joined_context += [model.eos_sym]

        joined_contexts.append(joined_context)

        if len(joined_contexts) == model.bs:
            batch_index = batch_index + 1
            logger.debug("[COMPUTE] - Got batch %d / %d" %
                         (batch_index, batch_total))
            encs = compute_encodings(joined_contexts, model,
                                     model_compute_encoder_state,
                                     model_compute_decoder_state)
            for i in range(len(encs)):
                dialogue_encodings.append(encs[i])

            joined_contexts = []

    if len(joined_contexts) > 0:
        logger.debug("[COMPUTE] - Got batch %d / %d" %
                     (batch_total, batch_total))
        encs = compute_encodings(joined_contexts, model,
                                 model_compute_encoder_state,
                                 model_compute_decoder_state)
        for i in range(len(encs)):
            dialogue_encodings.append(encs[i])

    return dialogue_encodings