Пример #1
0
    def _load_theano(self):
        """
        Loads models, sets theano shared variables and builds samplers.
        This entails irrevocable binding to a specific GPU.
        """

        from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams
        from theano import shared

        from nmt import (build_sampler, gen_sample)
        from theano_util import (numpy_floatX, load_params, init_theano_params)

        trng = RandomStreams(1234)
        use_noise = shared(numpy_floatX(0.))

        fs_init = []
        fs_next = []

        for model, option in zip(self._models, self._options):
            param_list = numpy.load(model).files
            param_list = dict.fromkeys(
                [key for key in param_list if not key.startswith('adam_')], 0)
            params = load_params(model, param_list)
            tparams = init_theano_params(params)

            # always return alignment at this point
            f_init, f_next = build_sampler(
                tparams, option, use_noise, trng, return_alignment=True)

            fs_init.append(f_init)
            fs_next.append(f_next)

        return trng, fs_init, fs_next, gen_sample
Пример #2
0
    def _load_theano(self):
        """
        Loads models, sets theano shared variables and builds samplers.
        This entails irrevocable binding to a specific GPU.
        """
        from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams
        from theano import shared

        from nmt import (build_sampler, build_multi_sampler, gen_sample)
        from theano_util import (numpy_floatX, load_params, init_theano_params)

        trng = RandomStreams(1234)
        use_noise = shared(numpy_floatX(0.))

        fs_init = []
        fs_next = []

        for model, option in zip(self._models, self._options):

            # check compatibility with multisource
            if option["multisource_type"] is not None and len(
                    option['extra_sources']) == 0:
                logging.error(
                    "This model is multi-source but no auxiliary source file was provided."
                )
                sys.exit(1)
            elif option["multisource_type"] is None and len(
                    option['extra_sources']) != 0:
                logging.warn(
                    "You provided an auxiliary input but this model is not multi-source. Ignoring extra input."
                )

            param_list = numpy.load(model).files
            param_list = dict.fromkeys(
                [key for key in param_list if not key.startswith('adam_')], 0)
            params = load_params(model, param_list)
            tparams = init_theano_params(params)

            # always return alignment at this point
            if option['multisource_type'] is not None:
                f_init, f_next = build_multi_sampler(tparams,
                                                     option,
                                                     use_noise,
                                                     trng,
                                                     return_alignment=True)
            else:
                f_init, f_next = build_sampler(tparams,
                                               option,
                                               use_noise,
                                               trng,
                                               return_alignment=True)

            fs_init.append(f_init)
            fs_next.append(f_next)

        return trng, fs_init, fs_next, gen_sample
Пример #3
0
def translate_model(queue, rqueue, pid, models, options, k, normalization_alpha, verbose, nbest, return_alignment, suppress_unk, return_hyp_graph, deviceid):

    # if the --device-list argument is set
    if deviceid != '':
        import os
        theano_flags = os.environ['THEANO_FLAGS'].split(',')
        exist = False
        for i in xrange(len(theano_flags)):
            if theano_flags[i].strip().startswith('device'):
                exist = True
                theano_flags[i] = '%s=%s' % ('device', deviceid)
                break
        if exist == False:
            theano_flags.append('%s=%s' % ('device', deviceid))
        os.environ['THEANO_FLAGS'] = ','.join(theano_flags)

    from theano_util import (floatX, numpy_floatX, load_params, init_theano_params)
    from nmt import (build_sampler, gen_sample, init_params)

    from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams
    from theano import shared
    trng = RandomStreams(1234)
    use_noise = shared(numpy_floatX(0.))

    fs_init = []
    fs_next = []

    for model, option in zip(models, options):
        # load model parameters and set theano shared variables
        param_list = numpy.load(model).files
        param_list = dict.fromkeys([key for key in param_list if not key.startswith('adam_')], 0)
        params = load_params(model, param_list)
        tparams = init_theano_params(params)

        # word index
        f_init, f_next = build_sampler(tparams, option, use_noise, trng, return_alignment=return_alignment)

        fs_init.append(f_init)
        fs_next.append(f_next)

    def _translate(seq):
        # sample given an input sequence and obtain scores
        sample, score, word_probs, alignment, hyp_graph = gen_sample(fs_init, fs_next,
                                   numpy.array(seq).T.reshape([len(seq[0]), len(seq), 1]),
                                   trng=trng, k=k, maxlen=200,
                                   stochastic=False, argmax=False, return_alignment=return_alignment,
                                   suppress_unk=suppress_unk, return_hyp_graph=return_hyp_graph)

        # normalize scores according to sequence lengths
        if normalization_alpha:
            adjusted_lengths = numpy.array([len(s) ** normalization_alpha for s in sample])
            score = score / adjusted_lengths
        if nbest:
            return sample, score, word_probs, alignment, hyp_graph
        else:
            sidx = numpy.argmin(score)
            return sample[sidx], score[sidx], word_probs[sidx], alignment[sidx], hyp_graph

    while True:
        req = queue.get()
        if req is None:
            break

        idx, x = req[0], req[1]
        if verbose:
            sys.stderr.write('{0} - {1}\n'.format(pid,idx))
        seq = _translate(x)

        rqueue.put((idx, seq))

    return