Beispiel #1
0
def get_config_multiWay():

    cgs = ['fi_en', 'de_en', 'en_de']
    enc_ids, dec_ids = get_enc_dec_ids(cgs)

    # Model related
    config = prototype_config_multiCG_08(cgs)
    config['saveto'] = 'multiWay'

    # Vocabulary/dataset related
    basedir = ''
    config['src_vocabs'] = get_paths(enc_ids, src_vocabs, basedir)
    config['trg_vocabs'] = get_paths(dec_ids, trg_vocabs, basedir)
    config['src_datas'] = get_paths(cgs, src_datas, basedir)
    config['trg_datas'] = get_paths(cgs, trg_datas, basedir)

    # Early stopping based on bleu related
    config['save_freq'] = 5000
    config['bleu_script'] = basedir + '/multi-bleu.perl'
    config['val_sets'] = get_paths(cgs, val_sets_src, basedir)
    config['val_set_grndtruths'] = get_paths(cgs, val_sets_ref, basedir)
    config['val_set_outs'] = get_val_set_outs(config['cgs'], config['saveto'])
    config['val_burn_in'] = 1

    # Validation set for log probs related
    config['log_prob_sets'] = get_paths(cgs, log_prob_sets, basedir)

    return ReadOnlyDict(config)
Beispiel #2
0
def get_config_multiWay():

    cgs = ['fi_en', 'de_en', 'en_de']
    enc_ids, dec_ids = get_enc_dec_ids(cgs)

    # Model related
    config = prototype_config_multiCG_08(cgs)
    config['saveto'] = 'multiWay'

    # Vocabulary/dataset related
    basedir = ''
    config['src_vocabs'] = get_paths(enc_ids, src_vocabs, basedir)
    config['trg_vocabs'] = get_paths(dec_ids, trg_vocabs, basedir)
    config['src_datas'] = get_paths(cgs, src_datas, basedir)
    config['trg_datas'] = get_paths(cgs, trg_datas, basedir)

    # Early stopping based on bleu related
    config['save_freq'] = 5000
    config['bleu_script'] = basedir + '/multi-bleu.perl'
    config['val_sets'] = get_paths(cgs, val_sets_src, basedir)
    config['val_set_grndtruths'] = get_paths(cgs, val_sets_ref, basedir)
    config['val_set_outs'] = get_val_set_outs(config['cgs'], config['saveto'])
    config['val_burn_in'] = 1

    # Validation set for log probs related
    config['log_prob_sets'] = get_paths(cgs, log_prob_sets, basedir)

    return ReadOnlyDict(config)
Beispiel #3
0
def get_config_single():

    cgs = ['de_en']
    config = prototype_config_multiCG_08(cgs)
    enc_ids, dec_ids = get_enc_dec_ids(cgs)
    config['saveto'] = 'single'

    basedir = ''
    config['batch_sizes'] = OrderedDict([('de_en', 80)])
    config['schedule'] = OrderedDict([('de_en', 12)])
    config['src_vocabs'] = get_paths(enc_ids, src_vocabs, basedir)
    config['trg_vocabs'] = get_paths(dec_ids, trg_vocabs, basedir)
    config['src_datas'] = get_paths(cgs, src_datas, basedir)
    config['trg_datas'] = get_paths(cgs, trg_datas, basedir)
    config['save_freq'] = 5000
    config['val_burn_in'] = 60000
    config['bleu_script'] = basedir + '/multi-bleu.perl'
    config['val_sets'] = get_paths(cgs, val_sets_src, basedir)
    config['val_set_grndtruths'] = get_paths(cgs, val_sets_ref, basedir)
    config['val_set_outs'] = get_val_set_outs(config['cgs'], config['saveto'])
    config['log_prob_sets'] = get_paths(cgs, log_prob_sets, basedir)

    return ReadOnlyDict(config)
Beispiel #4
0
def get_config_single():

    cgs = ['de_en']
    config = prototype_config_multiCG_08(cgs)
    enc_ids, dec_ids = get_enc_dec_ids(cgs)
    config['saveto'] = 'single'

    basedir = ''
    config['batch_sizes'] = OrderedDict([('de_en', 80)])
    config['schedule'] = OrderedDict([('de_en', 12)])
    config['src_vocabs'] = get_paths(enc_ids, src_vocabs, basedir)
    config['trg_vocabs'] = get_paths(dec_ids, trg_vocabs, basedir)
    config['src_datas'] = get_paths(cgs, src_datas, basedir)
    config['trg_datas'] = get_paths(cgs, trg_datas, basedir)
    config['save_freq'] = 5000
    config['val_burn_in'] = 60000
    config['bleu_script'] = basedir + '/multi-bleu.perl'
    config['val_sets'] = get_paths(cgs, val_sets_src, basedir)
    config['val_set_grndtruths'] = get_paths(cgs, val_sets_ref, basedir)
    config['val_set_outs'] = get_val_set_outs(config['cgs'], config['saveto'])
    config['log_prob_sets'] = get_paths(cgs, log_prob_sets, basedir)

    return ReadOnlyDict(config)
Beispiel #5
0
def main(config, model, normalize=False, n_process=5, chr_level=False,
         cgs_to_translate=None, n_best=1, zero_shot=False, test=False):

    trng = RandomStreams(config['seed'] if 'seed' in config else 1234)
    enc_ids, dec_ids = get_enc_dec_ids(config['cgs'])
    iternum = re.search('(?<=iter)[0-9]+', model)

    # Translate only the chosen cgs if they are valid
    if cgs_to_translate is None:
        cgs_to_translate = config['cgs']

    # Determine the version, this is for backward compatibility
    version = get_version(config)

    # Check if computational graphs are valid
    if not set(config['cgs']) >= set(cgs_to_translate) and not zero_shot:
        raise ValueError('{} not a subset of {}!'.format(
            cgs_to_translate, config['cgs']))

    # Check if zero shot computational graph is valid
    if zero_shot:
        if len(cgs_to_translate) > 1:
            raise ValueError('Only one cg can be translated for zero shot')
        if p_(cgs_to_translate[0])[0] not in enc_ids or \
                p_(cgs_to_translate[0])[1] not in dec_ids:
            raise ValueError('Zero shot is not valid for {}'
                             .format(cgs_to_translate[0]))
        config['cgs'] += cgs_to_translate

    # Create Theano variables
    floatX = theano.config.floatX
    src_sel = tensor.matrix('src_selector', dtype=floatX)
    trg_sel = tensor.matrix('trg_selector', dtype=floatX)
    x_sampling = tensor.matrix('source', dtype='int64')
    y_sampling = tensor.vector('target', dtype='int64')
    prev_state = tensor.matrix('prev_state', dtype=floatX)

    # Create encoder-decoder architecture
    logger.info('Creating encoder-decoder')
    enc_dec = EncoderDecoder(
        encoder=MultiEncoder(enc_ids=enc_ids, **config),
        decoder=MultiDecoder(**config))

    # Allocate parameters
    enc_dec.init_params()

    # Build sampling models
    logger.info('Building sampling models')
    f_inits, f_nexts, f_next_states = enc_dec.build_sampling_models(
        x_sampling, y_sampling, src_sel, trg_sel, prev_state, trng=trng)

    # Load parameters
    logger.info('Loading parameters')
    enc_dec.load_params(model)

    # Output translation file names to be returned
    translations = {}

    # Iterate over computational graphs
    for cg_name in f_inits.keys():

        enc_name = p_(cg_name)[0]
        dec_name = p_(cg_name)[1]
        enc_idx = enc_ids.index(enc_name)
        dec_idx = dec_ids.index(dec_name)
        f_init = f_inits[cg_name]
        f_next = f_nexts[cg_name]
        f_next_state = f_next_states.get(cg_name, None)
        aux_lm = enc_dec.decoder.decoders[dec_name].aux_lm

        # For monolingual paths do not perform any translations
        if enc_name == dec_name or cg_name not in cgs_to_translate:
            logger.info('Passing the validation of computational graph [{}]'
                        .format(cg_name))
            continue

        logger.info('Validating computational graph [{}]'.format(cg_name))

        # Change output filename
        if zero_shot:
            config['val_set_outs'][cg_name] += '_zeroShot'

        # Get input and output file names
        source_file = config['val_sets'][cg_name]
        saveto = config['val_set_outs'][cg_name]
        saveto = saveto + '{}_{}'.format(
            '' if iternum is None else '_iter' + iternum.group(),
            'nbest' if n_best > 1 else 'BLEU')

        # pass if output exists
        if len([ff for ff in os.listdir(config['saveto'])
                if ff.startswith(os.path.basename(saveto))]):
            logger.info('Output file {}* exists, skipping'.format(saveto))
            continue

        # Prepare source vocabs and files, make sure special tokens are there
        src_vocab = cPickle.load(open(config['src_vocabs'][enc_name]))
        src_vocab['<S>'] = 0
        src_vocab['</S>'] = config['src_eos_idxs'][enc_name]
        src_vocab['<UNK>'] = config['unk_id']

        # Invert dictionary
        src_ivocab = dict()
        for kk, vv in src_vocab.iteritems():
            src_ivocab[vv] = kk

        # Prepare target vocabs and files, make sure special tokens are there
        trg_vocab = cPickle.load(open(config['trg_vocabs'][dec_name]))
        trg_vocab['<S>'] = 0
        trg_vocab['</S>'] = config['trg_eos_idxs'][dec_name]
        trg_vocab['<UNK>'] = config['unk_id']

        # Invert dictionary
        trg_ivocab = dict()
        for kk, vv in trg_vocab.iteritems():
            trg_ivocab[vv] = kk

        def _send_jobs(fname):
            with open(fname, 'r') as f:
                for idx, line in enumerate(f):
                    x = words2seqs(
                        line, src_vocab,
                        vocab_size=config['src_vocab_sizes'][enc_name],
                        chr_level=chr_level)
                    queue.put((idx, x))
            return idx+1

        def _finish_processes():
            for midx in xrange(n_process):
                queue.put(None)

        def _retrieve_jobs(n_samples):
            trans = [None] * n_samples
            scores = [None] * n_samples
            for idx in xrange(n_samples):
                resp = rqueue.get()
                trans[resp[0]] = resp[1]
                scores[resp[0]] = resp[2]
                if numpy.mod(idx, 10) == 0:
                    print 'Sample ', (idx+1), '/', n_samples, ' Done'
            return trans, scores

        # Create source and target selector vectors
        src_selector_input = numpy.zeros(
            (1, enc_dec.num_encs)).astype(theano.config.floatX)
        src_selector_input[0, enc_idx] = 1.
        trg_selector_input = numpy.zeros(
            (1, enc_dec.num_decs)).astype(theano.config.floatX)
        trg_selector_input[0, dec_idx] = 1.

        # Actual translation here
        logger.info('Translating ' + source_file + '...')
        val_start_time = time.time()
        if n_process == 1:
            trans = []
            scores = []
            with open(source_file, 'r') as f:
                for idx, line in enumerate(f):
                    if idx % 100 == 0 and idx != 0:
                        logger.info('...translated [{}] lines'.format(idx))
                    seq = words2seqs(
                        line, src_vocab,
                        vocab_size=config['src_vocab_sizes'][enc_name],
                        chr_level=chr_level)
                    _t, _s = _translate(
                        seq, f_init, f_next, trg_vocab['</S>'],
                        src_selector_input, trg_selector_input,
                        config['beam_size'],
                        config.get('cond_init_trg', False),
                        normalize, n_best, f_next_state=f_next_state,
                        version=version, aux_lm=aux_lm)
                    trans.append(_t)
                    scores.append(_s)

        else:
            # Create queues
            queue = Queue()
            rqueue = Queue()
            processes = [None] * n_process
            for midx in xrange(n_process):
                processes[midx] = Process(
                    target=translate_model,
                    args=(queue, rqueue, midx, f_init, f_next,
                          src_selector_input, trg_selector_input,
                          trg_vocab['</S>'], config['beam_size'], normalize,
                          config.get('cond_init_trg', False), n_best),
                    kwargs={'f_next_state': f_next_state,
                            'version': version,
                            'aux_lm': aux_lm})
                processes[midx].start()

            n_samples = _send_jobs(source_file)
            trans, scores = _retrieve_jobs(n_samples)
            _finish_processes()

        logger.info("Validation Took: {} minutes".format(
            float(time.time() - val_start_time) / 60.))

        # Prepare translation outputs and calculate BLEU if necessary
        # Note that, translations are post processed for BPE here
        if n_best == 1:
            trans = seqs2words(trans, trg_vocab, trg_ivocab)
            trans = [tt.replace('@@ ', '') for tt in trans]
            bleu_score = calculate_bleu(
                bleu_script=config['bleu_script'], trans=trans,
                gold=config['val_set_grndtruths'][cg_name])
            saveto += '{}'.format(bleu_score)
        else:
            n_best_trans = []
            for idx, (n_best_tr, score_) in enumerate(zip(trans, scores)):
                sentences = seqs2words(n_best_tr, trg_vocab, trg_ivocab)
                sentences = [tt.replace('@@ ', '') for tt in sentences]
                for ids, trans_ in enumerate(sentences):
                    n_best_trans.append(
                        '|||'.join(
                            ['{}'.format(idx), trans_,
                             '{}'.format(score_[ids])]))
            trans = n_best_trans

        # Write to file
        with open(saveto, 'w') as f:
            print >>f, '\n'.join(trans)
        translations[cg_name] = saveto
    return translations, saveto
Beispiel #6
0
def train(config, tr_stream, dev_stream, logprob_stream):

    trng = RandomStreams(config['seed'] if 'seed' in config else 1234)
    enc_ids, dec_ids = get_enc_dec_ids(config['cgs'])

    # Create Theano variables
    floatX = theano.config.floatX
    src_sel = tensor.matrix('src_selector', dtype=floatX)
    trg_sel = tensor.matrix('trg_selector', dtype=floatX)
    x = tensor.lmatrix('source')
    y = tensor.lmatrix('target')
    x_mask = tensor.matrix('source_mask')
    y_mask = tensor.matrix('target_mask')
    x_sampling = tensor.matrix('source', dtype='int64')
    y_sampling = tensor.vector('target', dtype='int64')
    prev_state = tensor.matrix('prev_state', dtype=floatX)
    src_sel_sampling = tensor.matrix('src_selector', dtype=floatX)
    trg_sel_sampling = tensor.matrix('trg_selector', dtype=floatX)

    # Create encoder-decoder architecture
    enc_dec = EncoderDecoder(encoder=MultiEncoder(enc_ids=enc_ids, **config),
                             decoder=MultiDecoder(**config))

    # Build training computational graphs
    probs, opt_rets = enc_dec.build_models(x, x_mask, y, y_mask, src_sel,
                                           trg_sel)

    # Get costs
    costs = enc_dec.get_costs(probs,
                              y,
                              y_mask,
                              decay_cs=config.get('decay_c', None),
                              opt_rets=opt_rets)

    # Computation graphs
    cgs = enc_dec.get_computational_graphs(costs)

    # Build sampling models
    f_inits, f_nexts, f_next_states = enc_dec.build_sampling_models(
        x_sampling,
        y_sampling,
        src_sel_sampling,
        trg_sel_sampling,
        prev_state,
        trng=trng)

    # Some printing
    enc_dec.print_params(cgs)

    # Get training parameters with optional excludes
    training_params, excluded_params = enc_dec.get_training_params(
        cgs,
        exclude_encs=config['exclude_encs'],
        additional_excludes=config['additional_excludes'],
        readout_only=config.get('readout_only', None),
        train_shared=config.get('train_shared', None))

    # Some more printing
    enc_dec.print_training_params(cgs, training_params)

    # Set up training algorithm
    algorithm = SGDMultiCG(costs=costs,
                           tparams=training_params,
                           drop_input=config['drop_input'],
                           step_rule=config['step_rule'],
                           learning_rate=config['learning_rate'],
                           clip_c=config['step_clipping'],
                           step_rule_kwargs=config.get('step_rule_kwargs', {}))

    # Set up training model
    training_models = OrderedDict()
    for k, v in costs.iteritems():
        training_models[k] = Model(costs[k])

    # Set extensions
    extensions = [
        Timing(after_batch=True),
        FinishAfter(after_n_batches=config['finish_after']),
        CostMonitoringWithMultiCG(after_batch=True),
        Printing(after_batch=True),
        PrintMultiStream(after_batch=True),
        DumpWithMultiCG(saveto=config['saveto'],
                        save_accumulators=config['save_accumulators'],
                        every_n_batches=config['save_freq'],
                        no_blocks=True)
    ]

    # Reload model if necessary
    if config['reload'] and os.path.exists(config['saveto']):
        extensions.append(
            LoadFromDumpMultiCG(saveto=config['saveto'],
                                load_accumulators=config['load_accumulators'],
                                no_blocks=True))

    # Add sampling to computational graphs
    for i, (cg_name, cg) in enumerate(cgs.iteritems()):
        eid, did = p_(cg_name)
        if config['hook_samples'] > 0:
            extensions.append(
                Sampler(f_init=f_inits[cg_name],
                        f_next=f_nexts[cg_name],
                        data_stream=tr_stream,
                        num_samples=config['hook_samples'],
                        src_eos_idx=config['src_eos_idxs'][eid],
                        trg_eos_idx=config['trg_eos_idxs'][did],
                        enc_id=eid,
                        dec_id=did,
                        every_n_batches=config['sampling_freq'],
                        cond_init_trg=config.get('cond_init_trg', False),
                        f_next_state=f_next_states.get(cg_name, None)))

    # Save parameters incrementally without overwriting
    if config.get('incremental_dump', False):
        extensions.append(
            IncrementalDump(saveto=config['saveto'],
                            burnin=config['val_burn_in'],
                            every_n_batches=config['save_freq']))

    # Compute log probability on dev set
    if 'log_prob_freq' in config:
        extensions.append(
            LogProbComputer(cgs=config['cgs'],
                            f_log_probs=enc_dec.build_f_log_probs(
                                probs, x, x_mask, y, y_mask, src_sel, trg_sel),
                            streams=logprob_stream,
                            every_n_batches=config['log_prob_freq']))

    # Initialize main loop
    main_loop = MainLoopWithMultiCGnoBlocks(models=training_models,
                                            algorithm=algorithm,
                                            data_stream=tr_stream,
                                            extensions=extensions,
                                            num_encs=config['num_encs'],
                                            num_decs=config['num_decs'])

    # Train!
    main_loop.run()

    # Be patient, after a month :-)
    print 'done'
def main(config, ref_encs=None, ref_decs=None, ref_att=None,
         ref_enc_embs=None, ref_dec_embs=None):

    # Create Theano variables
    floatX = theano.config.floatX
    src_sel = tensor.matrix('src_selector', dtype=floatX)
    trg_sel = tensor.matrix('trg_selector', dtype=floatX)
    x = tensor.lmatrix('source')
    y = tensor.lmatrix('target')
    x_mask = tensor.matrix('source_mask')
    y_mask = tensor.matrix('target_mask')

    # for multi source - maximum is 5 for now
    xs = [tensor.lmatrix('source%d' % i) for i in range(5)]
    x_masks = [tensor.matrix('source%d_mask' % i) for i in range(5)]

    # Create encoder-decoder architecture, and initialize
    logger.info('Creating encoder-decoder')
    enc_ids, dec_ids = get_enc_dec_ids(config['cgs'])
    enc_dec = EncoderDecoder(
        encoder=MultiEncoder(enc_ids=enc_ids, **config),
        decoder=MultiDecoder(**config))
    enc_dec.build_models(x, x_mask, y, y_mask, src_sel, trg_sel,
                         xs=xs, x_masks=x_masks)

    # load reference encoder models
    r_encs = {}
    if ref_encs is not None:
        for eid, path in ref_encs.items():
            logger.info('... ref-enc[{}] loading [{}]'.format(eid, path))
            r_encs[eid] = dict(numpy.load(path))

    # load reference decoder models
    r_decs = {}
    if ref_decs is not None:
        for did, path in ref_decs.items():
            logger.info('... ref-dec[{}] loading [{}]'.format(did, path))
            r_decs[did] = dict(numpy.load(path))

    # load reference model for the shared components
    if ref_att is not None:
        logger.info('... ref-shared loading [{}]'.format(ref_att))
        r_att = dict(numpy.load(ref_att))

    num_params_set = 0
    params_set = {k: 0 for k in enc_dec.get_params().keys()}

    # set encoder parameters of target model
    for eid, rparams in r_encs.items():
        logger.info(' Setting encoder [{}] parameters ...'.format(eid))
        tparams = enc_dec.encoder.encoders[eid].tparams
        for pname, pval in tparams.items():
            set_tparam(tparams[pname], rparams[pname])
            params_set[pname] += 1
            num_params_set += 1
        set_tparam(enc_dec.encoder.tparams['ctx_embedder_%s_W' % eid],
                   rparams['ctx_embedder_%s_W' % eid])
        set_tparam(enc_dec.encoder.tparams['ctx_embedder_%s_b' % eid],
                   rparams['ctx_embedder_%s_b' % eid])
        params_set['ctx_embedder_%s_W' % eid] += 1
        params_set['ctx_embedder_%s_b' % eid] += 1
        num_params_set += 2

    # set decoder parameters of target model
    for did, rparams in r_decs.items():
        logger.info(' Setting decoder [{}] parameters ...'.format(did))
        tparams = enc_dec.decoder.decoders[did].tparams
        for pname, pval in tparams.items():
            set_tparam(tparams[pname], rparams[pname])
            params_set[pname] += 1
            num_params_set += 1

    # set shared component parameters of target model
    if ref_att is not None:
        logger.info(' Setting shared parameters ...')
        shared_enc, shared_params = enc_dec.decoder._get_shared_params()
        for pname in shared_params.keys():
            set_tparam(enc_dec.decoder.tparams[pname], r_att[pname])
            params_set[pname] += 1
            num_params_set += 1

    # set encoder embeddings
    if ref_enc_embs is not None:
        logger.info(' Setting encoder embeddings ...')
        for eid, path in ref_enc_embs.items():
            pname = 'Wemb_%s' % eid
            logger.info(' ... [{}]-[{}]'.format(did, pname))
            emb = numpy.load(path)[pname]
            set_tparam(enc_dec.encoder.tparams[pname], emb)
            params_set[pname] += 1
            num_params_set += 1

    # set decoder embeddings
    if ref_dec_embs is not None:
        logger.info(' Setting decoder embeddings ...')
        for did, path in ref_dec_embs.items():
            pname = 'Wemb_dec_%s' % did
            logger.info(' ... [{}]-[{}]'.format(did, pname))
            emb = numpy.load(path)[pname]
            set_tparam(enc_dec.decoder.tparams[pname], emb)
            params_set[pname] += 1
            num_params_set += 1

    logger.info(' Saving initialized params to [{}/.params.npz]'
                .format(config['saveto']))
    if not os.path.exists(config['saveto']):
        os.makedirs(config['saveto'])

    numpy.savez('{}/params.npz'.format(config['saveto']),
                **tparams_asdict(enc_dec.get_params()))
    logger.info(' Total number of params    : [{}]'
                .format(len(enc_dec.get_params())))
    logger.info(' Total number of params set: [{}]'.format(num_params_set))
    logger.info(' Duplicates [{}]'.format(
        [k for k, v in params_set.items() if v > 1]))
    logger.info(' Unset (random) [{}]'.format(
        [k for k, v in params_set.items() if v == 0]))
    logger.info(' Set {}'.format(
        [k for k, v in params_set.items() if v > 0]))
Beispiel #8
0
def train(config, tr_stream, dev_stream, logprob_stream):

    trng = RandomStreams(config['seed'] if 'seed' in config else 1234)
    enc_ids, dec_ids = get_enc_dec_ids(config['cgs'])

    # Create Theano variables
    floatX = theano.config.floatX
    src_sel = tensor.matrix('src_selector', dtype=floatX)
    trg_sel = tensor.matrix('trg_selector', dtype=floatX)
    x = tensor.lmatrix('source')
    y = tensor.lmatrix('target')
    x_mask = tensor.matrix('source_mask')
    y_mask = tensor.matrix('target_mask')
    x_sampling = tensor.matrix('source', dtype='int64')
    y_sampling = tensor.vector('target', dtype='int64')
    prev_state = tensor.matrix('prev_state', dtype=floatX)
    src_sel_sampling = tensor.matrix('src_selector', dtype=floatX)
    trg_sel_sampling = tensor.matrix('trg_selector', dtype=floatX)

    # Create encoder-decoder architecture
    enc_dec = EncoderDecoder(
        encoder=MultiEncoder(enc_ids=enc_ids, **config),
        decoder=MultiDecoder(**config))

    # Build training computational graphs
    probs, opt_rets = enc_dec.build_models(
        x, x_mask, y, y_mask, src_sel, trg_sel)

    # Get costs
    costs = enc_dec.get_costs(probs, y, y_mask,
                              decay_cs=config.get('decay_c', None),
                              opt_rets=opt_rets)

    # Computation graphs
    cgs = enc_dec.get_computational_graphs(costs)

    # Build sampling models
    f_inits, f_nexts, f_next_states = enc_dec.build_sampling_models(
        x_sampling, y_sampling, src_sel_sampling, trg_sel_sampling, prev_state,
        trng=trng)

    # Some printing
    enc_dec.print_params(cgs)

    # Get training parameters with optional excludes
    training_params, excluded_params = enc_dec.get_training_params(
        cgs, exclude_encs=config['exclude_encs'],
        additional_excludes=config['additional_excludes'],
        readout_only=config.get('readout_only', None),
        train_shared=config.get('train_shared', None))

    # Some more printing
    enc_dec.print_training_params(cgs, training_params)

    # Set up training algorithm
    algorithm = SGDMultiCG(
        costs=costs, tparams=training_params, drop_input=config['drop_input'],
        step_rule=config['step_rule'], learning_rate=config['learning_rate'],
        clip_c=config['step_clipping'],
        step_rule_kwargs=config.get('step_rule_kwargs', {}))

    # Set up training model
    training_models = OrderedDict()
    for k, v in costs.iteritems():
        training_models[k] = Model(costs[k])

    # Set extensions
    extensions = [
        Timing(after_batch=True),
        FinishAfter(after_n_batches=config['finish_after']),
        CostMonitoringWithMultiCG(after_batch=True),
        Printing(after_batch=True),
        PrintMultiStream(after_batch=True),
        DumpWithMultiCG(saveto=config['saveto'],
                        save_accumulators=config['save_accumulators'],
                        every_n_batches=config['save_freq'],
                        no_blocks=True)]

    # Reload model if necessary
    if config['reload'] and os.path.exists(config['saveto']):
        extensions.append(
            LoadFromDumpMultiCG(saveto=config['saveto'],
                                load_accumulators=config['load_accumulators'],
                                no_blocks=True))

    # Add sampling to computational graphs
    for i, (cg_name, cg) in enumerate(cgs.iteritems()):
        eid, did = p_(cg_name)
        if config['hook_samples'] > 0:
            extensions.append(Sampler(
                f_init=f_inits[cg_name], f_next=f_nexts[cg_name],
                data_stream=tr_stream, num_samples=config['hook_samples'],
                src_eos_idx=config['src_eos_idxs'][eid],
                trg_eos_idx=config['trg_eos_idxs'][did],
                enc_id=eid, dec_id=did,
                every_n_batches=config['sampling_freq'],
                cond_init_trg=config.get('cond_init_trg', False),
                f_next_state=f_next_states.get(cg_name, None)))

    # Save parameters incrementally without overwriting
    if config.get('incremental_dump', False):
        extensions.append(
            IncrementalDump(saveto=config['saveto'],
                            burnin=config['val_burn_in'],
                            every_n_batches=config['save_freq']))

    # Compute log probability on dev set
    if 'log_prob_freq' in config:
        extensions.append(
            LogProbComputer(
                cgs=config['cgs'],
                f_log_probs=enc_dec.build_f_log_probs(
                    probs, x, x_mask, y, y_mask, src_sel, trg_sel),
                streams=logprob_stream,
                every_n_batches=config['log_prob_freq']))

    # Initialize main loop
    main_loop = MainLoopWithMultiCGnoBlocks(
        models=training_models,
        algorithm=algorithm,
        data_stream=tr_stream,
        extensions=extensions,
        num_encs=config['num_encs'],
        num_decs=config['num_decs'])

    # Train!
    main_loop.run()

    # Be patient, after a month :-)
    print 'done'
Beispiel #9
0
def prototype_config_multiCG_08(cgs):

    enc_ids, dec_ids = get_enc_dec_ids(cgs)

    # Model related
    config = {}
    config['cgs'] = cgs
    config['num_encs'] = len(enc_ids)
    config['num_decs'] = len(dec_ids)
    config['seq_len'] = 50
    config['representation_dim'] = 1200  # joint annotation dimension
    config['enc_nhids'] = get_odict(enc_ids, 1000)
    config['dec_nhids'] = get_odict(dec_ids, 1000)
    config['enc_embed_sizes'] = get_odict(enc_ids, 620)
    config['dec_embed_sizes'] = get_odict(dec_ids, 620)

    # Additional options for the model
    config['take_last'] = True
    config['multi_latent'] = True
    config['readout_dim'] = 1000
    config['representation_act'] = 'linear'  # encoder representation act
    config['lencoder_act'] = 'tanh'  # att-encoder latent space act
    config['ldecoder_act'] = 'tanh'  # att-decoder latent space act
    config['dec_rnn_type'] = 'gru_cond_multiEnc_v08'
    config['finit_mid_dim'] = 600
    config['finit_code_dim'] = 500
    config['finit_act'] = 'tanh'
    config['att_dim'] = 1200

    # Optimization related
    config['batch_sizes'] = get_odict(cgs, 60)
    config['sort_k_batches'] = 12
    config['step_rule'] = 'uAdam'
    config['learning_rate'] = 2e-4
    config['step_clipping'] = 1
    config['weight_scale'] = 0.01
    config['schedule'] = get_odict(cgs, 1)
    config['save_accumulators'] = True  # algorithms' update step variables
    config['load_accumulators'] = True  # be careful with this
    config['exclude_encs'] = get_odict(enc_ids, False)
    config['min_seq_lens'] = get_odict(cgs, 0)
    config['additional_excludes'] = get_odict(cgs, [])

    # Regularization related
    config['drop_input'] = get_odict(cgs, 0.)
    config['decay_c'] = get_odict(cgs, 0.)
    config['alpha_c'] = get_odict(cgs, 0.)
    config['weight_noise_ff'] = False
    config['weight_noise_rec'] = False
    config['dropout'] = 1.0

    # Vocabulary related
    config['src_vocab_sizes'] = get_odict(enc_ids, 30000)
    config['trg_vocab_sizes'] = get_odict(dec_ids, 30000)
    config['src_eos_idxs'] = get_odict(enc_ids, 0)
    config['trg_eos_idxs'] = get_odict(dec_ids, 0)
    config['stream'] = 'multiCG_stream'
    config['unk_id'] = 1

    # Early stopping based on bleu related
    config['normalized_bleu'] = True
    config['track_n_models'] = 3
    config['output_val_set'] = True
    config['beam_size'] = 12

    # Validation set for log probs related
    config['log_prob_freq'] = 2000
    config['log_prob_bs'] = 10

    # Timing related
    config['reload'] = True
    config['save_freq'] = 10000
    config['sampling_freq'] = 17
    config['bleu_val_freq'] = 10000000
    config['val_burn_in'] = 1
    config['finish_after'] = 2000000
    config['incremental_dump'] = True

    # Monitoring related
    config['hook_samples'] = 2
    config['plot'] = False
    config['bokeh_port'] = 3333

    return config
Beispiel #10
0
def main(config,
         model,
         normalize=False,
         n_process=5,
         chr_level=False,
         cgs_to_translate=None,
         n_best=1,
         zero_shot=False,
         test=False):

    trng = RandomStreams(config['seed'] if 'seed' in config else 1234)
    enc_ids, dec_ids = get_enc_dec_ids(config['cgs'])
    iternum = re.search('(?<=iter)[0-9]+', model)

    # Translate only the chosen cgs if they are valid
    if cgs_to_translate is None:
        cgs_to_translate = config['cgs']

    # Determine the version, this is for backward compatibility
    version = get_version(config)

    # Check if computational graphs are valid
    if not set(config['cgs']) >= set(cgs_to_translate) and not zero_shot:
        raise ValueError('{} not a subset of {}!'.format(
            cgs_to_translate, config['cgs']))

    # Check if zero shot computational graph is valid
    if zero_shot:
        if len(cgs_to_translate) > 1:
            raise ValueError('Only one cg can be translated for zero shot')
        if p_(cgs_to_translate[0])[0] not in enc_ids or \
                p_(cgs_to_translate[0])[1] not in dec_ids:
            raise ValueError('Zero shot is not valid for {}'.format(
                cgs_to_translate[0]))
        config['cgs'] += cgs_to_translate

    # Create Theano variables
    floatX = theano.config.floatX
    src_sel = tensor.matrix('src_selector', dtype=floatX)
    trg_sel = tensor.matrix('trg_selector', dtype=floatX)
    x_sampling = tensor.matrix('source', dtype='int64')
    y_sampling = tensor.vector('target', dtype='int64')
    prev_state = tensor.matrix('prev_state', dtype=floatX)

    # Create encoder-decoder architecture
    logger.info('Creating encoder-decoder')
    enc_dec = EncoderDecoder(encoder=MultiEncoder(enc_ids=enc_ids, **config),
                             decoder=MultiDecoder(**config))

    # Allocate parameters
    enc_dec.init_params()

    # Build sampling models
    logger.info('Building sampling models')
    f_inits, f_nexts, f_next_states = enc_dec.build_sampling_models(x_sampling,
                                                                    y_sampling,
                                                                    src_sel,
                                                                    trg_sel,
                                                                    prev_state,
                                                                    trng=trng)

    # Load parameters
    logger.info('Loading parameters')
    enc_dec.load_params(model)

    # Output translation file names to be returned
    translations = {}

    # Iterate over computational graphs
    for cg_name in f_inits.keys():

        enc_name = p_(cg_name)[0]
        dec_name = p_(cg_name)[1]
        enc_idx = enc_ids.index(enc_name)
        dec_idx = dec_ids.index(dec_name)
        f_init = f_inits[cg_name]
        f_next = f_nexts[cg_name]
        f_next_state = f_next_states.get(cg_name, None)
        aux_lm = enc_dec.decoder.decoders[dec_name].aux_lm

        # For monolingual paths do not perform any translations
        if enc_name == dec_name or cg_name not in cgs_to_translate:
            logger.info(
                'Passing the validation of computational graph [{}]'.format(
                    cg_name))
            continue

        logger.info('Validating computational graph [{}]'.format(cg_name))

        # Change output filename
        if zero_shot:
            config['val_set_outs'][cg_name] += '_zeroShot'

        # Get input and output file names
        source_file = config['val_sets'][cg_name]
        saveto = config['val_set_outs'][cg_name]
        saveto = saveto + '{}_{}'.format(
            '' if iternum is None else '_iter' + iternum.group(),
            'nbest' if n_best > 1 else 'BLEU')

        # pass if output exists
        if len([
                ff for ff in os.listdir(config['saveto'])
                if ff.startswith(os.path.basename(saveto))
        ]):
            logger.info('Output file {}* exists, skipping'.format(saveto))
            continue

        # Prepare source vocabs and files, make sure special tokens are there
        src_vocab = cPickle.load(open(config['src_vocabs'][enc_name]))
        src_vocab['<S>'] = 0
        src_vocab['</S>'] = config['src_eos_idxs'][enc_name]
        src_vocab['<UNK>'] = config['unk_id']

        # Invert dictionary
        src_ivocab = dict()
        for kk, vv in src_vocab.iteritems():
            src_ivocab[vv] = kk

        # Prepare target vocabs and files, make sure special tokens are there
        trg_vocab = cPickle.load(open(config['trg_vocabs'][dec_name]))
        trg_vocab['<S>'] = 0
        trg_vocab['</S>'] = config['trg_eos_idxs'][dec_name]
        trg_vocab['<UNK>'] = config['unk_id']

        # Invert dictionary
        trg_ivocab = dict()
        for kk, vv in trg_vocab.iteritems():
            trg_ivocab[vv] = kk

        def _send_jobs(fname):
            with open(fname, 'r') as f:
                for idx, line in enumerate(f):
                    x = words2seqs(
                        line,
                        src_vocab,
                        vocab_size=config['src_vocab_sizes'][enc_name],
                        chr_level=chr_level)
                    queue.put((idx, x))
            return idx + 1

        def _finish_processes():
            for midx in xrange(n_process):
                queue.put(None)

        def _retrieve_jobs(n_samples):
            trans = [None] * n_samples
            scores = [None] * n_samples
            for idx in xrange(n_samples):
                resp = rqueue.get()
                trans[resp[0]] = resp[1]
                scores[resp[0]] = resp[2]
                if numpy.mod(idx, 10) == 0:
                    print 'Sample ', (idx + 1), '/', n_samples, ' Done'
            return trans, scores

        # Create source and target selector vectors
        src_selector_input = numpy.zeros(
            (1, enc_dec.num_encs)).astype(theano.config.floatX)
        src_selector_input[0, enc_idx] = 1.
        trg_selector_input = numpy.zeros(
            (1, enc_dec.num_decs)).astype(theano.config.floatX)
        trg_selector_input[0, dec_idx] = 1.

        # Actual translation here
        logger.info('Translating ' + source_file + '...')
        val_start_time = time.time()
        if n_process == 1:
            trans = []
            scores = []
            with open(source_file, 'r') as f:
                for idx, line in enumerate(f):
                    if idx % 100 == 0 and idx != 0:
                        logger.info('...translated [{}] lines'.format(idx))
                    seq = words2seqs(
                        line,
                        src_vocab,
                        vocab_size=config['src_vocab_sizes'][enc_name],
                        chr_level=chr_level)
                    _t, _s = _translate(seq,
                                        f_init,
                                        f_next,
                                        trg_vocab['</S>'],
                                        src_selector_input,
                                        trg_selector_input,
                                        config['beam_size'],
                                        config.get('cond_init_trg', False),
                                        normalize,
                                        n_best,
                                        f_next_state=f_next_state,
                                        version=version,
                                        aux_lm=aux_lm)
                    trans.append(_t)
                    scores.append(_s)

        else:
            # Create queues
            queue = Queue()
            rqueue = Queue()
            processes = [None] * n_process
            for midx in xrange(n_process):
                processes[midx] = Process(
                    target=translate_model,
                    args=(queue, rqueue, midx, f_init, f_next,
                          src_selector_input, trg_selector_input,
                          trg_vocab['</S>'], config['beam_size'], normalize,
                          config.get('cond_init_trg', False), n_best),
                    kwargs={
                        'f_next_state': f_next_state,
                        'version': version,
                        'aux_lm': aux_lm
                    })
                processes[midx].start()

            n_samples = _send_jobs(source_file)
            trans, scores = _retrieve_jobs(n_samples)
            _finish_processes()

        logger.info("Validation Took: {} minutes".format(
            float(time.time() - val_start_time) / 60.))

        # Prepare translation outputs and calculate BLEU if necessary
        # Note that, translations are post processed for BPE here
        if n_best == 1:
            trans = seqs2words(trans, trg_vocab, trg_ivocab)
            trans = [tt.replace('@@ ', '') for tt in trans]
            bleu_score = calculate_bleu(
                bleu_script=config['bleu_script'],
                trans=trans,
                gold=config['val_set_grndtruths'][cg_name])
            saveto += '{}'.format(bleu_score)
        else:
            n_best_trans = []
            for idx, (n_best_tr, score_) in enumerate(zip(trans, scores)):
                sentences = seqs2words(n_best_tr, trg_vocab, trg_ivocab)
                sentences = [tt.replace('@@ ', '') for tt in sentences]
                for ids, trans_ in enumerate(sentences):
                    n_best_trans.append('|||'.join(
                        ['{}'.format(idx), trans_, '{}'.format(score_[ids])]))
            trans = n_best_trans

        # Write to file
        with open(saveto, 'w') as f:
            print >> f, '\n'.join(trans)
        translations[cg_name] = saveto
    return translations, saveto
Beispiel #11
0
def main(configs,
         models,
         val_sets,
         output_file,
         n_process=5,
         chr_level=False,
         cgs_to_translate=None,
         gold_file=None,
         bleu_script=None,
         beam_size=7):

    # Translate only the chosen cgs if they are valid
    if cgs_to_translate is None:
        raise ValueError('cgs-to-translate cannot be None')

    # Check if computational graphs are valid
    for cidx, config in enumerate(configs):
        enc_ids, dec_ids = get_enc_dec_ids(config['cgs'])
        enc_ids = add_single_pairs(enc_ids)
        for cg_name in cgs_to_translate.values()[cidx]:
            if cg_name not in config['cgs']:
                eids = p_(cg_name)[0].split('.')
                dids = p_(cg_name)[1].split('.')
                if not all([eid in enc_ids for eid in eids]) or \
                        not all([did in dec_ids for did in dids]):
                    raise ValueError(
                        'Zero shot is NOT valid for [{}]'.format(cg_name))
                logger.info('Zero shot is valid for [{}]'.format(cg_name))
                config['cgs'].append(cg_name)
            else:
                logger.info('Translation is valid for [{}]'.format(cg_name))

    # Create Theano variables
    floatX = theano.config.floatX
    src_sel = tensor.matrix('src_selector', dtype=floatX)
    trg_sel = tensor.matrix('trg_selector', dtype=floatX)
    x_sampling = tensor.matrix('source', dtype='int64')
    y_sampling = tensor.vector('target', dtype='int64')
    prev_state = tensor.matrix('prev_state', dtype=floatX)

    # for multi source - maximum is 10 for now
    xs_sampling = [
        tensor.matrix('source%d' % i, dtype='int64') for i in range(10)
    ]

    # Iterate over multiple models
    enc_dec_dict = OrderedDict()
    f_init_dict = OrderedDict()
    f_next_dict = OrderedDict()
    enc_ids_dict = OrderedDict()
    dec_ids_dict = OrderedDict()
    for mi, (model_id, model_path,
             config) in enumerate(zip(cgs_to_translate.keys(), models,
                                      configs)):

        # Helper variables
        cgs = config['cgs']
        trng = RandomStreams(config['seed'] if 'seed' in config else 1234)
        enc_ids, dec_ids = get_enc_dec_ids_mSrc(cgs)
        enc_ids_dict[model_id] = enc_ids
        dec_ids_dict[model_id] = dec_ids

        # Create encoder-decoder architecture
        logger.info('Creating encoder-decoder for model [{}]'.format(mi))
        enc_dec = EncoderDecoder(encoder=MultiEncoder(
            enc_ids=get_enc_dec_ids(cgs)[0], **config),
                                 decoder=MultiDecoder(**config))

        # Allocate parameters
        enc_dec.init_params()

        # Build sampling models
        logger.info('Building sampling models for model [{}]'.format(mi))
        f_inits, f_nexts, f_next_states = enc_dec.build_sampling_models(
            x_sampling,
            y_sampling,
            src_sel,
            trg_sel,
            prev_state,
            trng=trng,
            xs=xs_sampling)

        # Load parameters
        logger.info('Loading parameters for model [{}]'.format(mi))
        enc_dec.load_params(model_path)

        enc_dec_dict[model_id] = enc_dec
        f_init_dict[model_id] = f_inits
        f_next_dict[model_id] = f_nexts

    # Output translation file names to be returned
    translations = {}

    # Fetch necessary functions and variables from all models
    f_inits_list = []
    f_nexts_list = []
    src_vocabs_list = []
    src_vocabs_sizes_list = []

    source_files = OrderedDict()

    for midx, (model_id, cg_names) in enumerate(cgs_to_translate.items()):
        for cg_name in cg_names:

            config = configs[midx]
            enc_name = p_(cg_name)[0]
            dec_name = p_(cg_name)[1]
            enc_ids = enc_name.split('.')

            f_inits_list.append(f_init_dict[model_id][cg_name])
            f_nexts_list.append(f_next_dict[model_id][cg_name])

            if is_multiSource(cg_name):
                source_files.update(val_sets[cg_name])
            else:
                source_files[enc_name] = val_sets[cg_name]

            src_vocabs = OrderedDict()
            src_vocab_sizes = OrderedDict()

            # This ordering will be abided all the way
            for eid in p_(cg_name)[0].split('.'):
                src_vocabs[eid] = load_vocab(config['src_vocabs'][eid], 0,
                                             config['src_eos_idxs'][eid],
                                             config['unk_id'])
                src_vocab_sizes[eid] = config['src_vocab_sizes'][eid]
            src_vocabs_list.append(src_vocabs)
            src_vocabs_sizes_list.append(src_vocab_sizes)

    saveto = output_file

    # Skip if outputs exist
    if os.path.exists(saveto):
        logger.info('Outputs exist:')
        logger.info(' ... {}'.format(saveto))
        logger.info(' ... skipping')
        return

    logger.info('Output file: [{}]'.format(saveto))

    # Prepare target vocabs and files, make sure special tokens are there
    trg_vocab = load_vocab(configs[0]['trg_vocabs'][dec_name], 0,
                           configs[0]['trg_eos_idxs'][dec_name],
                           configs[0]['unk_id'])

    # Invert dictionary
    trg_ivocab = invert_vocab(trg_vocab)

    # Actual translation here
    for eid, fname in source_files.items():
        logger.info('Translating from [{}]-[{}]...'.format(eid, fname))
    logger.info('Using [{}] processes...'.format(n_process))
    val_start_time = time.time()

    # helper functions for multi-process
    def _send_jobs(source_fnames, source_files, src_vocabs_list,
                   src_vocabs_sizes_list):
        for idx, rows in enumerate(izip(*source_fnames)):
            lines = OrderedDict(zip(source_files.keys(), rows))
            seqs_list = [
                words2seqs_multi_irregular(lines, src_vocabs_list[ii],
                                           src_vocabs_sizes_list[ii])
                for ii, _ in enumerate(src_vocabs_list)
            ]
            queue.put((idx, seqs_list))
        return idx + 1

    def _finish_processes():
        for midx in xrange(n_process):
            queue.put(None)

    def _retrieve_jobs(n_samples):
        trans = [None] * n_samples
        scores = [None] * n_samples
        for idx in xrange(n_samples):
            resp = rqueue.get()
            trans[resp[0]] = resp[1]
            scores[resp[0]] = resp[2]
            if numpy.mod(idx, 10) == 0:
                print 'Sample ', (idx + 1), '/', n_samples, ' Done'
        return trans, scores

    # Open files with the correct ordering
    source_fnames = [
        open(source_files[eid], "r") for eid in source_files.keys()
    ]

    if n_process == 1:

        trans = []
        scores = []

        # Process each line for each source simultaneuosly
        for idx, rows in enumerate(izip(*source_fnames)):
            if idx % 100 == 0 and idx != 0:
                logger.info('...translated [{}] lines'.format(idx))
            lines = OrderedDict(zip(source_files.keys(), rows))
            seqs_list = [
                words2seqs_multi_irregular(lines, src_vocabs_list[ii],
                                           src_vocabs_sizes_list[ii])
                for ii, _ in enumerate(src_vocabs_list)
            ]
            _t, _s = _translate(seqs_list, f_inits_list, f_nexts_list,
                                trg_vocab['</S>'], beam_size)
            trans.append(_t)
            scores.append(_s)

    else:

        # Create queues
        queue = Queue()
        rqueue = Queue()
        processes = [None] * n_process
        for midx in xrange(n_process):
            processes[midx] = Process(target=translate_model,
                                      args=(queue, rqueue, midx, f_inits_list,
                                            f_nexts_list, trg_vocab['</S>'],
                                            beam_size),
                                      kwargs={'f_next_state': f_next_states})
            processes[midx].start()

        n_samples = _send_jobs(source_fnames, source_files, src_vocabs_list,
                               src_vocabs_sizes_list)
        trans, scores = _retrieve_jobs(n_samples)
        _finish_processes()

    logger.info("Validation Took: {} minutes".format(
        float(time.time() - val_start_time) / 60.))

    # Prepare translation outputs and calculate BLEU if necessary
    # Note that, translations are post processed for BPE here
    trans = seqs2words(trans, trg_vocab, trg_ivocab)
    trans = [tt.replace('@@ ', '') for tt in trans]
    if gold_file is not None and os.path.exists(gold_file):
        bleu_score = calculate_bleu(bleu_script=bleu_script,
                                    trans=trans,
                                    gold=gold_file)
        saveto += '{}'.format(bleu_score)

    # Write to file
    with open(saveto, 'w') as f:
        print >> f, '\n'.join(trans)

    translations[cg_name] = saveto
    return translations, saveto
Beispiel #12
0
def prototype_config_multiCG_08(cgs):

    enc_ids, dec_ids = get_enc_dec_ids(cgs)

    # Model related
    config = {}
    config['cgs'] = cgs
    config['num_encs'] = len(enc_ids)
    config['num_decs'] = len(dec_ids)
    config['src_seq_len'] = 50
    config['tgt_seq_len'] = 50
    config['representation_dim'] = 1200  # joint annotation dimension
    config['enc_nhids'] = get_odict(enc_ids, 1000)
    config['dec_nhids'] = get_odict(dec_ids, 1000)
    config['enc_embed_sizes'] = get_odict(enc_ids, 620)
    config['dec_embed_sizes'] = get_odict(dec_ids, 620)

    # Additional options for the model
    config['take_last'] = True
    config['multi_latent'] = True
    config['readout_dim'] = 1000
    config['representation_act'] = 'linear'  # encoder representation act
    config['lencoder_act'] = 'tanh'  # att-encoder latent space act
    config['ldecoder_act'] = 'tanh'  # att-decoder latent space act
    config['dec_rnn_type'] = 'gru_cond_multiEnc_v08'
    config['finit_mid_dim'] = 600
    config['finit_code_dim'] = 500
    config['finit_act'] = 'tanh'
    config['att_dim'] = 1200

    # Optimization related
    config['batch_sizes'] = get_odict(cgs, 60)
    config['sort_k_batches'] = 12
    config['step_rule'] = 'uAdam'
    config['learning_rate'] = 2e-4
    config['step_clipping'] = 1
    config['weight_scale'] = 0.01
    config['schedule'] = get_odict(cgs, 1)
    config['save_accumulators'] = True  # algorithms' update step variables
    config['load_accumulators'] = True  # be careful with this
    config['exclude_encs'] = get_odict(enc_ids, False)
    config['exclude_embs'] = False
    config['min_seq_lens'] = get_odict(cgs, 0)
    config['additional_excludes'] = get_odict(cgs, [])

    # Regularization related
    config['drop_input'] = get_odict(cgs, 0.)
    config['decay_c'] = get_odict(cgs, 0.)
    config['alpha_c'] = get_odict(cgs, 0.)
    config['weight_noise_ff'] = False
    config['weight_noise_rec'] = False
    config['dropout'] = 1.0

    # Vocabulary related
    config['src_vocab_sizes'] = get_odict(enc_ids, 30000)
    config['trg_vocab_sizes'] = get_odict(dec_ids, 30000)
    config['src_eos_idxs'] = get_odict(enc_ids, 0)
    config['trg_eos_idxs'] = get_odict(dec_ids, 0)
    config['stream'] = 'multiCG_stream'
    config['unk_id'] = 1

    # Early stopping based on bleu related
    config['normalized_bleu'] = True
    config['track_n_models'] = 3
    config['output_val_set'] = True
    config['beam_size'] = 12

    # Validation set for log probs related
    config['log_prob_freq'] = 2000
    config['log_prob_bs'] = 10

    # Timing related
    config['reload'] = True
    config['save_freq'] = 10000
    config['sampling_freq'] = 17
    config['bleu_val_freq'] = 10000000
    config['val_burn_in'] = 1
    config['finish_after'] = 2000000
    config['incremental_dump'] = True

    # Monitoring related
    config['hook_samples'] = 2
    config['plot'] = False
    config['bokeh_port'] = 3333

    return config
Beispiel #13
0
def main(config,
         ref_encs=None,
         ref_decs=None,
         ref_att=None,
         ref_enc_embs=None,
         ref_dec_embs=None):

    # Create Theano variables
    floatX = theano.config.floatX
    src_sel = tensor.matrix('src_selector', dtype=floatX)
    trg_sel = tensor.matrix('trg_selector', dtype=floatX)
    x = tensor.lmatrix('source')
    y = tensor.lmatrix('target')
    x_mask = tensor.matrix('source_mask')
    y_mask = tensor.matrix('target_mask')

    # for multi source - maximum is 5 for now
    xs = [tensor.lmatrix('source%d' % i) for i in range(5)]
    x_masks = [tensor.matrix('source%d_mask' % i) for i in range(5)]

    # Create encoder-decoder architecture, and initialize
    logger.info('Creating encoder-decoder')
    enc_ids, dec_ids = get_enc_dec_ids(config['cgs'])
    enc_dec = EncoderDecoder(encoder=MultiEncoder(enc_ids=enc_ids, **config),
                             decoder=MultiDecoder(**config))
    enc_dec.build_models(x,
                         x_mask,
                         y,
                         y_mask,
                         src_sel,
                         trg_sel,
                         xs=xs,
                         x_masks=x_masks)

    # load reference encoder models
    r_encs = {}
    if ref_encs is not None:
        for eid, path in ref_encs.items():
            logger.info('... ref-enc[{}] loading [{}]'.format(eid, path))
            r_encs[eid] = dict(numpy.load(path))

    # load reference decoder models
    r_decs = {}
    if ref_decs is not None:
        for did, path in ref_decs.items():
            logger.info('... ref-dec[{}] loading [{}]'.format(did, path))
            r_decs[did] = dict(numpy.load(path))

    # load reference model for the shared components
    if ref_att is not None:
        logger.info('... ref-shared loading [{}]'.format(ref_att))
        r_att = dict(numpy.load(ref_att))

    num_params_set = 0
    params_set = {k: 0 for k in enc_dec.get_params().keys()}

    # set encoder parameters of target model
    for eid, rparams in r_encs.items():
        logger.info(' Setting encoder [{}] parameters ...'.format(eid))
        tparams = enc_dec.encoder.encoders[eid].tparams
        for pname, pval in tparams.items():
            set_tparam(tparams[pname], rparams[pname])
            params_set[pname] += 1
            num_params_set += 1
        set_tparam(enc_dec.encoder.tparams['ctx_embedder_%s_W' % eid],
                   rparams['ctx_embedder_%s_W' % eid])
        set_tparam(enc_dec.encoder.tparams['ctx_embedder_%s_b' % eid],
                   rparams['ctx_embedder_%s_b' % eid])
        params_set['ctx_embedder_%s_W' % eid] += 1
        params_set['ctx_embedder_%s_b' % eid] += 1
        num_params_set += 2

    # set decoder parameters of target model
    for did, rparams in r_decs.items():
        logger.info(' Setting decoder [{}] parameters ...'.format(did))
        tparams = enc_dec.decoder.decoders[did].tparams
        for pname, pval in tparams.items():
            set_tparam(tparams[pname], rparams[pname])
            params_set[pname] += 1
            num_params_set += 1

    # set shared component parameters of target model
    if ref_att is not None:
        logger.info(' Setting shared parameters ...')
        shared_enc, shared_params = enc_dec.decoder._get_shared_params()
        for pname in shared_params.keys():
            set_tparam(enc_dec.decoder.tparams[pname], r_att[pname])
            params_set[pname] += 1
            num_params_set += 1

    # set encoder embeddings
    if ref_enc_embs is not None:
        logger.info(' Setting encoder embeddings ...')
        for eid, path in ref_enc_embs.items():
            pname = 'Wemb_%s' % eid
            logger.info(' ... [{}]-[{}]'.format(did, pname))
            emb = numpy.load(path)[pname]
            set_tparam(enc_dec.encoder.tparams[pname], emb)
            params_set[pname] += 1
            num_params_set += 1

    # set decoder embeddings
    if ref_dec_embs is not None:
        logger.info(' Setting decoder embeddings ...')
        for did, path in ref_dec_embs.items():
            pname = 'Wemb_dec_%s' % did
            logger.info(' ... [{}]-[{}]'.format(did, pname))
            emb = numpy.load(path)[pname]
            set_tparam(enc_dec.decoder.tparams[pname], emb)
            params_set[pname] += 1
            num_params_set += 1

    logger.info(' Saving initialized params to [{}/.params.npz]'.format(
        config['saveto']))
    if not os.path.exists(config['saveto']):
        os.makedirs(config['saveto'])

    numpy.savez('{}/params.npz'.format(config['saveto']),
                **tparams_asdict(enc_dec.get_params()))
    logger.info(' Total number of params    : [{}]'.format(
        len(enc_dec.get_params())))
    logger.info(' Total number of params set: [{}]'.format(num_params_set))
    logger.info(' Duplicates [{}]'.format(
        [k for k, v in params_set.items() if v > 1]))
    logger.info(' Unset (random) [{}]'.format(
        [k for k, v in params_set.items() if v == 0]))
    logger.info(' Set {}'.format([k for k, v in params_set.items() if v > 0]))