예제 #1
0
def get_models():
    args = parse_args()

    state_en2fr = prototype_state()
    if hasattr(args, 'state_en2fr'):
        with open(args.state_en2fr) as src:
            state_en2fr.update(cPickle.load(src))
    state_en2fr.update(eval("dict({})".format(args.changes)))

    state_fr2en = prototype_state()
    if hasattr(args, 'state_fr2en') and args.state_fr2en is not None:
        with open(args.state_fr2en) as src:
            state_fr2en.update(cPickle.load(src))
    state_fr2en.update(eval("dict({})".format(args.changes)))

    rng = numpy.random.RandomState(state_en2fr['seed'])
    enc_dec_en_2_fr = RNNEncoderDecoder(state_en2fr, rng, skip_init=True)
    enc_dec_en_2_fr.build()
    lm_model_en_2_fr = enc_dec_en_2_fr.create_lm_model()
    lm_model_en_2_fr.load(args.model_path_en2fr)
    indx_word_src = cPickle.load(open(state_en2fr['word_indx'],'rb'))
    indx_word_trgt = cPickle.load(open(state_en2fr['word_indx_trgt'], 'rb'))

    if hasattr(args, 'state_fr2en') and args.state_fr2en is not None:
        rng = numpy.random.RandomState(state_fr2en['seed'])
        enc_dec_fr_2_en = RNNEncoderDecoder(state_fr2en, rng, skip_init=True)
        enc_dec_fr_2_en.build()
        lm_model_fr_2_en = enc_dec_fr_2_en.create_lm_model()
        lm_model_fr_2_en.load(args.model_path_fr2en)

        return [lm_model_en_2_fr, enc_dec_en_2_fr, indx_word_src, indx_word_trgt, state_en2fr, \
            lm_model_fr_2_en, enc_dec_fr_2_en, state_fr2en]
    else:
        return [lm_model_en_2_fr, enc_dec_en_2_fr, indx_word_src, indx_word_trgt, state_en2fr,\
                None, None, None]
예제 #2
0
파일: suggest.py 프로젝트: kdjyss/hred-qs
def main():
    args = parse_args()
    state = prototype_state()
    seqs = [[]]
    state_path = args.model_prefix + "_state.pkl"
    model_path = args.model_prefix + "_model.npz"

    with open(state_path) as src:
        state.update(cPickle.load(src))
    logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")
    model = SessionEncoderDecoder(state)
    if os.path.isfile(model_path):
        logger.debug("Loading previous model")
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")
    beam_search = BeamSampler(model)
    beam_search.compile()
    for ctx_file in args.ext_file:
        lines = open(ctx_file, "r").readlines()
        seqs = context_to_indices(lines, model)
        sugg_text, sugg_ranks, sugg_costs = \
            sample(model, seqs=seqs, ignore_unk=args.ignore_unk,
                   beam_search=beam_search, n_samples=args.n_samples, session=args.session)
        output_path = ctx_file + "_" + model.state['model_id']
        print_output_suggestions(output_path, lines, sugg_text, sugg_ranks, sugg_costs)
def init(path):
    
    ##------------------------------------------------------------------------------##
    #                Compile and load the model.                                     #          
    #                Compile the encoder.                                            #                                           
    ##------------------------------------------------------------------------------##

    state = prototype_state()
    state_path   = path  + "_state.pkl"
    model_path   = path  + "_model.npz"
    
    with open(state_path) as src:
        state.update(cPickle.load(src))
    
    logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")

    model = DialogEncoderDecoder(state) 
    
    if os.path.isfile(model_path):
        logger.debug("Loading previous model")
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")
    
    encoding_function = model.build_encoder_function()
    
    return model, encoding_function
예제 #4
0
def init(path):

    ##------------------------------------------------------------------------------##
    #                Compile and load the model.                                     #
    #                Compile the encoder.                                            #
    ##------------------------------------------------------------------------------##

    state = prototype_state()
    state_path = path + "_state.pkl"
    model_path = path + "_model.npz"

    with open(state_path) as src:
        state.update(pickle.load(src))

    logging.basicConfig(
        level=getattr(logging, state['level']),
        format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")

    model = DialogEncoderDecoder(state)

    if os.path.isfile(model_path):
        logger.debug("Loading previous model")
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")

    encoding_function = model.build_encoder_function()

    return model, encoding_function
예제 #5
0
def main():
    args = parse_args()
    state = prototype_state()
    seqs = [[]]
    state_path = args.model_prefix + "_state.pkl"
    model_path = args.model_prefix + "_model.npz"

    with open(state_path) as src:
        state.update(cPickle.load(src))
    logging.basicConfig(
        level=getattr(logging, state['level']),
        format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")
    model = SessionEncoderDecoder(state)
    if os.path.isfile(model_path):
        logger.debug("Loading previous model")
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")
    beam_search = BeamSampler(model)
    beam_search.compile()
    for ctx_file in args.ext_file:
        lines = open(ctx_file, "r").readlines()
        seqs = context_to_indices(lines, model)
        sugg_text, sugg_ranks, sugg_costs = \
            sample(model, seqs=seqs, ignore_unk=args.ignore_unk,
                   beam_search=beam_search, n_samples=args.n_samples, session=args.session)
        output_path = ctx_file + "_" + model.state['model_id']
        print_output_suggestions(output_path, lines, sugg_text, sugg_ranks,
                                 sugg_costs)
예제 #6
0
def main():
    args = parse_args()
    state = prototype_state()
    seqs = [[]]
     
    state_path = args.model_prefix + "_state.pkl"
    model_path = args.model_prefix + "_model.npz"

    with open(state_path) as src:
        state.update(cPickle.load(src)) 
    
    logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")
     
    model = SessionEncoderDecoder(state)
    if os.path.isfile(model_path):
        logger.debug("Loading previous model")
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")
     
    beam_search = BeamSearch(model)
    beam_search.compile()
    
    sugg_text, sugg_ranks, sugg_costs = \
        sample(model, seqs=[[]], ignore_unk=args.ignore_unk, 
                beam_search=beam_search, n_samples=args.n_samples)
         
    print sugg_text
예제 #7
0
파일: sample.py 프로젝트: hydercps/hred-qs
def main():
    args = parse_args()
    state = prototype_state()

    state_path = args.model_prefix + "_state.pkl"
    model_path = args.model_prefix + "_model.npz"

    with open(state_path) as src:
        state.update(cPickle.load(src))

    logging.basicConfig(
        level  = getattr(logging, state['level']),
        format = "%(asctime)s: %(name)s: %(levelname)s: %(message)s")

    model = SessionEncoderDecoder(state)
    sampler = search.BeamSampler(model)

    if os.path.isfile(model_path):
        logger.debug("Loading previous model")
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")
    contexts = [[]]
    lines = open(args.context, "r").readlines()
    contexts = [x.strip().split('\t') for x in lines]
    context_samples, context_costs = sampler.sample(
        contexts, n_samples=args.n_samples, ignore_unk=args.ignore_unk,
        verbose=args.verbose)
    # Write to output file
    output_handle = open(args.context + "_HED_" + model.run_id + ".gen", "w")
    for context_sample in context_samples:
        print >> output_handle, '\t'.join(context_sample)
    output_handle.close()
예제 #8
0
파일: sample.py 프로젝트: wqj111186/hred-qs
def main():
    args = parse_args()
    state = prototype_state()

    state_path = args.model_prefix + "_state.pkl"
    model_path = args.model_prefix + "_model.npz"

    with open(state_path) as src:
        state.update(cPickle.load(src))

    logging.basicConfig(
        level=getattr(logging, state['level']),
        format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")

    model = SessionEncoderDecoder(state)
    sampler = search.BeamSampler(model)

    if os.path.isfile(model_path):
        logger.debug("Loading previous model")
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")
    contexts = [[]]
    lines = open(args.context, "r").readlines()
    contexts = [x.strip().split('\t') for x in lines]
    context_samples, context_costs = sampler.sample(contexts,
                                                    n_samples=args.n_samples,
                                                    ignore_unk=args.ignore_unk,
                                                    verbose=args.verbose)
    # Write to output file
    output_handle = open(args.context + "_HED_" + model.run_id + ".gen", "w")
    for context_sample in context_samples:
        print >> output_handle, '\t'.join(context_sample)
    output_handle.close()
예제 #9
0
    def __init__(self, model_prefix, dict_file, name):
        # Load the HRED model.
        self.name = name
        state_path = '%s_state.pkl' % model_prefix
        model_path = '%s_model.npz' % model_prefix

        state = prototype_state()
        with open(state_path, 'r') as handle:
            state.update(cPickle.load(handle))
        state['dictionary'] = dict_file
        print 'Building %s model...' % name
        self.model = DialogEncoderDecoder(state)
        print 'Building sampler...'
        self.sampler = search.BeamSampler(self.model)
        print 'Loading model...'
        self.model.load(model_path)
        print 'Model built (%s).' % name

        self.speaker_token = '<first_speaker>'
        if name == 'reddit':
            self.speaker_token = '<speaker_1>'

        self.remove_tokens = ['<first_speaker>', '<at>', '<second_speaker>']
        for i in range(0, 10):
            self.remove_tokens.append('<speaker_%d>' % i)
예제 #10
0
파일: score.py 프로젝트: wqj111186/hred-qs
def main():
    args = parse_args()
    state = prototype_state()

    state_path = args.model_prefix + "_state.pkl"
    model_path = args.model_prefix + "_model.npz"

    with open(state_path) as src:
        state.update(cPickle.load(src))

    logging.basicConfig(
        level=getattr(logging, state['level']),
        format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")

    model = SessionEncoderDecoder(state)
    scorer = Scorer(model)

    if os.path.isfile(model_path):
        logger.debug("Loading previous model")
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")

    contexts = [[]]
    lines = open(args.context, "r").readlines()
    contexts = [x.strip().split('\t') for x in lines]

    targets = [[]]
    lines = open(args.targets, "r").readlines()
    targets = [x.strip().split('\t') for x in lines]

    logging.info('Normalizing by length = {}'.format(args.normalize_by_length))
    logging.info('Multi feature = {}'.format(args.multi_feature))

    costs = scorer.score(contexts,
                         targets,
                         verbose=args.verbose,
                         normalize_by_length=args.normalize_by_length,
                         N=args.multi_feature)

    output_handle = open(args.targets + "_HED_" + ("nn_" if not args.normalize_by_length else "") + \
                         model.run_id + (".f" if args.feature_gen else ".gen"), "w")

    if args.feature_gen:
        print >> output_handle, ' '.join(
            ["%d_HED_" % i + model.run_id for i in range(args.multi_feature)])

    for num_target, target in enumerate(targets):
        reranked = numpy.array(target)[numpy.argsort(costs[num_target])]

        if args.feature_gen:
            for cost in numpy.array(costs[num_target]).T:
                print >> output_handle, ' '.join(map(str, cost))
        else:
            print >> output_handle, '\t'.join(reranked)

        output_handle.flush()
    output_handle.close()
예제 #11
0
def main():
    args = parse_args()
    state = prototype_state()

    state_path = args.model_prefix + "_state.pkl"
    model_path = args.model_prefix + "_model.npz"

    with open(state_path) as src:
        state.update(cPickle.load(src))

    logging.basicConfig(
        level=getattr(logging, state['level']),
        format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")

    model = DialogEncoderDecoder(state)
    if os.path.isfile(model_path):
        logger.debug("Loading previous model")
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")

    logger.info("This model uses " + model.decoder_bias_type + " bias type")

    #sampler = search.RandomSampler(model)
    sampler = search.BeamSampler(model)

    # Start chat loop
    utterances = collections.deque()

    while (True):
        var = raw_input("User - ")

        # Increase number of utterances. We just set it to zero for simplicity so that model has no memory.
        # But it works fine if we increase this number
        while len(utterances) > 0:
            utterances.popleft()

        current_utterance = [model.end_sym_sentence] + [
            '<first_speaker>'
        ] + var.split() + [model.end_sym_sentence]
        utterances.append(current_utterance)

        #TODO Sample a random reply. To spicy it up, we could pick the longest reply or the reply with the fewest placeholders...
        seqs = list(itertools.chain(*utterances))

        #TODO Retrieve only replies which are generated for second speaker...
        sentences = sample(model, \
             seqs=[seqs], ignore_unk=args.ignore_unk, \
             sampler=sampler, n_samples=5)

        if len(sentences) == 0:
            raise ValueError("Generation error, no sentences were produced!")

        utterances.append(sentences[0][0].split())

        reply = sentences[0][0].encode('utf-8')
        print "AI - ", remove_speaker_tokens(reply)
예제 #12
0
def main():
    args = parse_args()
    state = prototype_state()

    state_path = args.model_prefix + "_state.pkl"
    model_path = args.model_prefix + "_model.npz"

    with open(state_path) as src:
        state.update(cPickle.load(src))

    logging.basicConfig(
        level=getattr(logging, state['level']),
        format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")

    state['compute_training_updates'] = False

    model = DialogEncoderDecoder(state)

    sampler = search.RandomSampler(model)
    if args.beam_search:
        sampler = search.BeamSampler(model)
    if args.diverse_beam_search:
        sampler = search.DiverseBeamSampler(model, args.gamma)

    if os.path.isfile(model_path):
        logger.debug("Loading previous model")
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")

    contexts = [[]]
    lines = open(args.context, "r").readlines()
    if len(lines):
        contexts = [x.strip() for x in lines]

    print('Sampling started...')
    context_samples, context_costs = sampler.sample(contexts,
                                                    n_samples=args.n_samples,
                                                    n_turns=args.n_turns,
                                                    ignore_unk=args.ignore_unk,
                                                    verbose=args.verbose,
                                                    return_words=True)
    print('Sampling finished.')
    print('Saving to file...')

    # Write to output file
    print type(context_samples)
    print type(context_samples[0])
    print context_samples[0]
    output_handle = open(args.output, "w")
    for context_sample in context_samples:

        print >> output_handle, '\t'.join(context_sample)
    output_handle.close()
    print('Saving to file finished.')
    print('All done!')
예제 #13
0
def main():
    args = parse_args()
    state = prototype_state()
   
    state_path = args.model_prefix + "_state.pkl"
    model_path = args.model_prefix + "_model.npz"

    with open(state_path) as src:
        state.update(cPickle.load(src)) 
    
    logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")
     
    model = DialogEncoderDecoder(state)
    if os.path.isfile(model_path):
        logger.debug("Loading previous model")
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")
    
    logger.info("This model uses " + model.decoder_bias_type + " bias type")

    beam_search = None
    sampler = None

    beam_search = search.BeamSearch(model)
    beam_search.compile()

    # Start chat loop    
    utterances = collections.deque()
    
    while (True):
       var = raw_input("User - ")

       while len(utterances) > 2:
           utterances.popleft()
         
       current_utterance = [ model.start_sym_sentence ] + var.split() + [ model.end_sym_sentence ]
       utterances.append(current_utterance)
         
       # Sample a random reply. To spicy it up, we could pick the longest reply or the reply with the fewest placeholders...
       seqs = list(itertools.chain(*utterances))

       sentences = sample(model, \
            seqs=[seqs], ignore_unk=args.ignore_unk, \
            beam_search=beam_search, n_samples=5)

       if len(sentences) == 0:
           raise ValueError("Generation error, no sentences were produced!")

       reply = " ".join(sentences[0]).encode('utf-8') 
       print "AI - ", reply
         
       utterances.append(sentences[0])
예제 #14
0
def main():
    args = parse_args()
    state = prototype_state()
   
    state_path = args.model_prefix + "_state.pkl"
    model_path = args.model_prefix + "_model.npz"

    with open(state_path) as src:
        state.update(cPickle.load(src)) 
    
    logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")
     
    model = DialogEncoderDecoder(state)
    if os.path.isfile(model_path):
        logger.debug("Loading previous model")
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")
    
    logger.info("This model uses " + model.decoder_bias_type + " bias type")

    #sampler = search.RandomSampler(model)
    sampler = search.BeamSampler(model)

    # Start chat loop    
    utterances = collections.deque()
    
    while (True):
       var = raw_input("User - ")

       # Increase number of utterances. We just set it to zero for simplicity so that model has no memory. 
       # But it works fine if we increase this number
       while len(utterances) > 0:
           utterances.popleft()
         
       current_utterance = [ model.end_sym_sentence ] + ['<first_speaker>'] + var.split() + [ model.end_sym_sentence ]
       utterances.append(current_utterance)
         
       #TODO Sample a random reply. To spicy it up, we could pick the longest reply or the reply with the fewest placeholders...
       seqs = list(itertools.chain(*utterances))

       #TODO Retrieve only replies which are generated for second speaker...
       sentences = sample(model, \
            seqs=[seqs], ignore_unk=args.ignore_unk, \
            sampler=sampler, n_samples=5)

       if len(sentences) == 0:
           raise ValueError("Generation error, no sentences were produced!")

       utterances.append(sentences[0][0].split())

       reply = sentences[0][0].encode('utf-8')
       print "AI - ", remove_speaker_tokens(reply)
예제 #15
0
def main():
    args = parse_args()
    state = prototype_state()

    state_path = args.model_prefix + "_state.pkl"
    model_path = args.model_prefix + "_model.npz"

    with open(state_path) as src:
        state.update(cPickle.load(src))

    logging.basicConfig(
        level=getattr(logging, state["level"]), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s"
    )

    model = DialogEncoderDecoder(state)
    if os.path.isfile(model_path):
        logger.debug("Loading previous model")
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")

    logger.info("This model uses " + model.decoder_bias_type + " bias type")

    beam_search = None
    sampler = None

    beam_search = search.BeamSearch(model)
    beam_search.compile()

    # Start chat loop
    utterances = collections.deque()

    while True:
        var = raw_input("User - ")

        while len(utterances) > 2:
            utterances.popleft()

        current_utterance = [model.start_sym_sentence] + var.split() + [model.end_sym_sentence]
        utterances.append(current_utterance)

        # Sample a random reply. To spicy it up, we could pick the longest reply or the reply with the fewest placeholders...
        seqs = list(itertools.chain(*utterances))

        sentences = sample(model, seqs=[seqs], ignore_unk=args.ignore_unk, beam_search=beam_search, n_samples=5)

        if len(sentences) == 0:
            raise ValueError("Generation error, no sentences were produced!")

        reply = " ".join(sentences[0]).encode("utf-8")
        print "AI - ", reply

        utterances.append(sentences[0])
예제 #16
0
def main():
    args = parse_args()
    state = prototype_state()

    state_path = args.model_prefix + "_state.pkl"
    model_path = args.model_prefix + "_model.npz"
    timing_path = args.model_prefix + "_timing.npz"

    with open(state_path, 'r') as src:
        state.update(cPickle.load(src))
    with open(timing_path, 'r') as src:
        timings = dict(numpy.load(src))

    state['compute_training_updates'] = False

    logging.basicConfig(
        level=getattr(logging, state['level']),
        format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")

    print "\nLoaded previous state, model, timings:"
    print "state:"
    print state
    print "timings:"
    print timings

    print "\nBuilding model..."
    model = DialogEncoderDecoder(state)

    sampler = search.RandomSampler(model)
    if args.beam_search:
        sampler = search.BeamSampler(model)

    if os.path.isfile(model_path):
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")
    print "build.\n"

    context = []
    while True:
        line = raw_input("user: "******"<first_speaker> <at> " + line + " </s> ")
        print "context: ", [' '.join(context[-4:])]
        context_samples, context_costs = sampler.sample(
            [' '.join(context[-4:])],
            ignore_unk=args.ignore_unk,
            verbose=args.verbose,
            return_words=True)

        print "bot:", context_samples
        context.append(context_samples[0][0] + " </s> ")
        print "cost:", context_costs
예제 #17
0
def main():
    args = parse_args()
    state = prototype_state()

    state_path = args.model_prefix + "_state.pkl"
    model_path = args.model_prefix + "_model.npz"

    with open(state_path) as src:
        state.update(cPickle.load(src))

    logging.basicConfig(
        level=getattr(logging, state["level"]), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s"
    )

    model = DialogEncoderDecoder(state)

    sampler = search.RandomSampler(model)
    if args.beam_search:
        sampler = search.BeamSampler(model)

    if os.path.isfile(model_path):
        logger.debug("Loading previous model")
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")

    contexts = [[]]
    lines = open(args.context, "r").readlines()
    if len(lines):
        contexts = [x.strip() for x in lines]

    print("Sampling started...")
    context_samples, context_costs = sampler.sample(
        contexts, n_samples=args.n_samples, n_turns=args.n_turns, ignore_unk=args.ignore_unk, verbose=args.verbose
    )
    print("Sampling finished.")
    print("Saving to file...")

    # Write to output file
    output_handle = open(args.output, "w")
    for context_sample in context_samples:
        print >> output_handle, "\t".join(context_sample)
    output_handle.close()
    print("Saving to file finished.")
    print("All done!")
def main():
    args = parse_args()

    # Load state file
    state = prototype_state()
    state_path = args.model_prefix + "_state.pkl"
    with open(state_path) as src:
        state.update(pickle.load(src))

    # Load dictionary

    # Load dictionaries to convert str to idx and vice-versa
    raw_dict = pickle.load(open(state['dictionary'], 'rb'))

    str_to_idx = dict([(tok, tok_id) for tok, tok_id, _, _ in raw_dict])
    idx_to_str = dict([(tok_id, tok) for tok, tok_id, freq, _ in raw_dict])

    assert len(args.test_file) > 3
    test_contexts = ''
    test_responses = ''
    utterances_to_predict = args.utterances_to_predict
    assert args.utterances_to_predict > 0

    # Is it a pickle file? Then process using model dictionaries..
    if args.test_file[len(args.test_file) - 4:len(args.test_file)] == '.pkl':
        test_dialogues = pickle.load(open(args.test_file, 'rb'))
        for test_dialogueid, test_dialogue in enumerate(test_dialogues):
            if test_dialogueid % 100 == 0:
                print('test_dialogue', test_dialogueid)

            utterances = []
            current_utterance = []
            for word in test_dialogue:
                current_utterance += [word]
                if word == state['eos_sym']:
                    utterances += [current_utterance]
                    current_utterance = []

            if args.leave_out_short_dialogues:
                if len(utterances) <= utterances_to_predict + 1:
                    continue

            context_utterances = []
            prediction_utterances = []
            for utteranceid, utterance in enumerate(utterances):
                if utteranceid >= len(utterances) - utterances_to_predict:
                    prediction_utterances += utterance
                else:
                    context_utterances += utterance

            if args.max_words_in_context > 0:
                while len(context_utterances) > args.max_words_in_context:
                    del context_utterances[0]

            test_contexts += indices_to_words(idx_to_str,
                                              context_utterances) + '\n'
            test_responses += indices_to_words(idx_to_str,
                                               prediction_utterances) + '\n'

    else:  # Assume it's a text file

        test_dialogues = [[]]
        lines = open(args.test_file, "r").readlines()
        if len(lines):
            test_dialogues = [x.strip() for x in lines]

        for test_dialogueid, test_dialogue in enumerate(test_dialogues):
            if test_dialogueid % 100 == 0:
                print('test_dialogue', test_dialogueid)

            utterances = []
            current_utterance = []
            for word in test_dialogue.split():
                current_utterance += [word]
                if word == state['end_sym_utterance']:
                    utterances += [current_utterance]
                    current_utterance = []

            if args.leave_out_short_dialogues:
                if len(utterances) <= utterances_to_predict + 1:
                    continue

            context_utterances = []
            prediction_utterances = []
            for utteranceid, utterance in enumerate(utterances):
                if utteranceid >= len(utterances) - utterances_to_predict:
                    prediction_utterances += utterance
                else:
                    context_utterances += utterance

            if args.max_words_in_context > 0:
                while len(context_utterances) > args.max_words_in_context:
                    del context_utterances[0]

            test_contexts += ' '.join(context_utterances) + '\n'
            test_responses += ' '.join(prediction_utterances) + '\n'

    print('Writing to files...')
    f = open('test_contexts.txt', 'w')
    f.write(test_contexts)
    f.close()

    f = open('test_responses.txt', 'w')
    f.write(test_responses)
    f.close()

    print('All done!')
예제 #19
0
    def preprocess_data(twitter_contexts, twitter_gtresponses, twitter_modelresponses, context_embedding_file, \
            gtresponses_embedding_file, modelresponses_embedding_file, use_precomputed_embeddings, liu=False):
        # Encode text into BPE format
        twitter_context_ids = strs_to_idxs(twitter_contexts, twitter_bpe, twitter_str_to_idx)
        twitter_gtresponse_ids = strs_to_idxs(twitter_gtresponses, twitter_bpe, twitter_str_to_idx)
        twitter_modelresponse_ids = strs_to_idxs(twitter_modelresponses, twitter_bpe, twitter_str_to_idx)
        
        # Compute VHRED embeddings
        if use_precomputed_embeddings:
            print 'Loading precomputed embeddings...'
            with open(context_embedding_file, 'r') as f1:
                twitter_context_embeddings = cPickle.load(f1)
            with open(gtresponses_embedding_file, 'r') as f1:
                twitter_gtresponse_embeddings = cPickle.load(f1)
            with open(modelresponses_embedding_file, 'r') as f1:
                twitter_modelresponse_embeddings = cPickle.load(f1)
        
        elif 'gpu' in theano.config.device.lower():
            print 'Loading model...'
            state = prototype_state()
            state_path = twitter_model_prefix + "_state.pkl"
            model_path = twitter_model_prefix + "_model.npz"

            with open(state_path) as src:
                state.update(cPickle.load(src))

            state['bs'] = 20
            state['dictionary'] = twitter_model_dictionary

            model = DialogEncoderDecoder(state) 
            
            print 'Computing context embeddings...'
            twitter_context_embeddings = compute_model_embeddings(twitter_context_ids, model, embedding_type)
            with open(context_embedding_file, 'w') as f1:
                cPickle.dump(twitter_context_embeddings, f1)
            print 'Computing ground truth response embeddings...'
            twitter_gtresponse_embeddings = compute_model_embeddings(twitter_gtresponse_ids, model, embedding_type)
            with open(gtresponses_embedding_file, 'w') as f1:
                cPickle.dump(twitter_gtresponse_embeddings, f1)
            print 'Computing model response embeddings...'
            twitter_modelresponse_embeddings = compute_model_embeddings(twitter_modelresponse_ids, model, embedding_type)
            with open(modelresponses_embedding_file, 'w') as f1:
                cPickle.dump(twitter_modelresponse_embeddings, f1)
       
        else:
            # Set embeddings to 0 for now. alternatively, we can load them from disc...
            #embeddings = cPickle.load(open(embedding_file, 'rb'))
            print 'ERROR: No GPU specified!'
            print ' To save testing time, model will be trained with zero context / response embeddings...'
            twitter_context_embeddings = np.zeros((len(twitter_context_embeddings), 3, emb_dim))
            twitter_gtresponses_embedding = np.zeros((len(twitter_context_embeddings), 3, emb_dim))
            twitter_modelresponse_embeddings = np.zeros((len(twitter_context_embeddings), 3, emb_dim))

        if not liu:
            # Copy the contexts and gt responses 4 times (to align with the model responses)
            temp_c_emb = []
            temp_gt_emb = []
            temp_gt = []
            for i in xrange(len(twitter_context_embeddings)):
                temp_c_emb.append([twitter_context_embeddings[i]]*4)
                temp_gt_emb.append([twitter_gtresponse_embeddings[i]]*4)
                temp_gt.append([twitter_gtresponses[i]]*4)
            twitter_context_embeddings = flatten(temp_c_emb)
            twitter_gtresponse_embeddings = flatten(temp_gt_emb)
            twitter_gtresponses = flatten(temp_gt)

        assert len(twitter_context_embeddings) == len(twitter_gtresponse_embeddings)
        assert len(twitter_context_embeddings) == len(twitter_modelresponse_embeddings)

        emb_dim = twitter_context_embeddings[0].shape[0]
        
        twitter_dialogue_embeddings = np.zeros((len(twitter_context_embeddings), 3, emb_dim))
        for i in range(len(twitter_context_embeddings)):
            twitter_dialogue_embeddings[i, 0, :] =  twitter_context_embeddings[i]
            twitter_dialogue_embeddings[i, 1, :] =  twitter_gtresponse_embeddings[i]
            twitter_dialogue_embeddings[i, 2, :] =  twitter_modelresponse_embeddings[i]
     
        print 'Computing auxiliary features...'
        if use_aux_features:
            aux_features = get_auxiliary_features(twitter_contexts, twitter_gtresponses, twitter_modelresponses, len(twitter_modelresponses))
        else:
            aux_features = np.zeros((len(twitter_modelresponses), 5))
        
        return twitter_dialogue_embeddings, aux_features 
예제 #20
0
def main():
    args = parse_args()

    # Load state file
    state = prototype_state()
    state_path = args.model_prefix + "_state.pkl"
    with open(state_path) as src:
        state.update(cPickle.load(src))

    # Load dictionary

    # Load dictionaries to convert str to idx and vice-versa
    #raw_dict = cPickle.load(open(state['dictionary'], 'r'))
    raw_dict = cPickle.load(open('../Data/Dataset.dict.pkl', 'r')) # HACK

    str_to_idx = dict([(tok, tok_id) for tok, tok_id, _, _ in raw_dict])
    idx_to_str = dict([(tok_id, tok) for tok, tok_id, freq, _ in raw_dict])



    test_dialogues = cPickle.load(open(args.test_file, 'r'))
    assert args.utterances_to_predict > 0
    utterances_to_predict = args.utterances_to_predict

    test_contexts = ''
    test_responses = ''


    for test_dialogueid,test_dialogue in enumerate(test_dialogues):
        if test_dialogueid % 100 == 0:
            print 'test_dialogue', test_dialogueid

        utterances = []
        current_utterance = []
        for word in test_dialogue:
            current_utterance += [word]
            if word == state['eos_sym']:
                utterances += [current_utterance]
                current_utterance = []



        context_utterances = []
        prediction_utterances = []
        for utteranceid, utterance in enumerate(utterances):
            if utteranceid >= len(utterances) - utterances_to_predict:
                prediction_utterances += utterance
            else:
                context_utterances += utterance

        if args.max_words_in_context > 0:
            while len(context_utterances) > args.max_words_in_context:
                del context_utterances[0]


        test_contexts += indices_to_words(idx_to_str, context_utterances) + '\n'
        test_responses += indices_to_words(idx_to_str, prediction_utterances) + '\n'

    print('Writing to files...')
    f = open('test_contexts.txt','w')
    f.write(test_contexts)
    f.close()

    f = open('test_responses.txt','w')
    f.write(test_responses)
    f.close()

    print('All done!')
예제 #21
0
def main():
    args = parse_args()
    state = prototype_state()

    state_path = args.model_prefix + "_state.pkl"
    model_path = args.model_prefix + "_model.npz"

    with open(state_path) as src:
        state.update(cPickle.load(src))

    logging.basicConfig(
        level=getattr(logging, state['level']),
        format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")

    # For simplicity, we force the batch size to be one
    state['bs'] = 1
    model = DialogEncoderDecoder(state)

    if os.path.isfile(model_path):
        logger.debug("Loading previous model")
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")

    eval_batch = model.build_eval_function()

    contexts = [[]]
    lines = open(args.context, "r").readlines()
    if len(lines):
        contexts = [x.strip() for x in lines]

    potential_responses_set = open(args.responses, "r").readlines()
    most_probable_responses_string = ''

    print('Retrieval started...')

    for context_idx, context in enumerate(contexts):
        if context_idx % 100 == 0:
            print '     processing example: ' + str(context_idx) + ' / ' + str(
                len(contexts))
        potential_responses = potential_responses_set[context_idx].strip(
        ).split('\t')

        most_probable_response_loglikelihood = -1.0
        most_probable_response = ''

        for potential_response_idx, potential_response in enumerate(
                potential_responses):
            # Convert contexts into list of ids
            dialogue = []
            if len(context) == 0:
                dialogue = [model.eos_sym]
            else:
                sentence_ids = model.words_to_indices(context.split())
                # Add eos tokens
                if len(sentence_ids) > 0:
                    if not sentence_ids[0] == model.eos_sym:
                        sentence_ids = [model.eos_sym] + sentence_ids
                    if not sentence_ids[-1] == model.eos_sym:
                        sentence_ids += [model.eos_sym]
                else:
                    sentence_ids = [model.eos_sym]

                dialogue += sentence_ids

            response = model.words_to_indices(potential_response.split())
            if len(response) > 0:
                if response[0] == model.eos_sym:
                    del response[0]
                if not response[-1] == model.eos_sym:
                    response += [model.eos_sym]

            dialogue += response
            dialogue = numpy.asarray(dialogue, dtype='int32').reshape(
                (len(dialogue), 1))

            dialogue_reversed = model.reverse_utterances(dialogue)

            dialogue_mask = numpy.ones((len(dialogue), 1), dtype='float32')
            dialogue_weight = numpy.ones((len(dialogue), 1), dtype='float32')
            dialogue_reset_mask = numpy.zeros((1), dtype='float32')
            dialogue_ran_vectors = model.rng.normal(size=(
                1, model.latent_gaussian_per_utterance_dim)).astype('float32')
            dialogue_ran_vectors = numpy.tile(dialogue_ran_vectors,
                                              (len(dialogue), 1, 1))

            dialogue_drop_mask = numpy.ones((len(dialogue), 1),
                                            dtype='float32')

            c, _, _, _, _ = eval_batch(dialogue, dialogue_reversed,
                                       len(dialogue), dialogue_mask,
                                       dialogue_weight, dialogue_reset_mask,
                                       dialogue_ran_vectors,
                                       dialogue_drop_mask)
            c = c / len(dialogue)

            print 'c', c

            if (potential_response_idx
                    == 0) or (-c > most_probable_response_loglikelihood):
                most_probable_response_loglikelihood = -c
                most_probable_response = potential_response

        most_probable_responses_string += most_probable_response + '\n'

    print('Retrieval finished.')
    print('Saving to file...')

    # Write to output file
    output_handle = open(args.output, "w")
    output_handle.write(most_probable_responses_string)
    output_handle.close()

    print('Saving to file finished.')
    print('All done!')
def main():
    args = parse_args()
    state = prototype_state()

    state_path = args.model_prefix + "_state.pkl"
    model_path = args.model_prefix + "_model.npz"

    with open(state_path) as src:
        state.update(cPickle.load(src))

    logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")

    model = DialogEncoderDecoder(state) 
    
    if os.path.isfile(model_path):
        logger.debug("Loading previous model")
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")
    
    contexts = [[]]
    lines = open(args.dialogues, "r").readlines()
    if len(lines):
        contexts = [x.strip().split('\t') for x in lines]
   
    model_compute_encoding = model.build_encoder_function()
    dialogue_encodings = []

    # Start loop
    joined_contexts = []
    batch_index = 0
    batch_total = int(math.ceil(float(len(contexts)) / float(model.bs)))
    for context_id, context_sentences in enumerate(contexts):

        # Convert contextes into list of ids
        joined_context = []

        if len(context_sentences) == 0:
            joined_context = [model.eos_sym]
        else:
            joined_context += [model.eos_sym]
            for sentence in context_sentences:
                sentence_ids = model.words_to_indices(sentence.split())
                # Add eos tokens
                joined_context += sentence_ids + [model.eos_sym]

        # HACK
        #for i in range(0, 50):
        #    joined_context += [0] + [model.eos_sym]

        joined_contexts.append(joined_context)

        if len(joined_contexts) == model.bs:
            batch_index = batch_index + 1
            logger.debug("[COMPUTE] - Got batch %d / %d" % (batch_index, batch_total))
            encs = compute_encodings(joined_contexts, model, model_compute_encoding, args.use_second_last_state)
            for i in range(len(encs)):
                dialogue_encodings.append(encs[i])

            joined_contexts = []


    if len(joined_contexts) > 0:
        logger.debug("[COMPUTE] - Got batch %d / %d" % (batch_total, batch_total))
        encs = compute_encodings(joined_contexts, model, model_compute_encoding, args.use_second_last_state)
        for i in range(len(encs)):
            dialogue_encodings.append(encs[i])

    # Save encodings to disc
    cPickle.dump(dialogue_encodings, open(args.output + '.pkl', 'w'))
예제 #23
0
from flask import Flask, jsonify, request
app = Flask(__name__)

from dialog_encdec import DialogEncoderDecoder
from state import prototype_state
import cPickle
import search

#MODEL_PREFIX = 'Output/1485188791.05_RedditHRED'
MODEL_PREFIX = '/home/ml/mnosew1/SavedModels/Twitter/1489857182.98_TwitterModel'

state_path = '%s_state.pkl' % MODEL_PREFIX
model_path = '%s_model.npz' % MODEL_PREFIX

state = prototype_state()
with open(state_path, 'r') as handle:
    state.update(cPickle.load(handle))

#state['dictionary'] = '/home/ml/mnosew1/data/twitter/hred_bpe/Dataset.dict.pkl'
state[
    'dictionary'] = '/home/ml/mnosew1/SavedModels/Twitter/Dataset.dict-5k.pkl'
print 'Building model...'
model = DialogEncoderDecoder(state)
print 'Building sampler...'
sampler = search.BeamSampler(model)
print 'Loading model...'
model.load(model_path)
print 'Model built.'

HISTORY = []
예제 #24
0
def main():
    args = parse_args()

    state = prototype_state()
    with open(args.state) as src:
        state.update(cPickle.load(src))
    state.update(eval("dict({})".format(args.changes)))

    logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")

    rng = numpy.random.RandomState(state['seed'])
    enc_dec = RNNEncoderDecoder(state, rng, skip_init=True)
    enc_dec.build()
    lm_model = enc_dec.create_lm_model()
    lm_model.load(args.model_path)
    indx_word = cPickle.load(open(state['word_indx'],'rb'))

    sampler = None
    beam_search = None
    if args.beam_search:
        beam_search = BeamSearch(enc_dec)
        beam_search.compile()
    else:
        sampler = enc_dec.create_sampler(many_samples=True)

    idict_src = cPickle.load(open(state['indx_word'],'r'))

    if args.source and args.trans:
        # Actually only beam search is currently supported here
        assert beam_search
        assert args.beam_size

        fsrc = open(args.source, 'r')
        ftrans = open(args.trans, 'w')

        start_time = time.time()

        n_samples = args.beam_size
        total_cost = 0.0
        logging.debug("Beam size: {}".format(n_samples))
        for i, line in enumerate(fsrc):
            seqin = line.strip()
            seq, parsed_in = parse_input(state, indx_word, seqin, idx2word=idict_src)
            if args.verbose:
                print "Parsed Input:", parsed_in
            trans, costs, _ = sample(lm_model, seq, n_samples, sampler=sampler,
                    beam_search=beam_search, ignore_unk=args.ignore_unk, normalize=args.normalize)
            best = numpy.argmin(costs)
            print >>ftrans, trans[best]
            if args.verbose:
                print "Translation:", trans[best]
            total_cost += costs[best]
            if (i + 1)  % 100 == 0:
                ftrans.flush()
                logger.debug("Current speed is {} per sentence".
                        format((time.time() - start_time) / (i + 1)))
        print "Total cost of the translations: {}".format(total_cost)

        fsrc.close()
        ftrans.close()
    else:
        while True:
            try:
                seqin = raw_input('Input Sequence: ')
                n_samples = int(raw_input('How many samples? '))
                alpha = None
                if not args.beam_search:
                    alpha = float(raw_input('Inverse Temperature? '))
                seq,parsed_in = parse_input(state, indx_word, seqin, idx2word=idict_src)
                print "Parsed Input:", parsed_in
            except Exception:
                print "Exception while parsing your input:"
                traceback.print_exc()
                continue

            sample(lm_model, seq, n_samples, sampler=sampler,
                    beam_search=beam_search,
                    ignore_unk=args.ignore_unk, normalize=args.normalize,
                    alpha=alpha, verbose=True)
예제 #25
0
def main():
    args = parse_args()
    # Sample args:
    # --state .\dataset\phrase_state.pkl
    # --beam-search --model_path .\dataset\github\phrase_model.npz

    state = prototype_state()
    with open(args.state, 'rb') as src:
        state.update(pickle.load(src))
    state.update(eval("dict({})".format(args.changes)))

    logging.basicConfig(
        level=getattr(logging, state['level']),
        format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")

    rng = numpy.random.RandomState(state['seed'])
    enc_dec = RNNEncoderDecoder(state, rng, skip_init=True)
    enc_dec.build()
    lm_model = enc_dec.create_lm_model()
    lm_model.load(args.model_path)
    indx_word = json.loads(open(state['word_indx'], 'r').readline())
    sampler = None
    beam_search = None
    if args.beam_search:
        beam_search = BeamSearch(enc_dec)
        beam_search.compile()
    else:
        sampler = enc_dec.create_sampler(many_samples=True)

    idict_src = {v: k for k, v in indx_word.items()}

    if args.source and args.trans:
        # Actually only beam search is currently supported here
        assert beam_search
        assert args.beam_size

        fsrc = open(args.source, 'r')
        ftrans = open(args.trans, 'w')
        if args.vec: fvec = open(args.vec, 'w')
        top_num = args.top_num

        start_time = time.time()

        n_samples = args.beam_size
        total_cost = 0.0
        logging.debug("Beam size: {}".format(n_samples))
        for i, line in enumerate(fsrc):
            seqin = line.strip()
            seq, parsed_in = parse_input(state,
                                         indx_word,
                                         seqin,
                                         idx2word=idict_src)
            if args.verbose: print("Parsed Input {}:".format(i), parsed_in)
            context_vec, trans, costs, _ = sample(lm_model,
                                                  seq,
                                                  n_samples,
                                                  sampler=sampler,
                                                  beam_search=beam_search,
                                                  ignore_unk=args.ignore_unk,
                                                  normalize=args.normalize)
            if not trans:  #if no translation
                for ss in range(top_num):
                    print >> ftrans, "a"
            else:
                top = numpy.array(costs).argsort()[0:top_num]
                total_cost += costs[top[0]]
                for k in top:
                    print >> ftrans, trans[k]
                if len(top) < top_num:
                    for ss in range(top_num - len(top)):
                        print >> ftrans, "a"
            if args.verbose and trans:
                print("Translation:{}".format(trans[top[0]]))
                #print ("Context Vector:%d",context_vec)

            if args.vec:  #print context vectors
                numpy.set_printoptions(threshold='nan',
                                       suppress=True,
                                       precision=12,
                                       linewidth=100000)
                if state['forward']:
                    assert context_vec.shape[1] >= state['dim']
                    forwardvec = context_vec[-1][0:state['dim']]
                    vec = forwardvec
                    if state['backward']:
                        assert context_vec.shape[1] == 2 * state['dim']
                        backwardvec = context_vec[0][state['dim']:2 *
                                                     state['dim']]
                        vec = numpy.concatenate((forwardvec, backwardvec))
                    print >> fvec, vec
            if (i + 1) % 100 == 0:
                ftrans.flush()
                if args.vec: fvec.flush()
                logger.debug("Current speed is {} per sentence".format(
                    (time.time() - start_time) / (i + 1)))
        print("Total cost of the translations: {}".format(total_cost))
        fsrc.close()
        ftrans.close()
        if args.vec: fvec.close()
        '''Validate the results and show BLEU results'''
        if args.validate:
            ftrans = open(args.trans, 'r')
            fvalid = open(args.validate, 'r')
            avg_bleu = bleu_analyze(ftrans.readlines(), fvalid.readlines(),
                                    top_num)
            ftrans.close()
            fvalid.close()
            print("Avg bleu of the translations: {}".format(avg_bleu))

    else:
        while True:
            try:
                seqin = raw_input('Input Sequence: ')
                n_samples = int(raw_input('How many samples? '))
                alpha = None
                if not args.beam_search:
                    alpha = float(raw_input('Inverse Temperature? '))
                seq, parsed_in = parse_input(state,
                                             indx_word,
                                             seqin,
                                             idx2word=idict_src)
                print("Parsed Input: {}".format(parsed_in))
            except Exception:
                print("Exception while parsing your input:")
                traceback.print_exc()
                continue

            sample(lm_model,
                   seq,
                   n_samples,
                   sampler=sampler,
                   beam_search=beam_search,
                   ignore_unk=args.ignore_unk,
                   normalize=args.normalize,
                   alpha=alpha,
                   verbose=True)
예제 #26
0
def main():
    args = parse_args()
    state = prototype_state()

    state_path = args.model_prefix + "_state.pkl"
    model_path = args.model_prefix + "_model.npz"

    with open(state_path) as src:
        state.update(cPickle.load(src))

    logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")

    state['bs'] = 20
    state['compute_training_updates'] = False

    if args.mf_inference_steps > 0:
        state['lr'] = 0.01
        state['apply_meanfield_inference'] = True

    model = DialogEncoderDecoder(state) 

    if os.path.isfile(model_path):
        logger.debug("Loading previous model")
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")
    

    eval_batch = model.build_eval_function()

    if args.mf_inference_steps > 0:
        mf_update_batch = model.build_mf_update_function()
        mf_reset_batch = model.build_mf_reset_function()
        
    contexts = [[]]
    lines = open(args.context, "r").readlines()
    if len(lines):
        contexts = [x.strip() for x in lines]

    lines = open(args.responses, "r").readlines()
    if len(lines):
        potential_responses_set = [x.strip() for x in lines]
    
    print('Building data batches...')

    # This is the maximum sequence length we can process on a 12 GB GPU with large HRED models
    max_sequence_length = 80*(80/state['bs'])

    assert len(potential_responses_set) == len(contexts)

    # Note we assume that each example has the same number of potential responses
    # The code can be adapted, however, by simply counting the total number of responses to consider here...
    examples_to_process = len(contexts) * len(potential_responses_set[0].strip().split('\t'))

    all_dialogues = numpy.zeros((max_sequence_length, examples_to_process), dtype='int32')
    all_dialogues_reversed = numpy.zeros((max_sequence_length, examples_to_process), dtype='int32')
    all_dialogues_mask = numpy.zeros((max_sequence_length, examples_to_process), dtype='float32')
    all_dialogues_reset_mask = numpy.zeros((examples_to_process), dtype='float32')
    all_dialogues_drop_mask = numpy.ones((max_sequence_length, examples_to_process), dtype='float32')

    all_dialogues_len = numpy.zeros((examples_to_process), dtype='int32')

    example_idx = -1

    for context_idx, context in enumerate(contexts):
        if context_idx % 100 == 0:
            print '     processing context idx: ' + str(context_idx) + ' / ' + str(len(contexts))

        potential_responses = potential_responses_set[context_idx].strip().split('\t')

        most_probable_response_loglikelihood = -1.0
        most_probable_response = ''

        for potential_response_idx, potential_response in enumerate(potential_responses):
            example_idx += 1

            # Convert contexts into list of ids
            dialogue = []
            if len(context) == 0:
                dialogue = [model.eos_sym]
            else:
                sentence_ids = model.words_to_indices(context.split())
                # Add eos tokens
                if len(sentence_ids) > 0:
                    if not sentence_ids[0] == model.eos_sym:
                        sentence_ids = [model.eos_sym] + sentence_ids
                    if not sentence_ids[-1] == model.eos_sym:
                        sentence_ids += [model.eos_sym]
                else:
                    sentence_ids = [model.eos_sym]


                dialogue += sentence_ids

            response = model.words_to_indices(potential_response.split())
            if len(response) > 0:
                if response[0] == model.eos_sym:
                    del response[0]
                if not response[-1] == model.eos_sym:
                    response += [model.eos_sym]

            if len(response) > 3:
                if ((response[-1] == model.eos_sym)
                 and (response[-2] == model.eod_sym)
                 and (response[-3] == model.eos_sym)):
                    del response[-1]
                    del response[-1]

            dialogue += response

            if context_idx == 0:
                print 'DEBUG INFO'
                print 'dialogue', dialogue

            if len(numpy.where(numpy.asarray(dialogue) == state['eos_sym'])[0]) < 3:
                print 'POTENTIAL PROBLEM WITH DIALOGUE CONTEXT IDX', context_idx
                print '     dialogue', dialogue


            # Trim beginning of dialogue words if the dialogue is too long... this should be a rare event.
            if len(dialogue) > max_sequence_length:
                if state['do_generate_first_utterance']:
                    dialogue = dialogue[len(dialogue)-max_sequence_length:len(dialogue)]
                else: # CreateDebate specific setting
                    dialogue = dialogue[0:max_sequence_length-1]
                    if not dialogue[-1] == state['eos_sym']:
                        dialogue += [state['eos_sym']]

                    


            dialogue = numpy.asarray(dialogue, dtype='int32').reshape((len(dialogue), 1))

            dialogue_reversed = model.reverse_utterances(dialogue)
            dialogue_mask = numpy.ones((len(dialogue)), dtype='float32') 
            dialogue_weight = numpy.ones((len(dialogue)), dtype='float32')
            dialogue_drop_mask = numpy.ones((len(dialogue)), dtype='float32')

            # Add example to large numpy arrays...
            all_dialogues[0:len(dialogue), example_idx] = dialogue[:,0]
            all_dialogues_reversed[0:len(dialogue), example_idx] = dialogue_reversed[:,0]
            all_dialogues_mask[0:len(dialogue), example_idx] = dialogue_mask
            all_dialogues_drop_mask[0:len(dialogue), example_idx] = dialogue_drop_mask

            all_dialogues_len[example_idx] = len(dialogue)


    sorted_dialogue_indices = numpy.argsort(all_dialogues_len)

    all_dialogues_cost = numpy.ones((examples_to_process), dtype='float32')

    batch_count = int(numpy.ceil(float(examples_to_process) / float(state['bs'])))

    print('Computing costs on batches...')
    for batch_idx in reversed(range(batch_count)):
        if batch_idx % 10 == 0:
            print '     processing batch idx: ' + str(batch_idx) + ' / ' + str(batch_count)

        example_index_start = batch_idx * state['bs']
        example_index_end = min((batch_idx+1) * state['bs'], examples_to_process)
        example_indices = sorted_dialogue_indices[example_index_start:example_index_end]

        current_batch_size = len(example_indices)

        max_batch_sequence_length = all_dialogues_len[example_indices[-1]]

        # Initialize batch with zeros
        batch_dialogues = numpy.zeros((max_batch_sequence_length, state['bs']), dtype='int32')
        batch_dialogues_reversed = numpy.zeros((max_batch_sequence_length, state['bs']), dtype='int32')
        batch_dialogues_mask = numpy.zeros((max_batch_sequence_length, state['bs']), dtype='float32')
        batch_dialogues_weight = numpy.ones((max_batch_sequence_length, state['bs']), dtype='float32')
        batch_dialogues_reset_mask = numpy.zeros((state['bs']), dtype='float32')
        batch_dialogues_drop_mask = numpy.ones((max_batch_sequence_length, state['bs']), dtype='float32')

        # Fill in batch with values
        batch_dialogues[:,0:current_batch_size]  = all_dialogues[0:max_batch_sequence_length, example_indices]
        batch_dialogues_reversed[:,0:current_batch_size] = all_dialogues_reversed[0:max_batch_sequence_length, example_indices]
        batch_dialogues_mask[:,0:current_batch_size] = all_dialogues_mask[0:max_batch_sequence_length, example_indices]
        batch_dialogues_drop_mask[:,0:current_batch_size] = all_dialogues_drop_mask[0:max_batch_sequence_length, example_indices]

        if batch_idx < 10:
            print '###'
            print '###'
            print '###'

            print 'DEBUG THIS:'
            print 'batch_dialogues', batch_dialogues[:, 0]
            print 'batch_dialogues_reversed',batch_dialogues_reversed[:, 0]
            print 'max_batch_sequence_length', max_batch_sequence_length
            print 'batch_dialogues_reset_mask', batch_dialogues_reset_mask


        batch_dialogues_ran_gaussian_vectors = numpy.zeros((max_batch_sequence_length, state['bs'], model.latent_gaussian_per_utterance_dim), dtype='float32')
        batch_dialogues_ran_uniform_vectors = numpy.zeros((max_batch_sequence_length, state['bs'], model.latent_piecewise_per_utterance_dim), dtype='float32')

        for idx in range(state['bs']):
            eos_sym_list = numpy.where(batch_dialogues[:, idx] == state['eos_sym'])[0]
            if len(eos_sym_list) > 1:
                second_last_eos_sym = eos_sym_list[-2]
            else:
                print 'WARNING: batch_idx ' + str(batch_idx) + ' example index ' + str(idx) + ' does not have at least two EOS tokens!'
                print '     batch was: ', batch_dialogues[:, idx]

                if len(eos_sym_list) > 0:
                    second_last_eos_sym = eos_sym_list[-1]
                else:
                    second_last_eos_sym = 0

            batch_dialogues_ran_gaussian_vectors[second_last_eos_sym:, idx, :] = model.rng.normal(loc=0, scale=1, size=model.latent_gaussian_per_utterance_dim)
            batch_dialogues_ran_uniform_vectors[second_last_eos_sym:, idx, :] = model.rng.uniform(low=0.0, high=1.0, size=model.latent_piecewise_per_utterance_dim)

            batch_dialogues_mask[0:second_last_eos_sym+1, idx] = 0.0




        if batch_idx < 10:
            print 'batch_dialogues_mask', batch_dialogues_mask[:, 0]
            print 'batch_dialogues_ran_gaussian_vectors', batch_dialogues_ran_gaussian_vectors[:, 0, 0]
            print 'batch_dialogues_ran_gaussian_vectors [-1]', batch_dialogues_ran_gaussian_vectors[:, 0, -1]
            print 'batch_dialogues_ran_uniform_vectors', batch_dialogues_ran_uniform_vectors[:, 0, 0]
            print 'batch_dialogues_ran_uniform_vectors [-1]', batch_dialogues_ran_uniform_vectors[:, 0, -1]


        # Carry out mean-field inference:
        if args.mf_inference_steps > 0:
            mf_reset_batch()
            model.build_mf_reset_function()
            for mf_step in range(args.mf_inference_steps):


                training_cost, kl_divergence_cost_acc, kl_divergences_between_piecewise_prior_and_posterior, kl_divergences_between_gaussian_prior_and_posterior = mf_update_batch(batch_dialogues, batch_dialogues_reversed, max_batch_sequence_length, batch_dialogues_mask, batch_dialogues_reset_mask, batch_dialogues_ran_gaussian_vectors, batch_dialogues_ran_uniform_vectors, batch_dialogues_drop_mask)
                if batch_idx % 10 == 0:
                    print  '  mf_step', mf_step
                    print '     training_cost', training_cost
                    print '     kl_divergence_cost_acc', kl_divergence_cost_acc
                    print '     kl_divergences_between_gaussian_prior_and_posterior',  numpy.sum(kl_divergences_between_gaussian_prior_and_posterior)
                    print '     kl_divergences_between_piecewise_prior_and_posterior', numpy.sum(kl_divergences_between_piecewise_prior_and_posterior)



        eval_batch(batch_dialogues, batch_dialogues_reversed, max_batch_sequence_length, batch_dialogues_mask, batch_dialogues_reset_mask, batch_dialogues_ran_gaussian_vectors, batch_dialogues_ran_uniform_vectors, batch_dialogues_drop_mask)

        # Train on batch and get results
        _, c_list, _ = eval_batch(batch_dialogues, batch_dialogues_reversed, max_batch_sequence_length, batch_dialogues_mask, batch_dialogues_reset_mask, batch_dialogues_ran_gaussian_vectors, batch_dialogues_ran_uniform_vectors, batch_dialogues_drop_mask)
        c_list = c_list.reshape((state['bs'],max_batch_sequence_length-1), order=(1,0))
        c_list = numpy.sum(c_list, axis=1) / numpy.sum(batch_dialogues_mask, axis=0)
        all_dialogues_cost[example_indices] = c_list[0:current_batch_size]

        #print 'hd_input [0]', hd_input[:, 0, 0]
        #print 'hd_input [-]', hd_input[:, 0, -1]


    print('Ranking responses...')
    example_idx = -1
    ranked_all_responses_string = ''
    ranked_all_responses_costs_string = ''
    for context_idx, context in enumerate(contexts):
        if context_idx % 100 == 0:
            print '     processing context idx: ' + str(context_idx) + ' / ' + str(len(contexts))

        potential_responses = potential_responses_set[context_idx].strip().split('\t')
        potential_example_response_indices = numpy.zeros((len(potential_responses)), dtype='int32')

        for potential_response_idx, potential_response in enumerate(potential_responses):
            example_idx += 1
            potential_example_response_indices[potential_response_idx] = example_idx

        ranked_potential_example_response_indices = numpy.argsort(all_dialogues_cost[potential_example_response_indices])

        ranked_responses_costs = all_dialogues_cost[potential_example_response_indices]

        ranked_responses_string = ''
        ranked_responses_costs_string = ''
        for idx in ranked_potential_example_response_indices:
            ranked_responses_string += potential_responses[idx] + '\t'
            ranked_responses_costs_string += str(ranked_responses_costs[idx]) + '\t'

        ranked_all_responses_string += ranked_responses_string[0:len(ranked_responses_string)-1] + '\n'
        ranked_all_responses_costs_string += ranked_responses_costs_string[0:len(ranked_responses_string)-1] + '\n'

    print('Finished all computations!')
    print('Saving to file...')

    # Write to output file
    output_handle = open(args.output + '.txt', "w")
    output_handle.write(ranked_all_responses_string)
    output_handle.close()

    output_handle = open(args.output + '_Costs.txt' , "w")
    output_handle.write(ranked_all_responses_costs_string)
    output_handle.close()

    print('Saving to file finished.')
    print('All done!')
예제 #27
0
def main(model_prefix, dialogue_file, use_second_last_state):
    state = prototype_state()

    state_path = model_prefix + "_state.pkl"
    model_path = model_prefix + "_model.npz"

    with open(state_path) as src:
        state.update(cPickle.load(src))

    logging.basicConfig(
        level=getattr(logging, state['level']),
        format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")

    state['bs'] = 10

    model = DialogEncoderDecoder(state)

    if os.path.isfile(model_path):
        logger.debug("Loading previous model")
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")

    contexts = [[]]
    lines = open(dialogue_file, "r").readlines()
    if len(lines):
        contexts = [x.strip() for x in lines]

    model_compute_encoding = model.build_encoder_function()
    dialogue_encodings = []

    # Start loop
    joined_contexts = []
    batch_index = 0
    batch_total = int(math.ceil(float(len(contexts)) / float(model.bs)))
    for context_id, context_sentences in enumerate(contexts):
        # Convert contexts into list of ids
        joined_context = []

        if len(context_sentences) == 0:
            joined_context = [model.eos_sym]
        else:
            joined_context = model.words_to_indices(context_sentences.split())

            if joined_context[0] != model.eos_sym:
                joined_context = [model.eos_sym] + joined_context

            if joined_context[-1] != model.eos_sym:
                joined_context += [model.eos_sym]

        joined_contexts.append(joined_context)

        if len(joined_contexts) == model.bs:
            batch_index = batch_index + 1
            logger.debug("[COMPUTE] - Got batch %d / %d" %
                         (batch_index, batch_total))
            encs = compute_encodings(joined_contexts, model,
                                     model_compute_encoding,
                                     use_second_last_state)
            for i in range(len(encs)):
                dialogue_encodings.append(encs[i])

            joined_contexts = []

    if len(joined_contexts) > 0:
        logger.debug("[COMPUTE] - Got batch %d / %d" %
                     (batch_total, batch_total))
        encs = compute_encodings(joined_contexts, model,
                                 model_compute_encoding, use_second_last_state)
        for i in range(len(encs)):
            dialogue_encodings.append(encs[i])

    return dialogue_encodings
def main():
    args = parse_args()

    # Load state file
    state = prototype_state()
    state_path = args.model_prefix + "_state.pkl"
    with open(state_path) as src:
        state.update(cPickle.load(src))

    # Load dictionary

    # Load dictionaries to convert str to idx and vice-versa
    raw_dict = cPickle.load(open(state['dictionary'], 'r'))

    str_to_idx = dict([(tok, tok_id) for tok, tok_id, _, _ in raw_dict])
    idx_to_str = dict([(tok_id, tok) for tok, tok_id, freq, _ in raw_dict])


    assert len(args.test_file) > 3
    test_contexts = ''
    test_responses = ''
    utterances_to_predict = args.utterances_to_predict
    assert args.utterances_to_predict > 0

    # Is it a pickle file? Then process using model dictionaries..
    if args.test_file[len(args.test_file)-4:len(args.test_file)] == '.pkl':
        test_dialogues = cPickle.load(open(args.test_file, 'r'))
        for test_dialogueid,test_dialogue in enumerate(test_dialogues):
            if test_dialogueid % 100 == 0:
                print 'test_dialogue', test_dialogueid

            utterances = []
            current_utterance = []
            for word in test_dialogue:
                current_utterance += [word]
                if word == state['eos_sym']:
                    utterances += [current_utterance]
                    current_utterance = []

            if args.leave_out_short_dialogues:
                if len(utterances) <= utterances_to_predict+1:
                    continue

            context_utterances = []
            prediction_utterances = []
            for utteranceid, utterance in enumerate(utterances):
                if utteranceid >= len(utterances) - utterances_to_predict:
                    prediction_utterances += utterance
                else:
                    context_utterances += utterance

            if args.max_words_in_context > 0:
                while len(context_utterances) > args.max_words_in_context:
                    del context_utterances[0]


            test_contexts += indices_to_words(idx_to_str, context_utterances) + '\n'
            test_responses += indices_to_words(idx_to_str, prediction_utterances) + '\n'

    else: # Assume it's a text file

        test_dialogues = [[]]
        lines = open(args.test_file, "r").readlines()
        if len(lines):
            test_dialogues = [x.strip() for x in lines]

        for test_dialogueid,test_dialogue in enumerate(test_dialogues):
            if test_dialogueid % 100 == 0:
                print 'test_dialogue', test_dialogueid

            utterances = []
            current_utterance = []
            for word in test_dialogue.split():
                current_utterance += [word]
                if word == state['end_sym_utterance']:
                    utterances += [current_utterance]
                    current_utterance = []

            if args.leave_out_short_dialogues:
                if len(utterances) <= utterances_to_predict+1:
                    continue

            context_utterances = []
            prediction_utterances = []
            for utteranceid, utterance in enumerate(utterances):
                if utteranceid >= len(utterances) - utterances_to_predict:
                    prediction_utterances += utterance
                else:
                    context_utterances += utterance

            if args.max_words_in_context > 0:
                while len(context_utterances) > args.max_words_in_context:
                    del context_utterances[0]


            test_contexts += ' '.join(context_utterances) + '\n'
            test_responses += ' '.join(prediction_utterances) + '\n'


    print('Writing to files...')
    f = open('test_contexts.txt','w')
    f.write(test_contexts)
    f.close()

    f = open('test_responses.txt','w')
    f.write(test_responses)
    f.close()

    print('All done!')
예제 #29
0
def main():
    args = parse_args()

    state = prototype_state()
    with open(args.state) as src:
        state.update(cPickle.load(src))
    state.update(eval("dict({})".format(args.changes)))

    state['sort_k_batches'] = 1
    state['shuffle'] = False
    state['use_infinite_loop'] = False
    state['force_enc_repr_cpu'] = False

    logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")

    rng = numpy.random.RandomState(state['seed'])
    enc_dec = RNNEncoderDecoder(state, rng, skip_init=True, compute_alignment=True)
    enc_dec.build()
    lm_model = enc_dec.create_lm_model()
    lm_model.load(args.model_path)

    indx_word_src = cPickle.load(open(state['word_indx'],'rb'))
    indx_word_trgt = cPickle.load(open(state['word_indx_trgt'], 'rb'))

    if args.mode == "batch":
        data_given = args.src or args.trg
        txt = data_given and not (args.src.endswith(".h5") and args.trg.endswith(".h5"))
        if data_given and not txt:
            state['source'] = [args.src]
            state['target'] = [args.trg]
        if not data_given and not txt:
            logger.info("Using the training data")
        if txt:
            data_iter = BatchBiTxtIterator(state,
                    args.src, indx_word_src, args.trg, indx_word_trgt,
                    state['bs'], raise_unk=not args.allow_unk)
            data_iter.start()
        else:
            data_iter = get_batch_iterator(state)
            data_iter.start(0)

        score_file = open(args.scores, "w") if args.scores else sys.stdout

        scorer = enc_dec.create_scorer(batch=True)

        count = 0
        n_samples = 0
        logger.info('Scoring phrases')
        for i, batch in enumerate(data_iter):
            if batch == None:
                continue
            if args.n_batches >= 0 and i == args.n_batches:
                break

            if args.y_noise:
                y = batch['y']
                random_words = numpy.random.randint(0, 100, y.shape).astype("int64")
                change_mask = numpy.random.binomial(1, args.y_noise, y.shape).astype("int64")
                y = change_mask * random_words + (1 - change_mask) * y
                batch['y'] = y

            st = time.time()
            [scores] = scorer(batch['x'], batch['y'],
                    batch['x_mask'], batch['y_mask'])
            if args.print_probs:
                scores = numpy.exp(scores)
            up_time = time.time() - st
            for s in scores:
                print >>score_file, "{:.5e}".format(float(s))

            n_samples += batch['x'].shape[1]
            count += 1

            if count % 100 == 0:
                score_file.flush()
                logger.debug("Scores flushed")
            logger.debug("{} batches, {} samples, {} per sample; example scores: {}".format(
                count, n_samples, up_time/scores.shape[0], scores[:5]))

        logger.info("Done")
        score_file.flush()
    elif args.mode == "interact":
        scorer = enc_dec.create_scorer()
        while True:
            try:
                compute_probs = enc_dec.create_probs_computer()
                src_line = raw_input('Source sequence: ')
                trgt_line = raw_input('Target sequence: ')
                src_seq = parse_input(state, indx_word_src, src_line, raise_unk=not args.allow_unk, 
                                      unk_sym=state['unk_sym_source'], null_sym=state['null_sym_source'])
                trgt_seq = parse_input(state, indx_word_trgt, trgt_line, raise_unk=not args.allow_unk,
                                       unk_sym=state['unk_sym_target'], null_sym=state['null_sym_target'])
                print "Binarized source: ", src_seq
                print "Binarized target: ", trgt_seq
                probs = compute_probs(src_seq, trgt_seq)
                print "Probs: {}, cost: {}".format(probs, -numpy.sum(numpy.log(probs)))
            except Exception:
                traceback.print_exc()
    elif args.mode == "txt":
        assert args.src and args.trg
        scorer = enc_dec.create_scorer()
        src_file = open(args.src, "r")
        trg_file = open(args.trg, "r")
        compute_probs = enc_dec.create_probs_computer(return_alignment=True)
        try:
            numpy.set_printoptions(precision=3, linewidth=150, suppress=True)
            i = 0
            while True:
                src_line = next(src_file).strip()
                trgt_line = next(trg_file).strip()
                src_seq, src_words = parse_input(state,
                        indx_word_src, src_line, raise_unk=not args.allow_unk,
                        unk_sym=state['unk_sym_source'], null_sym=state['null_sym_source'])
                trgt_seq, trgt_words = parse_input(state,
                        indx_word_trgt, trgt_line, raise_unk=not args.allow_unk,
                        unk_sym=state['unk_sym_target'], null_sym=state['null_sym_target'])
                probs, alignment = compute_probs(src_seq, trgt_seq)
                if args.verbose:
                    print "Probs: ", probs.flatten()
                    if alignment.ndim == 3:
                        print "Alignment:".ljust(20), src_line, "<eos>"
                        for i, word in enumerate(trgt_words):
                            print "{}{}".format(word.ljust(20), alignment[i, :, 0])
                        print "Generated by:"
                        for i, word in enumerate(trgt_words):
                            j = numpy.argmax(alignment[i, :, 0])
                            print "{} <--- {}".format(word,
                                    src_words[j] if j < len(src_words) else "<eos>")
                i += 1
                if i % 100 == 0:
                    sys.stdout.flush()
                    logger.debug(i)
                print -numpy.sum(numpy.log(probs))
        except StopIteration:
            pass
    else:
        raise Exception("Unknown mode {}".format(args.mode))
예제 #30
0
def main():
    args = parse_args()
    state = prototype_state()

    state_path = args.model_prefix + "_state.pkl"
    model_path = args.model_prefix + "_model.npz"

    with open(state_path) as src:
        state.update(cPickle.load(src))

    logging.basicConfig(
        level=getattr(logging, state['level']),
        format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")

    model = DialogEncoderDecoder(state)

    if os.path.isfile(model_path):
        logger.debug("Loading previous model")
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")

    contexts = [[]]
    lines = open(args.dialogues, "r").readlines()
    if len(lines):
        contexts = [x.strip().split('\t') for x in lines]

    model_compute_encoding = model.build_encoder_function()
    dialogue_encodings = []

    # Start loop
    joined_contexts = []
    batch_index = 0
    batch_total = int(math.ceil(float(len(contexts)) / float(model.bs)))
    for context_id, context_sentences in enumerate(contexts):

        # Convert contextes into list of ids
        joined_context = []

        if len(context_sentences) == 0:
            joined_context = [model.eos_sym]
        else:
            for sentence in context_sentences:
                sentence_ids = model.words_to_indices(sentence.split())
                # Add sos and eos tokens
                joined_context += [model.sos_sym
                                   ] + sentence_ids + [model.eos_sym]

        # HACK
        for i in range(0, 50):
            joined_context += [model.sos_sym] + [0] + [model.eos_sym]

        joined_contexts.append(joined_context)

        if len(joined_contexts) == model.bs:
            batch_index = batch_index + 1
            logger.debug("[COMPUTE] - Got batch %d / %d" %
                         (batch_index, batch_total))
            encs = compute_encodings(joined_contexts, model,
                                     model_compute_encoding,
                                     args.use_second_last_state)
            for i in range(len(encs)):
                dialogue_encodings.append(encs[i])

            joined_contexts = []

    if len(joined_contexts) > 0:
        logger.debug("[COMPUTE] - Got batch %d / %d" %
                     (batch_total, batch_total))
        encs = compute_encodings(joined_contexts, model,
                                 model_compute_encoding,
                                 args.use_second_last_state)
        for i in range(len(encs)):
            dialogue_encodings.append(encs[i])

    # Save encodings to disc
    cPickle.dump(dialogue_encodings, open(args.output + '.pkl', 'w'))
예제 #31
0
def main():
    ####yawa add
    raw_dict = cPickle.load(open('./Data/Dataset.dict.pkl', 'r'))
    str_to_idx = dict([(tok, tok_id) for tok, tok_id, _, _ in raw_dict])
    idx_to_str = dict([(tok_id, tok) for tok, tok_id, _, _ in raw_dict])
    #########

    args = parse_args()
    state = prototype_state()

    state_path = args.model_prefix + "_state.pkl"
    model_path = args.model_prefix + "_model.npz"

    with open(state_path) as src:
        state.update(cPickle.load(src))

    logging.basicConfig(
        level=getattr(logging, state['level']),
        format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")

    model = DialogEncoderDecoder(state)

    sampler = search.RandomSampler(model)
    if args.beam_search:
        sampler = search.BeamSampler(model)

    if os.path.isfile(model_path):
        logger.debug("Loading previous model")
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")

    contexts = [[]]
    lines = open(args.context, "r").readlines()
    if len(lines):
        contexts = [x.strip() for x in lines]
    #contexts = cPickle.load(open('./Data/Test.dialogues.pkl', 'r'))
    print('Sampling started...')
    context_samples, context_costs, att_weights, att_context = sampler.sample(
        contexts,
        n_samples=args.n_samples,
        n_turns=args.n_turns,
        ignore_unk=args.ignore_unk,
        verbose=args.verbose)
    print('Sampling finished.')
    print('Saving to file...')

    # Write to output file
    output_handle = open(args.output, "w")
    for context_sample in context_samples:
        print >> output_handle, '\t'.join(context_sample)
    outline = ''
    #for att_weight in att_weights:
    #for att_in in att_weight:
    #print >> output_handle, str(att_in)
    print "number of weights:" + str(len(att_weights))
    #for i in range(len(att_weights)):
    #outline = att_weights[0]
    cPickle.dump(att_weights,
                 open('Data/beam_search_2000_2_weight.pkl', 'wb'),
                 protocol=cPickle.HIGHEST_PROTOCOL)
    cPickle.dump(att_context,
                 open('Data/beam_search_2000_2_context.pkl', 'wb'),
                 protocol=cPickle.HIGHEST_PROTOCOL)
    #for i in range(len(att_context)):
    #print att_context[i]
    #print numpy.array(att_weights[0])
    #print type(att_weights[0])
    #aa = numpy.array(att_weights[0])
    #size  = aa.shape[1]
    #bb = aa.reshape(5,5,size/5)
    #print bb.shape

    output_handle.close()
    print('Saving to file finished.')
    print('All done!')