Ejemplo n.º 1
0
def main():
    args = parse_args()

    state = eval(args.proto)()
    if args.state:
        if args.state.endswith(".py"):
            state.update(eval(open(args.state).read()))
        else:
            with open(args.state) as src:
                state.update(cPickle.load(src))
    for change in args.changes:
        state.update(eval("dict({})".format(change)))

    logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")
    logger.debug("State:\n{}".format(pprint.pformat(state)))

    rng = numpy.random.RandomState(state['seed'])
    enc_dec = RNNEncoderDecoder(state, rng, args.skip_init)
    enc_dec.build()
    lm_model = enc_dec.create_lm_model()

    logger.debug("Load data")
    train_data = get_batch_iterator(state)
    logger.debug("Compile trainer")
    algo = eval(state['algo'])(lm_model, state, train_data)
    logger.debug("Run training")
    main = MainLoop(train_data, None, None, lm_model, algo, state, None,
            reset=state['reset'],
            hooks=[RandomSamplePrinter(state, lm_model, train_data)]
                if state['hookFreq'] >= 0
                else None)
    if state['reload']:
        main.load()
    if state['loopIters'] > 0:
        main.main()
Ejemplo n.º 2
0
def main():
    args = parse_args()

    state = prototype_state()
    with open(args.state) as src:
        state.update(cPickle.load(src))
    state.update(eval("dict({})".format(args.changes)))

    state['sort_k_batches'] = 1
    state['shuffle'] = False
    state['use_infinite_loop'] = False
    state['force_enc_repr_cpu'] = False

    logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")

    rng = numpy.random.RandomState(state['seed'])
    enc_dec = RNNEncoderDecoder(state, rng, skip_init=True, compute_alignment=True)
    enc_dec.build()
    lm_model = enc_dec.create_lm_model()
    lm_model.load(args.model_path)

    indx_word_src = cPickle.load(open(state['word_indx'],'rb'))
    indx_word_trgt = cPickle.load(open(state['word_indx_trgt'], 'rb'))

    if args.mode == "batch":
        data_given = args.src or args.trg
        txt = data_given and not (args.src.endswith(".h5") and args.trg.endswith(".h5"))
        if data_given and not txt:
            state['source'] = [args.src]
            state['target'] = [args.trg]
        if not data_given and not txt:
            logger.info("Using the training data")
        if txt:
            data_iter = BatchBiTxtIterator(state,
                    args.src, indx_word_src, args.trg, indx_word_trgt,
                    state['bs'], raise_unk=not args.allow_unk)
            data_iter.start()
        else:
            data_iter = get_batch_iterator(state)
            data_iter.start(0)

        score_file = open(args.scores, "w") if args.scores else sys.stdout

        scorer = enc_dec.create_scorer(batch=True)

        count = 0
        n_samples = 0
        logger.info('Scoring phrases')
        for i, batch in enumerate(data_iter):
            if batch == None:
                continue
            if args.n_batches >= 0 and i == args.n_batches:
                break

            if args.y_noise:
                y = batch['y']
                random_words = numpy.random.randint(0, 100, y.shape).astype("int64")
                change_mask = numpy.random.binomial(1, args.y_noise, y.shape).astype("int64")
                y = change_mask * random_words + (1 - change_mask) * y
                batch['y'] = y

            st = time.time()
            [scores] = scorer(batch['x'], batch['y'],
                    batch['x_mask'], batch['y_mask'])
            if args.print_probs:
                scores = numpy.exp(scores)
            up_time = time.time() - st
            for s in scores:
                print >>score_file, "{:.5e}".format(float(s))

            n_samples += batch['x'].shape[1]
            count += 1

            if count % 100 == 0:
                score_file.flush()
                logger.debug("Scores flushed")
            logger.debug("{} batches, {} samples, {} per sample; example scores: {}".format(
                count, n_samples, up_time/scores.shape[0], scores[:5]))

        logger.info("Done")
        score_file.flush()
    elif args.mode == "interact":
        scorer = enc_dec.create_scorer()
        while True:
            try:
                compute_probs = enc_dec.create_probs_computer()
                src_line = raw_input('Source sequence: ')
                trgt_line = raw_input('Target sequence: ')
                src_seq = parse_input(state, indx_word_src, src_line, raise_unk=not args.allow_unk, 
                                      unk_sym=state['unk_sym_source'], null_sym=state['null_sym_source'])
                trgt_seq = parse_input(state, indx_word_trgt, trgt_line, raise_unk=not args.allow_unk,
                                       unk_sym=state['unk_sym_target'], null_sym=state['null_sym_target'])
                print "Binarized source: ", src_seq
                print "Binarized target: ", trgt_seq
                probs = compute_probs(src_seq, trgt_seq)
                print "Probs: {}, cost: {}".format(probs, -numpy.sum(numpy.log(probs)))
            except Exception:
                traceback.print_exc()
    elif args.mode == "txt":
        assert args.src and args.trg
        scorer = enc_dec.create_scorer()
        src_file = open(args.src, "r")
        trg_file = open(args.trg, "r")
        compute_probs = enc_dec.create_probs_computer(return_alignment=True)
        try:
            numpy.set_printoptions(precision=3, linewidth=150, suppress=True)
            i = 0
            while True:
                src_line = next(src_file).strip()
                trgt_line = next(trg_file).strip()
                src_seq, src_words = parse_input(state,
                        indx_word_src, src_line, raise_unk=not args.allow_unk,
                        unk_sym=state['unk_sym_source'], null_sym=state['null_sym_source'])
                trgt_seq, trgt_words = parse_input(state,
                        indx_word_trgt, trgt_line, raise_unk=not args.allow_unk,
                        unk_sym=state['unk_sym_target'], null_sym=state['null_sym_target'])
                probs, alignment = compute_probs(src_seq, trgt_seq)
                if args.verbose:
                    print "Probs: ", probs.flatten()
                    if alignment.ndim == 3:
                        print "Alignment:".ljust(20), src_line, "<eos>"
                        for i, word in enumerate(trgt_words):
                            print "{}{}".format(word.ljust(20), alignment[i, :, 0])
                        print "Generated by:"
                        for i, word in enumerate(trgt_words):
                            j = numpy.argmax(alignment[i, :, 0])
                            print "{} <--- {}".format(word,
                                    src_words[j] if j < len(src_words) else "<eos>")
                i += 1
                if i % 100 == 0:
                    sys.stdout.flush()
                    logger.debug(i)
                print -numpy.sum(numpy.log(probs))
        except StopIteration:
            pass
    else:
        raise Exception("Unknown mode {}".format(args.mode))