コード例 #1
0
ファイル: train.py プロジェクト: tangyaohua/ProperNouns
def main():
    args = parse_args()

    state = eval(args.proto)()
    if args.state:
        if args.state.endswith(".py"):
            state.update(eval(open(args.state).read()))
        else:
            with open(args.state) as src:
                state.update(cPickle.load(src))
    for change in args.changes:
        state.update(eval("dict({})".format(change)))

    logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")
    logger.debug("State:\n{}".format(pprint.pformat(state)))

    rng = numpy.random.RandomState(state['seed'])
    enc_dec = RNNEncoderDecoder(state, rng, args.skip_init)
    enc_dec.build()
    lm_model = enc_dec.create_lm_model()

    logger.debug("Load data")
    train_data = get_batch_iterator(state)
    logger.debug("Compile trainer")
    algo = eval(state['algo'])(lm_model, state, train_data)
    logger.debug("Run training")
    main = MainLoop(train_data, None, None, lm_model, algo, state, None,
            reset=state['reset'],
            hooks=[RandomSamplePrinter(state, lm_model, train_data)]
                if state['hookFreq'] >= 0
                else None)
    if state['reload']:
        main.load()
    if state['loopIters'] > 0:
        main.main()
コード例 #2
0
ファイル: segment.py プロジェクト: tangyaohua/ProperNouns
def get_models():
    args = parse_args()

    state_en2fr = prototype_state()
    if hasattr(args, 'state_en2fr'):
        with open(args.state_en2fr) as src:
            state_en2fr.update(cPickle.load(src))
    state_en2fr.update(eval("dict({})".format(args.changes)))

    state_fr2en = prototype_state()
    if hasattr(args, 'state_fr2en') and args.state_fr2en is not None:
        with open(args.state_fr2en) as src:
            state_fr2en.update(cPickle.load(src))
    state_fr2en.update(eval("dict({})".format(args.changes)))

    rng = numpy.random.RandomState(state_en2fr['seed'])
    enc_dec_en_2_fr = RNNEncoderDecoder(state_en2fr, rng, skip_init=True)
    enc_dec_en_2_fr.build()
    lm_model_en_2_fr = enc_dec_en_2_fr.create_lm_model()
    lm_model_en_2_fr.load(args.model_path_en2fr)
    indx_word_src = cPickle.load(open(state_en2fr['word_indx'],'rb'))
    indx_word_trgt = cPickle.load(open(state_en2fr['word_indx_trgt'], 'rb'))

    if hasattr(args, 'state_fr2en') and args.state_fr2en is not None:
        rng = numpy.random.RandomState(state_fr2en['seed'])
        enc_dec_fr_2_en = RNNEncoderDecoder(state_fr2en, rng, skip_init=True)
        enc_dec_fr_2_en.build()
        lm_model_fr_2_en = enc_dec_fr_2_en.create_lm_model()
        lm_model_fr_2_en.load(args.model_path_fr2en)

        return [lm_model_en_2_fr, enc_dec_en_2_fr, indx_word_src, indx_word_trgt, state_en2fr, \
            lm_model_fr_2_en, enc_dec_fr_2_en, state_fr2en]
    else:
        return [lm_model_en_2_fr, enc_dec_en_2_fr, indx_word_src, indx_word_trgt, state_en2fr,\
                None, None, None]
コード例 #3
0
def main():
    args = parse_args()

    state = getattr(experiments.nmt, args.proto)()
    if args.state:
        if args.state.endswith(".py"):
            state.update(eval(open(args.state).read()))
        else:
            with open(args.state) as src:
                state.update(cPickle.load(src))
    for change in args.changes:
        state.update(eval("dict({})".format(change)))

    logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")
    logger.debug("State:\n{}".format(pprint.pformat(state)))

    rng = numpy.random.RandomState(state['seed'])
    enc_dec = RNNEncoderDecoder(state, rng, args.skip_init)
    enc_dec.build()
    lm_model = enc_dec.create_lm_model()

    logger.debug("Load data")
    train_data = get_batch_iterator(state)
    logger.debug("Compile trainer")
    algo = eval(state['algo'])(lm_model, state, train_data)
    logger.debug("Run training")
    main = MainLoop(train_data, None, None, lm_model, algo, state, None,
            reset=state['reset'],
            hooks=[RandomSamplePrinter(state, lm_model, train_data)]
                if state['hookFreq'] >= 0
                else None,
            valid=validate_translation)
    if state['reload']:
        main.load()
    if state['loopIters'] > 0:
        main.main()
コード例 #4
0
ファイル: score.py プロジェクト: tangyaohua/ProperNouns
def main():
    args = parse_args()

    state = prototype_state()
    with open(args.state) as src:
        state.update(cPickle.load(src))
    state.update(eval("dict({})".format(args.changes)))

    state['sort_k_batches'] = 1
    state['shuffle'] = False
    state['use_infinite_loop'] = False
    state['force_enc_repr_cpu'] = False

    logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")

    rng = numpy.random.RandomState(state['seed'])
    enc_dec = RNNEncoderDecoder(state, rng, skip_init=True, compute_alignment=True)
    enc_dec.build()
    lm_model = enc_dec.create_lm_model()
    lm_model.load(args.model_path)

    indx_word_src = cPickle.load(open(state['word_indx'],'rb'))
    indx_word_trgt = cPickle.load(open(state['word_indx_trgt'], 'rb'))

    if args.mode == "batch":
        data_given = args.src or args.trg
        txt = data_given and not (args.src.endswith(".h5") and args.trg.endswith(".h5"))
        if data_given and not txt:
            state['source'] = [args.src]
            state['target'] = [args.trg]
        if not data_given and not txt:
            logger.info("Using the training data")
        if txt:
            data_iter = BatchBiTxtIterator(state,
                    args.src, indx_word_src, args.trg, indx_word_trgt,
                    state['bs'], raise_unk=not args.allow_unk)
            data_iter.start()
        else:
            data_iter = get_batch_iterator(state)
            data_iter.start(0)

        score_file = open(args.scores, "w") if args.scores else sys.stdout

        scorer = enc_dec.create_scorer(batch=True)

        count = 0
        n_samples = 0
        logger.info('Scoring phrases')
        for i, batch in enumerate(data_iter):
            if batch == None:
                continue
            if args.n_batches >= 0 and i == args.n_batches:
                break

            if args.y_noise:
                y = batch['y']
                random_words = numpy.random.randint(0, 100, y.shape).astype("int64")
                change_mask = numpy.random.binomial(1, args.y_noise, y.shape).astype("int64")
                y = change_mask * random_words + (1 - change_mask) * y
                batch['y'] = y

            st = time.time()
            [scores] = scorer(batch['x'], batch['y'],
                    batch['x_mask'], batch['y_mask'])
            if args.print_probs:
                scores = numpy.exp(scores)
            up_time = time.time() - st
            for s in scores:
                print >>score_file, "{:.5e}".format(float(s))

            n_samples += batch['x'].shape[1]
            count += 1

            if count % 100 == 0:
                score_file.flush()
                logger.debug("Scores flushed")
            logger.debug("{} batches, {} samples, {} per sample; example scores: {}".format(
                count, n_samples, up_time/scores.shape[0], scores[:5]))

        logger.info("Done")
        score_file.flush()
    elif args.mode == "interact":
        scorer = enc_dec.create_scorer()
        while True:
            try:
                compute_probs = enc_dec.create_probs_computer()
                src_line = raw_input('Source sequence: ')
                trgt_line = raw_input('Target sequence: ')
                src_seq = parse_input(state, indx_word_src, src_line, raise_unk=not args.allow_unk, 
                                      unk_sym=state['unk_sym_source'], null_sym=state['null_sym_source'])
                trgt_seq = parse_input(state, indx_word_trgt, trgt_line, raise_unk=not args.allow_unk,
                                       unk_sym=state['unk_sym_target'], null_sym=state['null_sym_target'])
                print "Binarized source: ", src_seq
                print "Binarized target: ", trgt_seq
                probs = compute_probs(src_seq, trgt_seq)
                print "Probs: {}, cost: {}".format(probs, -numpy.sum(numpy.log(probs)))
            except Exception:
                traceback.print_exc()
    elif args.mode == "txt":
        assert args.src and args.trg
        scorer = enc_dec.create_scorer()
        src_file = open(args.src, "r")
        trg_file = open(args.trg, "r")
        compute_probs = enc_dec.create_probs_computer(return_alignment=True)
        try:
            numpy.set_printoptions(precision=3, linewidth=150, suppress=True)
            i = 0
            while True:
                src_line = next(src_file).strip()
                trgt_line = next(trg_file).strip()
                src_seq, src_words = parse_input(state,
                        indx_word_src, src_line, raise_unk=not args.allow_unk,
                        unk_sym=state['unk_sym_source'], null_sym=state['null_sym_source'])
                trgt_seq, trgt_words = parse_input(state,
                        indx_word_trgt, trgt_line, raise_unk=not args.allow_unk,
                        unk_sym=state['unk_sym_target'], null_sym=state['null_sym_target'])
                probs, alignment = compute_probs(src_seq, trgt_seq)
                if args.verbose:
                    print "Probs: ", probs.flatten()
                    if alignment.ndim == 3:
                        print "Alignment:".ljust(20), src_line, "<eos>"
                        for i, word in enumerate(trgt_words):
                            print "{}{}".format(word.ljust(20), alignment[i, :, 0])
                        print "Generated by:"
                        for i, word in enumerate(trgt_words):
                            j = numpy.argmax(alignment[i, :, 0])
                            print "{} <--- {}".format(word,
                                    src_words[j] if j < len(src_words) else "<eos>")
                i += 1
                if i % 100 == 0:
                    sys.stdout.flush()
                    logger.debug(i)
                print -numpy.sum(numpy.log(probs))
        except StopIteration:
            pass
    else:
        raise Exception("Unknown mode {}".format(args.mode))
コード例 #5
0
def main():
    args = parse_args()

    state = prototype_state()
    with open(args.state) as src:
        state.update(cPickle.load(src))
    state.update(eval("dict({})".format(args.changes)))

    logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")

    rng = numpy.random.RandomState(state['seed'])
    enc_dec = RNNEncoderDecoder(state, rng, skip_init=True)
    enc_dec.build()
    lm_model = enc_dec.create_lm_model()
    lm_model.load(args.model_path)
    indx_word = cPickle.load(open(state['word_indx'],'rb'))

    sampler = None
    beam_search = None
    if args.beam_search:
        beam_search = BeamSearch(enc_dec)
        beam_search.compile()
    else:
        sampler = enc_dec.create_sampler(many_samples=True)

    idict_src = cPickle.load(open(state['indx_word'],'r'))

    if args.source and args.trans:
        # Actually only beam search is currently supported here
        assert beam_search
        assert args.beam_size

        fsrc = open(args.source, 'r')
        ftrans = open(args.trans, 'w')

        start_time = time.time()

        n_samples = args.beam_size
        total_cost = 0.0
        logging.debug("Beam size: {}".format(n_samples))
        for i, line in enumerate(fsrc):
            seqin = line.strip()
            seq, parsed_in = parse_input(state, indx_word, seqin, idx2word=idict_src)
            if args.verbose:
                print "Parsed Input:", parsed_in
            trans, costs, _ = sample(lm_model, seq, n_samples, sampler=sampler,
                    beam_search=beam_search, ignore_unk=args.ignore_unk, normalize=args.normalize)
            best = numpy.argmin(costs)
            print >>ftrans, trans[best]
            if args.verbose:
                print "Translation:", trans[best]
            total_cost += costs[best]
            if (i + 1)  % 100 == 0:
                ftrans.flush()
                logger.debug("Current speed is {} per sentence".
                        format((time.time() - start_time) / (i + 1)))
        print "Total cost of the translations: {}".format(total_cost)

        fsrc.close()
        ftrans.close()
    else:
        while True:
            try:
                seqin = raw_input('Input Sequence: ')
                n_samples = int(raw_input('How many samples? '))
                alpha = None
                if not args.beam_search:
                    alpha = float(raw_input('Inverse Temperature? '))
                seq,parsed_in = parse_input(state, indx_word, seqin, idx2word=idict_src)
                print "Parsed Input:", parsed_in
            except Exception:
                print "Exception while parsing your input:"
                traceback.print_exc()
                continue

            sample(lm_model, seq, n_samples, sampler=sampler,
                    beam_search=beam_search,
                    ignore_unk=args.ignore_unk, normalize=args.normalize,
                    alpha=alpha, verbose=True)
コード例 #6
0
ファイル: sample.py プロジェクト: tangyaohua/ProperNouns
def main():
    args = parse_args()

    state = prototype_state()
    with open(args.state) as src:
        state.update(cPickle.load(src))
    state.update(eval("dict({})".format(args.changes)))

    logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")

    rng = numpy.random.RandomState(state['seed'])
    enc_dec = RNNEncoderDecoder(state, rng, skip_init=True)
    enc_dec.build()
    lm_model = enc_dec.create_lm_model()
    lm_model.load(args.model_path)
    indx_word = cPickle.load(open(state['word_indx'],'rb'))

    sampler = None
    beam_search = None
    if args.beam_search:
        beam_search = BeamSearch(enc_dec)
        beam_search.compile()
    else:
        sampler = enc_dec.create_sampler(many_samples=True)

    idict_src = cPickle.load(open(state['indx_word'],'r'))

    if args.source and args.trans:
        # Actually only beam search is currently supported here
        assert beam_search
        assert args.beam_size

        fsrc = open(args.source, 'r')
        ftrans = open(args.trans, 'w')

        start_time = time.time()

        n_samples = args.beam_size
        total_cost = 0.0
        logging.debug("Beam size: {}".format(n_samples))
        for i, line in enumerate(fsrc):
            seqin = line.strip()
            seq, parsed_in = parse_input(state, indx_word, seqin, idx2word=idict_src)
            if args.verbose:
                print "Parsed Input:", parsed_in
            trans, costs, _ = sample(lm_model, seq, n_samples, sampler=sampler,
                    beam_search=beam_search, ignore_unk=args.ignore_unk, normalize=args.normalize)
            best = numpy.argmin(costs)
            print >>ftrans, trans[best]
            if args.verbose:
                print "Translation:", trans[best]
            total_cost += costs[best]
            if (i + 1)  % 100 == 0:
                ftrans.flush()
                logger.debug("Current speed is {} per sentence".
                        format((time.time() - start_time) / (i + 1)))
        print "Total cost of the translations: {}".format(total_cost)

        fsrc.close()
        ftrans.close()
    else:
        while True:
            try:
                seqin = raw_input('Input Sequence: ')
                n_samples = int(raw_input('How many samples? '))
                alpha = None
                if not args.beam_search:
                    alpha = float(raw_input('Inverse Temperature? '))
                seq,parsed_in = parse_input(state, indx_word, seqin, idx2word=idict_src)
                print "Parsed Input:", parsed_in
            except Exception:
                print "Exception while parsing your input:"
                traceback.print_exc()
                continue

            sample(lm_model, seq, n_samples, sampler=sampler,
                    beam_search=beam_search,
                    ignore_unk=args.ignore_unk, normalize=args.normalize,
                    alpha=alpha, verbose=True)
コード例 #7
0
ファイル: tree.py プロジェクト: tangyaohua/ProperNouns
def main():
    args = parse_args()

    state = getattr(experiments.nmt, args.state_fn)()
    if hasattr(args, 'state') and args.state:
        with open(args.state) as src:
            state.update(cPickle.load(src))
    state.update(eval("dict({})".format(args.changes)))

    assert state['enc_rec_layer'] == "RecursiveConvolutionalLayer", "Only works with gated recursive convolutional encoder"

    logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s")

    rng = numpy.random.RandomState(state['seed'])
    enc_dec = RNNEncoderDecoder(state, rng, skip_init=True)
    enc_dec.build()
    lm_model = enc_dec.create_lm_model()
    lm_model.load(args.model_path)

    indx_word = cPickle.load(open(state['word_indx'],'rb'))
    idict_src = cPickle.load(open(state['indx_word'],'r'))

    x = TT.lvector()
    h = TT.tensor3()

    proj_x = theano.function([x], enc_dec.encoder.input_embedders[0](
        enc_dec.encoder.approx_embedder(x)).out, name='proj_x')
    new_h, gater = enc_dec.encoder.transitions[0].step_fprop(
        None, h, return_gates = True)
    step_up = theano.function([h], [new_h, gater], name='gater_step')

    while True:
        try:
            seqin = raw_input('Input Sequence: ')
            seq,parsed_in = parse_input(state, indx_word, seqin, idx2word=idict_src)
            print "Parsed Input:", parsed_in
        except Exception:
            print "Exception while parsing your input:"
            traceback.print_exc()
            continue

        # get the initial embedding
        new_h = proj_x(seq)
        new_h = new_h.reshape(new_h.shape[0], 1, new_h.shape[1])

        nodes = numpy.arange(len(seq)).tolist()
        node_idx = len(seq)-1
        rules = []
        nodes_level = copy.deepcopy(nodes)

        G = nx.DiGraph()

        input_nodes = []
        merge_nodes = []
        aggregate_nodes = []

        nidx = 0 
        vpos = 0
        nodes_pos = {}
        nodes_labels = {}
        # input nodes
        for nn in nodes[:-1]:
            nidx += 1
            G.add_node(nn, pos=(nidx, 0), ndcolor="blue", label="%d"%nn)
            nodes_pos[nn] = (nidx, vpos)
            nodes_labels[nn] = idict_src[seq[nidx-1]]
            input_nodes.append(nn)
        node_idx = len(seq) - 1

        vpos += 6
        for dd in xrange(len(seq)-1):
            new_h, gater = step_up(new_h)
            decisions = numpy.argmax(gater, -1)
            new_nodes_level = numpy.zeros(len(seq) - (dd+1))
            hpos = float(len(seq)+1) - 0.5 * (dd+1)
            last_node = True
            for nn in xrange(len(seq)-(dd+1)):
                hpos -= 1
                if not last_node:
                    # merge nodes
                    node_idx += 1
                    G.add_node(node_idx, ndcolor="red", label="m")
                    nodes_labels[node_idx] = ""
                    nodes_pos[node_idx] = (hpos, vpos)
                    G.add_edge(nodes_level[-(nn+1)], node_idx, weight=gater[-(nn+1),0,0])
                    G.add_edge(nodes_level[-(nn+2)], node_idx, weight=gater[-(nn+1),0,0])
                    merge_nodes.append(node_idx)

                    merge_node = node_idx
                    # linear aggregation nodes
                    node_idx += 1
                    G.add_node(node_idx, ndcolor="red", label="")
                    nodes_labels[node_idx] = "$+$"
                    nodes_pos[node_idx] = (hpos, vpos+6)
                    G.add_edge(merge_node, node_idx, weight=gater[-(nn+1),0,0])
                    G.add_edge(nodes_level[-(nn+2)], node_idx, weight=gater[-(nn+1),0,1])
                    G.add_edge(nodes_level[-(nn+1)], node_idx, weight=gater[-(nn+1),0,2])
                    aggregate_nodes.append(node_idx)

                    new_nodes_level[-(nn+1)] = node_idx
                last_node = False
            nodes_level = copy.deepcopy(new_nodes_level)
            vpos += 12

        # TODO: Show only strong edges.
        threshold = float(raw_input('Threshold: '))
        edges = [(u,v,d) for (u,v,d) in G.edges(data=True) if d['weight'] > threshold]
        #edges = G.edges(data=True)

        use_weighting = raw_input('Color according to weight [Y/N]: ')
        if use_weighting == 'Y':
            cm = plt.get_cmap('binary') 
            cNorm  = colors.Normalize(vmin=0., vmax=1.)
            scalarMap = cmx.ScalarMappable(norm=cNorm, cmap=cm)
            colorList = [scalarMap.to_rgba(d['weight']) for (u,v,d) in edges]
        else:
            colorList = 'k'

        nx.draw_networkx_nodes(G, pos=nodes_pos, nodelist=input_nodes, node_color='white', alpha=1., edge_color='white')
        nx.draw_networkx_nodes(G, pos=nodes_pos, nodelist=merge_nodes, node_color='blue', alpha=0.8, node_size=20)
        nx.draw_networkx_nodes(G, pos=nodes_pos, nodelist=aggregate_nodes, node_color='red', alpha=0.8, node_size=80)
        nx.draw_networkx_edges(G, pos=nodes_pos, edge_color=colorList, edgelist=edges)
        nx.draw_networkx_labels(G,pos=nodes_pos,labels=nodes_labels,font_family='sans-serif')
        plt.axis('off')
        figname = raw_input('Save to: ')
        if figname[-3:] == "pdf":
            plt.savefig(figname, type='pdf')
        else:
            plt.savefig(figname)
        plt.close()
        G.clear()