Пример #1
0
def main():
    args = parse_args()
    state = prototype_state()
    seqs = [[]]
    state_path = args.model_prefix + "_state.pkl"
    model_path = args.model_prefix + "_model.npz"

    with open(state_path) as src:
        state.update(cPickle.load(src))
    logging.basicConfig(
        level=getattr(logging, state['level']),
        format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")
    model = SessionEncoderDecoder(state)
    if os.path.isfile(model_path):
        logger.debug("Loading previous model")
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")
    beam_search = BeamSampler(model)
    beam_search.compile()
    for ctx_file in args.ext_file:
        lines = open(ctx_file, "r").readlines()
        seqs = context_to_indices(lines, model)
        sugg_text, sugg_ranks, sugg_costs = \
            sample(model, seqs=seqs, ignore_unk=args.ignore_unk,
                   beam_search=beam_search, n_samples=args.n_samples, session=args.session)
        output_path = ctx_file + "_" + model.state['model_id']
        print_output_suggestions(output_path, lines, sugg_text, sugg_ranks,
                                 sugg_costs)
Пример #2
0
def main():
    args = parse_args()
    state = prototype_state()

    state_path = args.model_prefix + "_state.pkl"
    model_path = args.model_prefix + "_model.npz"

    with open(state_path) as src:
        state.update(cPickle.load(src))

    logging.basicConfig(
        level  = getattr(logging, state['level']),
        format = "%(asctime)s: %(name)s: %(levelname)s: %(message)s")

    model = SessionEncoderDecoder(state)
    sampler = search.BeamSampler(model)

    if os.path.isfile(model_path):
        logger.debug("Loading previous model")
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")
    contexts = [[]]
    lines = open(args.context, "r").readlines()
    contexts = [x.strip().split('\t') for x in lines]
    context_samples, context_costs = sampler.sample(
        contexts, n_samples=args.n_samples, ignore_unk=args.ignore_unk,
        verbose=args.verbose)
    # Write to output file
    output_handle = open(args.context + "_HED_" + model.run_id + ".gen", "w")
    for context_sample in context_samples:
        print >> output_handle, '\t'.join(context_sample)
    output_handle.close()
Пример #3
0
def main():
    args = parse_args()
    state = prototype_state()

    state_path = args.model_prefix + "_state.pkl"
    model_path = args.model_prefix + "_model.npz"

    with open(state_path) as src:
        state.update(cPickle.load(src))

    logging.basicConfig(
        level=getattr(logging, state['level']),
        format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")

    model = SessionEncoderDecoder(state)
    sampler = search.BeamSampler(model)

    if os.path.isfile(model_path):
        logger.debug("Loading previous model")
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")
    contexts = [[]]
    lines = open(args.context, "r").readlines()
    contexts = [x.strip().split('\t') for x in lines]
    context_samples, context_costs = sampler.sample(contexts,
                                                    n_samples=args.n_samples,
                                                    ignore_unk=args.ignore_unk,
                                                    verbose=args.verbose)
    # Write to output file
    output_handle = open(args.context + "_HED_" + model.run_id + ".gen", "w")
    for context_sample in context_samples:
        print >> output_handle, '\t'.join(context_sample)
    output_handle.close()
Пример #4
0
def main():
    args = parse_args()
    state = prototype_state()
    seqs = [[]]
     
    state_path = args.model_prefix + "_state.pkl"
    model_path = args.model_prefix + "_model.npz"

    with open(state_path) as src:
        state.update(cPickle.load(src)) 
    
    logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")
     
    model = SessionEncoderDecoder(state)
    if os.path.isfile(model_path):
        logger.debug("Loading previous model")
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")
     
    beam_search = BeamSearch(model)
    beam_search.compile()
    
    sugg_text, sugg_ranks, sugg_costs = \
        sample(model, seqs=[[]], ignore_unk=args.ignore_unk, 
                beam_search=beam_search, n_samples=args.n_samples)
         
    print sugg_text
Пример #5
0
def main():
    args = parse_args()
    state = prototype_state()
    seqs = [[]]
    state_path = args.model_prefix + "_state.pkl"
    model_path = args.model_prefix + "_model.npz"

    with open(state_path) as src:
        state.update(cPickle.load(src))
    logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")
    model = SessionEncoderDecoder(state)
    if os.path.isfile(model_path):
        logger.debug("Loading previous model")
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")
    beam_search = BeamSampler(model)
    beam_search.compile()
    for ctx_file in args.ext_file:
        lines = open(ctx_file, "r").readlines()
        seqs = context_to_indices(lines, model)
        sugg_text, sugg_ranks, sugg_costs = \
            sample(model, seqs=seqs, ignore_unk=args.ignore_unk,
                   beam_search=beam_search, n_samples=args.n_samples, session=args.session)
        output_path = ctx_file + "_" + model.state['model_id']
        print_output_suggestions(output_path, lines, sugg_text, sugg_ranks, sugg_costs)
Пример #6
0
def main():
    args = parse_args()
    state = prototype_state()

    state_path = args.model_prefix + "_state.pkl"
    model_path = args.model_prefix + "_model.npz"

    with open(state_path) as src:
        state.update(cPickle.load(src))

    logging.basicConfig(
        level=getattr(logging, state['level']),
        format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")

    model = SessionEncoderDecoder(state)
    scorer = Scorer(model)

    if os.path.isfile(model_path):
        logger.debug("Loading previous model")
        model.load(model_path)
    else:
        raise Exception("Must specify a valid model path")

    contexts = [[]]
    lines = open(args.context, "r").readlines()
    contexts = [x.strip().split('\t') for x in lines]

    targets = [[]]
    lines = open(args.targets, "r").readlines()
    targets = [x.strip().split('\t') for x in lines]

    logging.info('Normalizing by length = {}'.format(args.normalize_by_length))
    logging.info('Multi feature = {}'.format(args.multi_feature))

    costs = scorer.score(contexts,
                         targets,
                         verbose=args.verbose,
                         normalize_by_length=args.normalize_by_length,
                         N=args.multi_feature)

    output_handle = open(args.targets + "_HED_" + ("nn_" if not args.normalize_by_length else "") + \
                         model.run_id + (".f" if args.feature_gen else ".gen"), "w")

    if args.feature_gen:
        print >> output_handle, ' '.join(
            ["%d_HED_" % i + model.run_id for i in range(args.multi_feature)])

    for num_target, target in enumerate(targets):
        reranked = numpy.array(target)[numpy.argsort(costs[num_target])]

        if args.feature_gen:
            for cost in numpy.array(costs[num_target]).T:
                print >> output_handle, ' '.join(map(str, cost))
        else:
            print >> output_handle, '\t'.join(reranked)

        output_handle.flush()
    output_handle.close()