class BeamSearchEvaluator(object):
    def __init__(self, eol_symbol, beam_size, x, x_mask, samples,
                 phoneme_dict=None, black_list=None):
        if black_list is None:
            self.black_list = []
        else:
            self.black_list = black_list
        self.x = x
        self.x_mask = x_mask
        self.eol_symbol = eol_symbol
        self.beam_size = beam_size
        self.beam_search = BeamSearch(beam_size, samples)
        self.beam_search.compile()
        self.phoneme_dict = phoneme_dict

    def evaluate(self, data_stream, train=False, file_pred=None,
                 file_targets=None):
        loss = 0.
        num_examples = 0
        iterator = data_stream.get_epoch_iterator()
        if train:
            print 'Train evaluation started'
        i = 0
        for inputs in iterator:
            inputs = dict(zip(data_stream.sources, inputs))
            x_mask_val = inputs['features_mask']
            x_val = inputs['features']
            y_val = inputs['phonemes']
            y_mask_val = inputs['phonemes_mask']
            for batch_ind in xrange(inputs['features'].shape[1]):
                if x_val.ndim == 2:
                    input_beam = numpy.tile(x_val[:, batch_ind][:, None],
                        (1, self.beam_size))
                else:
                    input_beam = numpy.tile(x_val[:, batch_ind, :][:, None, :],
                                            (1, self.beam_size, 1))
                input_mask_beam = numpy.tile(x_mask_val[:, batch_ind][:, None],
                                             (1, self.beam_size))
                predictions, _ = self.beam_search.search(
                    {self.x: input_beam,
                     self.x_mask: input_mask_beam},
                    self.eol_symbol, 100)
                predictions = [self.phoneme_dict[phone_ind] for phone_ind
                             in predictions[0]
                             if self.phoneme_dict[phone_ind] not in
                             self.black_list][1:-1]

                targets = y_val[:sum(y_mask_val[:, batch_ind]), batch_ind]
                targets = [self.phoneme_dict[phone_ind] for phone_ind
                             in targets
                             if self.phoneme_dict[phone_ind] not in
                             self.black_list][1:-1]
                predictions = [x[0] for x in groupby(predictions)]
                targets = [x[0] for x in groupby(targets)]
                i += 1
                if file_pred:
                    file_pred.write(' '.join(predictions) + '(%d)\n' % i)
                if file_targets:
                    file_targets.write(' '.join(targets) + '(%d)\n' %i)

                loss += Evaluation.wer([predictions], [targets])
                num_examples += 1

            print '.. found sequence example:', ' '.join(predictions)
            print '.. real output was:       ', ' '.join(targets)
            if train:
                break
        if train:
            print 'Train evaluation finished'
        per = loss.sum() / num_examples
        return {'per': per}
Ejemplo n.º 2
0
class BeamSearchEvaluator(object):
    def __init__(self,
                 eol_symbol,
                 beam_size,
                 x,
                 x_mask,
                 samples,
                 phoneme_dict=None,
                 black_list=None,
                 language_model=False):
        if black_list is None:
            self.black_list = []
        else:
            self.black_list = black_list
        self.x = x
        self.x_mask = x_mask
        self.eol_symbol = eol_symbol
        self.beam_size = beam_size
        if language_model:
            lm = TrigramLanguageModel()
            ind_to_word = dict(enumerate(lm.unigrams))
            self.beam_search = BeamSearchLM(lm, 1., ind_to_word, beam_size,
                                            samples)
        else:
            self.beam_search = BeamSearch(beam_size, samples)
        self.beam_search.compile()
        self.phoneme_dict = phoneme_dict

    def evaluate(self,
                 data_stream,
                 train=False,
                 file_pred=None,
                 file_targets=None):
        loss = 0.
        num_examples = 0
        iterator = data_stream.get_epoch_iterator()
        if train:
            print 'Train evaluation started'
        i = 0
        for inputs in iterator:
            inputs = dict(zip(data_stream.sources, inputs))
            x_mask_val = inputs['features_mask']
            x_val = inputs['features']
            y_val = inputs['phonemes']
            y_mask_val = inputs['phonemes_mask']
            for batch_ind in xrange(inputs['features'].shape[1]):
                if x_val.ndim == 2:
                    input_beam = numpy.tile(x_val[:, batch_ind][:, None],
                                            (1, self.beam_size))
                else:
                    input_beam = numpy.tile(x_val[:, batch_ind, :][:, None, :],
                                            (1, self.beam_size, 1))
                input_mask_beam = numpy.tile(x_mask_val[:, batch_ind][:, None],
                                             (1, self.beam_size))
                predictions, _ = self.beam_search.search(
                    {
                        self.x: input_beam,
                        self.x_mask: input_mask_beam
                    }, self.eol_symbol, 100)
                predictions = [
                    self.phoneme_dict[phone_ind]
                    for phone_ind in predictions[0]
                    if self.phoneme_dict[phone_ind] not in self.black_list
                ][1:-1]

                targets = y_val[:sum(y_mask_val[:, batch_ind]), batch_ind]
                targets = [
                    self.phoneme_dict[phone_ind] for phone_ind in targets
                    if self.phoneme_dict[phone_ind] not in self.black_list
                ][1:-1]
                predictions = [x[0] for x in groupby(predictions)]
                targets = [x[0] for x in groupby(targets)]
                i += 1
                if file_pred:
                    file_pred.write(' '.join(predictions) + '(%d)\n' % i)
                if file_targets:
                    file_targets.write(' '.join(targets) + '(%d)\n' % i)

                loss += Evaluation.wer([predictions], [targets])
                num_examples += 1

            print '.. found sequence example:', ' '.join(predictions)
            print '.. real output was:       ', ' '.join(targets)
            if train:
                break
        if train:
            print 'Train evaluation finished'
        per = loss.sum() / num_examples
        return {'per': per}
Ejemplo n.º 3
0
class SpeechRecognizer(Initializable):
    """Encapsulate all reusable logic.

    This class plays a few roles: (a) it's a top brick that knows
    how to combine bottom, bidirectional and recognizer network, (b)
    it has the inputs variables and can build whole computation graphs
    starting with them (c) it hides compilation of Theano functions
    and initialization of beam search. I find it simpler to have it all
    in one place for research code.

    Parameters
    ----------
    All defining the structure and the dimensions of the model. Typically
    receives everything from the "net" section of the config.

    """
    def __init__(
        self,
        recordings_source,
        labels_source,
        eos_label,
        num_features,
        num_phonemes,
        dim_dec,
        dims_bidir,
        dims_bottom,
        enc_transition,
        dec_transition,
        use_states_for_readout,
        attention_type,
        lm=None,
        character_map=None,
        subsample=None,
        dims_top=None,
        prior=None,
        conv_n=None,
        bottom_activation=None,
        post_merge_activation=None,
        post_merge_dims=None,
        dim_matcher=None,
        embed_outputs=True,
        dec_stack=1,
        conv_num_filters=1,
        data_prepend_eos=True,
        energy_normalizer=None,  # softmax is th edefault set in SequenceContentAndConvAttention
        **kwargs):
        if bottom_activation is None:
            bottom_activation = Tanh()
        if post_merge_activation is None:
            post_merge_activation = Tanh()
        super(SpeechRecognizer, self).__init__(**kwargs)
        self.recordings_source = recordings_source
        self.labels_source = labels_source
        self.eos_label = eos_label
        self.data_prepend_eos = data_prepend_eos

        self.rec_weights_init = None
        self.initial_states_init = None

        self.enc_transition = enc_transition
        self.dec_transition = dec_transition
        self.dec_stack = dec_stack

        bottom_activation = bottom_activation
        post_merge_activation = post_merge_activation

        if dim_matcher is None:
            dim_matcher = dim_dec

        # The bottom part, before BiRNN
        if dims_bottom:
            bottom = MLP([bottom_activation] * len(dims_bottom),
                         [num_features] + dims_bottom,
                         name="bottom")
        else:
            bottom = Identity(name='bottom')

        # BiRNN
        if not subsample:
            subsample = [1] * len(dims_bidir)
        encoder = Encoder(
            self.enc_transition, dims_bidir,
            dims_bottom[-1] if len(dims_bottom) else num_features, subsample)

        # The top part, on top of BiRNN but before the attention
        if dims_top:
            top = MLP([Tanh()],
                      [2 * dims_bidir[-1]] + dims_top + [2 * dims_bidir[-1]],
                      name="top")
        else:
            top = Identity(name='top')

        if dec_stack == 1:
            transition = self.dec_transition(dim=dim_dec,
                                             activation=Tanh(),
                                             name="transition")
        else:
            transitions = [
                self.dec_transition(dim=dim_dec,
                                    activation=Tanh(),
                                    name="transition_{}".format(trans_level))
                for trans_level in xrange(dec_stack)
            ]
            transition = RecurrentStack(transitions=transitions,
                                        skip_connections=True)
        # Choose attention mechanism according to the configuration
        if attention_type == "content":
            attention = SequenceContentAttention(
                state_names=transition.apply.states,
                attended_dim=2 * dims_bidir[-1],
                match_dim=dim_matcher,
                name="cont_att")
        elif attention_type == "content_and_conv":
            attention = SequenceContentAndConvAttention(
                state_names=transition.apply.states,
                conv_n=conv_n,
                conv_num_filters=conv_num_filters,
                attended_dim=2 * dims_bidir[-1],
                match_dim=dim_matcher,
                prior=prior,
                energy_normalizer=energy_normalizer,
                name="conv_att")
        else:
            raise ValueError(
                "Unknown attention type {}".format(attention_type))
        if embed_outputs:
            feedback = LookupFeedback(num_phonemes + 1, dim_dec)
        else:
            feedback = OneOfNFeedback(num_phonemes + 1)
        if lm:
            # In case we use LM it is Readout that is responsible
            # for normalization.
            emitter = LMEmitter()
        else:
            emitter = SoftmaxEmitter(initial_output=num_phonemes,
                                     name="emitter")
        readout_config = dict(readout_dim=num_phonemes,
                              source_names=(transition.apply.states if
                                            use_states_for_readout else []) +
                              [attention.take_glimpses.outputs[0]],
                              emitter=emitter,
                              feedback_brick=feedback,
                              name="readout")
        if post_merge_dims:
            readout_config['merged_dim'] = post_merge_dims[0]
            readout_config['post_merge'] = InitializableSequence(
                [
                    Bias(post_merge_dims[0]).apply,
                    post_merge_activation.apply,
                    MLP(
                        [post_merge_activation] *
                        (len(post_merge_dims) - 1) + [Identity()],
                        # MLP was designed to support Maxout is activation
                        # (because Maxout in a way is not one). However
                        # a single layer Maxout network works with the trick below.
                        # For deeper Maxout network one has to use the
                        # Sequence brick.
                        [
                            d //
                            getattr(post_merge_activation, 'num_pieces', 1)
                            for d in post_merge_dims
                        ] + [num_phonemes]).apply,
                ],
                name='post_merge')
        readout = Readout(**readout_config)

        language_model = None
        if lm:
            lm_weight = lm.pop('weight', 0.0)
            normalize_am_weights = lm.pop('normalize_am_weights', True)
            normalize_lm_weights = lm.pop('normalize_lm_weights', False)
            normalize_tot_weights = lm.pop('normalize_tot_weights', False)
            am_beta = lm.pop('am_beta', 1.0)
            if normalize_am_weights + normalize_lm_weights + normalize_tot_weights < 1:
                logger.warn(
                    "Beam search is prone to fail with no log-prob normalization"
                )
            language_model = LanguageModel(nn_char_map=character_map, **lm)
            readout = ShallowFusionReadout(
                lm_costs_name='lm_add',
                lm_weight=lm_weight,
                normalize_am_weights=normalize_am_weights,
                normalize_lm_weights=normalize_lm_weights,
                normalize_tot_weights=normalize_tot_weights,
                am_beta=am_beta,
                **readout_config)

        generator = SequenceGenerator(readout=readout,
                                      transition=transition,
                                      attention=attention,
                                      language_model=language_model,
                                      name="generator")

        # Remember child bricks
        self.encoder = encoder
        self.bottom = bottom
        self.top = top
        self.generator = generator
        self.children = [encoder, top, bottom, generator]

        # Create input variables
        self.recordings = tensor.tensor3(self.recordings_source)
        self.recordings_mask = tensor.matrix(self.recordings_source + "_mask")
        self.labels = tensor.lmatrix(self.labels_source)
        self.labels_mask = tensor.matrix(self.labels_source + "_mask")
        self.batch_inputs = [
            self.recordings, self.recordings_source, self.labels,
            self.labels_mask
        ]
        self.single_recording = tensor.matrix(self.recordings_source)
        self.single_transcription = tensor.lvector(self.labels_source)

    def push_initialization_config(self):
        super(SpeechRecognizer, self).push_initialization_config()
        if self.rec_weights_init:
            rec_weights_config = {
                'weights_init': self.rec_weights_init,
                'recurrent_weights_init': self.rec_weights_init
            }
            global_push_initialization_config(self, rec_weights_config,
                                              BaseRecurrent)
        if self.initial_states_init:
            global_push_initialization_config(
                self, {'initial_states_init': self.initial_states_init})

    @application
    def cost(self, recordings, recordings_mask, labels, labels_mask):
        bottom_processed = self.bottom.apply(recordings)
        encoded, encoded_mask = self.encoder.apply(input_=bottom_processed,
                                                   mask=recordings_mask)
        encoded = self.top.apply(encoded)
        return self.generator.cost_matrix(labels,
                                          labels_mask,
                                          attended=encoded,
                                          attended_mask=encoded_mask)

    @application
    def generate(self, recordings):
        encoded, encoded_mask = self.encoder.apply(
            input_=self.bottom.apply(recordings))
        encoded = self.top.apply(encoded)
        return self.generator.generate(n_steps=recordings.shape[0],
                                       batch_size=recordings.shape[1],
                                       attended=encoded,
                                       attended_mask=encoded_mask,
                                       as_dict=True)

    def load_params(self, path):
        generated = self.get_generate_graph()
        param_values = load_parameter_values(path)
        SpeechModel(generated['outputs']).set_parameter_values(param_values)

    def get_generate_graph(self):
        result = self.generate(self.recordings)
        return result

    def get_cost_graph(self, batch=True):
        if batch:
            return self.cost(self.recordings, self.recordings_mask,
                             self.labels, self.labels_mask)
        recordings = self.single_recording[:, None, :]
        labels = self.single_transcription[:, None]
        return self.cost(recordings, tensor.ones_like(recordings[:, :, 0]),
                         labels, None)

    def analyze(self, recording, transcription):
        """Compute cost and aligment for a recording/transcription pair."""
        if not hasattr(self, "_analyze"):
            cost = self.get_cost_graph(batch=False)
            cg = ComputationGraph(cost)
            energies = VariableFilter(bricks=[self.generator],
                                      name="energies")(cg)
            energies_output = [
                energies[0][:, 0, :] if energies else tensor.zeros(
                    (self.single_transcription.shape[0],
                     self.single_recording.shape[0]))
            ]
            states, = VariableFilter(applications=[self.encoder.apply],
                                     roles=[OUTPUT],
                                     name="encoded")(cg)
            ctc_matrix_output = []
            # Temporarily disabled for compatibility with LM code
            # if len(self.generator.readout.source_names) == 1:
            #    ctc_matrix_output = [
            #        self.generator.readout.readout(weighted_averages=states)[:, 0, :]]
            weights, = VariableFilter(bricks=[self.generator],
                                      name="weights")(cg)
            self._analyze = theano.function(
                [self.single_recording, self.single_transcription],
                [cost[:, 0], weights[:, 0, :]] + energies_output +
                ctc_matrix_output)
        return self._analyze(recording, transcription)

    def init_beam_search(self, beam_size):
        """Compile beam search and set the beam size.

        See Blocks issue #500.

        """
        self.beam_size = beam_size
        generated = self.get_generate_graph()
        samples, = VariableFilter(applications=[self.generator.generate],
                                  name="outputs")(ComputationGraph(
                                      generated['outputs']))
        self._beam_search = BeamSearch(beam_size, samples)
        self._beam_search.compile()

    def beam_search(self, recording, char_discount=0.0):
        if not hasattr(self, '_beam_search'):
            self.init_beam_search(self.beam_size)
        input_ = recording[:, numpy.newaxis, :]
        outputs, search_costs = self._beam_search.search(
            {self.recordings: input_},
            self.eos_label,
            input_.shape[0] / 3,
            ignore_first_eol=self.data_prepend_eos,
            char_discount=char_discount)
        return outputs, search_costs

    def __getstate__(self):
        state = dict(self.__dict__)
        for attr in ['_analyze', '_beam_search']:
            state.pop(attr, None)
        return state

    def __setstate__(self, state):
        self.__dict__.update(state)
        # To use bricks used on a GPU first on a CPU later
        try:
            emitter = self.generator.readout.emitter
            del emitter._theano_rng
        except:
            pass
Ejemplo n.º 4
0
class SpeechRecognizer(Initializable):
    """Encapsulate all reusable logic.

    This class plays a few roles: (a) it's a top brick that knows
    how to combine bottom, bidirectional and recognizer network, (b)
    it has the inputs variables and can build whole computation graphs
    starting with them (c) it hides compilation of Theano functions
    and initialization of beam search. I find it simpler to have it all
    in one place for research code.

    Parameters
    ----------
    All defining the structure and the dimensions of the model. Typically
    receives everything from the "net" section of the config.

    """
    def __init__(self, recordings_source, labels_source, eos_label,
                 num_features, num_phonemes,
                 dim_dec, dims_bidir, dims_bottom,
                 enc_transition, dec_transition,
                 use_states_for_readout,
                 attention_type,
                 lm=None, character_map=None,
                 subsample=None,
                 dims_top=None,
                 prior=None, conv_n=None,
                 bottom_activation=None,
                 post_merge_activation=None,
                 post_merge_dims=None,
                 dim_matcher=None,
                 embed_outputs=True,
                 dec_stack=1,
                 conv_num_filters=1,
                 data_prepend_eos=True,
                 energy_normalizer=None,  # softmax is th edefault set in SequenceContentAndConvAttention
                 **kwargs):
        if bottom_activation is None:
            bottom_activation = Tanh()
        if post_merge_activation is None:
            post_merge_activation = Tanh()
        super(SpeechRecognizer, self).__init__(**kwargs)
        self.recordings_source = recordings_source
        self.labels_source = labels_source
        self.eos_label = eos_label
        self.data_prepend_eos = data_prepend_eos

        self.rec_weights_init = None
        self.initial_states_init = None

        self.enc_transition = enc_transition
        self.dec_transition = dec_transition
        self.dec_stack = dec_stack

        bottom_activation = bottom_activation
        post_merge_activation = post_merge_activation

        if dim_matcher is None:
            dim_matcher = dim_dec

        # The bottom part, before BiRNN
        if dims_bottom:
            bottom = MLP([bottom_activation] * len(dims_bottom),
                         [num_features] + dims_bottom,
                         name="bottom")
        else:
            bottom = Identity(name='bottom')

        # BiRNN
        if not subsample:
            subsample = [1] * len(dims_bidir)
        encoder = Encoder(self.enc_transition, dims_bidir,
                          dims_bottom[-1] if len(dims_bottom) else num_features,
                          subsample)

        # The top part, on top of BiRNN but before the attention
        if dims_top:
            top = MLP([Tanh()],
                      [2 * dims_bidir[-1]] + dims_top + [2 * dims_bidir[-1]], name="top")
        else:
            top = Identity(name='top')

        if dec_stack == 1:
            transition = self.dec_transition(
                dim=dim_dec, activation=Tanh(), name="transition")
        else:
            transitions = [self.dec_transition(dim=dim_dec,
                                               activation=Tanh(),
                                               name="transition_{}".format(trans_level))
                           for trans_level in xrange(dec_stack)]
            transition = RecurrentStack(transitions=transitions,
                                        skip_connections=True)
        # Choose attention mechanism according to the configuration
        if attention_type == "content":
            attention = SequenceContentAttention(
                state_names=transition.apply.states,
                attended_dim=2 * dims_bidir[-1], match_dim=dim_matcher,
                name="cont_att")
        elif attention_type == "content_and_conv":
            attention = SequenceContentAndConvAttention(
                state_names=transition.apply.states,
                conv_n=conv_n,
                conv_num_filters=conv_num_filters,
                attended_dim=2 * dims_bidir[-1], match_dim=dim_matcher,
                prior=prior,
                energy_normalizer=energy_normalizer,
                name="conv_att")
        else:
            raise ValueError("Unknown attention type {}"
                             .format(attention_type))
        if embed_outputs:
            feedback = LookupFeedback(num_phonemes + 1, dim_dec)
        else:
            feedback = OneOfNFeedback(num_phonemes + 1)
        if lm:
            # In case we use LM it is Readout that is responsible
            # for normalization.
            emitter = LMEmitter()
        else:
            emitter = SoftmaxEmitter(initial_output=num_phonemes, name="emitter")
        readout_config = dict(
            readout_dim=num_phonemes,
            source_names=(transition.apply.states if use_states_for_readout else [])
                         + [attention.take_glimpses.outputs[0]],
            emitter=emitter,
            feedback_brick=feedback,
            name="readout")
        if post_merge_dims:
            readout_config['merged_dim'] = post_merge_dims[0]
            readout_config['post_merge'] = InitializableSequence([
                Bias(post_merge_dims[0]).apply,
                post_merge_activation.apply,
                MLP([post_merge_activation] * (len(post_merge_dims) - 1) + [Identity()],
                    # MLP was designed to support Maxout is activation
                    # (because Maxout in a way is not one). However
                    # a single layer Maxout network works with the trick below.
                    # For deeper Maxout network one has to use the
                    # Sequence brick.
                    [d//getattr(post_merge_activation, 'num_pieces', 1)
                     for d in post_merge_dims] + [num_phonemes]).apply,
            ],
                name='post_merge')
        readout = Readout(**readout_config)

        language_model = None
        if lm:
            lm_weight = lm.pop('weight', 0.0)
            normalize_am_weights = lm.pop('normalize_am_weights', True)
            normalize_lm_weights = lm.pop('normalize_lm_weights', False)
            normalize_tot_weights = lm.pop('normalize_tot_weights', False)
            am_beta = lm.pop('am_beta', 1.0)
            if normalize_am_weights + normalize_lm_weights + normalize_tot_weights < 1:
                logger.warn("Beam search is prone to fail with no log-prob normalization")
            language_model = LanguageModel(nn_char_map=character_map, **lm)
            readout = ShallowFusionReadout(lm_costs_name='lm_add',
                                           lm_weight=lm_weight,
                                           normalize_am_weights=normalize_am_weights,
                                           normalize_lm_weights=normalize_lm_weights,
                                           normalize_tot_weights=normalize_tot_weights,
                                           am_beta=am_beta,
                                           **readout_config)

        generator = SequenceGenerator(
            readout=readout, transition=transition, attention=attention,
            language_model=language_model,
            name="generator")

        # Remember child bricks
        self.encoder = encoder
        self.bottom = bottom
        self.top = top
        self.generator = generator
        self.children = [encoder, top, bottom, generator]

        # Create input variables
        self.recordings = tensor.tensor3(self.recordings_source)
        self.recordings_mask = tensor.matrix(self.recordings_source + "_mask")
        self.labels = tensor.lmatrix(self.labels_source)
        self.labels_mask = tensor.matrix(self.labels_source + "_mask")
        self.batch_inputs = [self.recordings, self.recordings_source,
                             self.labels, self.labels_mask]
        self.single_recording = tensor.matrix(self.recordings_source)
        self.single_transcription = tensor.lvector(self.labels_source)

    def push_initialization_config(self):
        super(SpeechRecognizer, self).push_initialization_config()
        if self.rec_weights_init:
            rec_weights_config = {'weights_init': self.rec_weights_init,
                                  'recurrent_weights_init': self.rec_weights_init}
            global_push_initialization_config(self,
                                              rec_weights_config,
                                              BaseRecurrent)
        if self.initial_states_init:
            global_push_initialization_config(self,
                                              {'initial_states_init': self.initial_states_init})

    @application
    def cost(self, recordings, recordings_mask, labels, labels_mask):
        bottom_processed = self.bottom.apply(recordings)
        encoded, encoded_mask = self.encoder.apply(
            input_=bottom_processed,
            mask=recordings_mask)
        encoded = self.top.apply(encoded)
        return self.generator.cost_matrix(
            labels, labels_mask,
            attended=encoded, attended_mask=encoded_mask)

    @application
    def generate(self, recordings):
        encoded, encoded_mask = self.encoder.apply(
            input_=self.bottom.apply(recordings))
        encoded = self.top.apply(encoded)
        return self.generator.generate(
            n_steps=recordings.shape[0], batch_size=recordings.shape[1],
            attended=encoded,
            attended_mask=encoded_mask,
            as_dict=True)

    def load_params(self, path):
        generated = self.get_generate_graph()
        param_values = load_parameter_values(path)
        SpeechModel(generated['outputs']).set_parameter_values(param_values)

    def get_generate_graph(self):
        result = self.generate(self.recordings)
        return result

    def get_cost_graph(self, batch=True):
        if batch:
            return self.cost(
                self.recordings, self.recordings_mask,
                self.labels, self.labels_mask)
        recordings = self.single_recording[:, None, :]
        labels = self.single_transcription[:, None]
        return self.cost(
            recordings, tensor.ones_like(recordings[:, :, 0]),
            labels, None)

    def analyze(self, recording, transcription):
        """Compute cost and aligment for a recording/transcription pair."""
        if not hasattr(self, "_analyze"):
            cost = self.get_cost_graph(batch=False)
            cg = ComputationGraph(cost)
            energies = VariableFilter(
                bricks=[self.generator], name="energies")(cg)
            energies_output = [energies[0][:, 0, :] if energies
                               else tensor.zeros((self.single_transcription.shape[0],
                                                  self.single_recording.shape[0]))]
            states, = VariableFilter(
                applications=[self.encoder.apply], roles=[OUTPUT],
                name="encoded")(cg)
            ctc_matrix_output = []
            # Temporarily disabled for compatibility with LM code
            # if len(self.generator.readout.source_names) == 1:
            #    ctc_matrix_output = [
            #        self.generator.readout.readout(weighted_averages=states)[:, 0, :]]
            weights, = VariableFilter(
                bricks=[self.generator], name="weights")(cg)
            self._analyze = theano.function(
                [self.single_recording, self.single_transcription],
                [cost[:, 0], weights[:, 0, :]] + energies_output + ctc_matrix_output)
        return self._analyze(recording, transcription)

    def init_beam_search(self, beam_size):
        """Compile beam search and set the beam size.

        See Blocks issue #500.

        """
        self.beam_size = beam_size
        generated = self.get_generate_graph()
        samples, = VariableFilter(
            applications=[self.generator.generate], name="outputs")(
            ComputationGraph(generated['outputs']))
        self._beam_search = BeamSearch(beam_size, samples)
        self._beam_search.compile()

    def beam_search(self, recording, char_discount=0.0):
        if not hasattr(self, '_beam_search'):
            self.init_beam_search(self.beam_size)
        input_ = recording[:,numpy.newaxis,:]
        outputs, search_costs = self._beam_search.search(
            {self.recordings: input_}, self.eos_label, input_.shape[0] / 3,
            ignore_first_eol=self.data_prepend_eos,
            char_discount=char_discount)
        return outputs, search_costs

    def __getstate__(self):
        state = dict(self.__dict__)
        for attr in ['_analyze', '_beam_search']:
            state.pop(attr, None)
        return state

    def __setstate__(self, state):
        self.__dict__.update(state)
        # To use bricks used on a GPU first on a CPU later
        try:
            emitter = self.generator.readout.emitter
            del emitter._theano_rng
        except:
            pass
Ejemplo n.º 5
0
class SpeechRecognizer(Initializable):
    """Encapsulate all reusable logic.

    This class plays a few roles: (a) it's a top brick that knows
    how to combine bottom, bidirectional and recognizer network, (b)
    it has the inputs variables and can build whole computation graphs
    starting with them (c) it hides compilation of Theano functions
    and initialization of beam search. I find it simpler to have it all
    in one place for research code.

    Parameters
    ----------
    All defining the structure and the dimensions of the model. Typically
    receives everything from the "net" section of the config.

    """
    def __init__(
            self,
            input_dims,
            input_num_chars,
            eos_label,
            num_phonemes,
            dim_dec,
            dims_bidir,
            enc_transition,
            dec_transition,
            use_states_for_readout,
            attention_type,
            criterion,
            bottom,
            lm=None,
            character_map=None,
            bidir=True,
            subsample=None,
            dims_top=None,
            prior=None,
            conv_n=None,
            post_merge_activation=None,
            post_merge_dims=None,
            dim_matcher=None,
            embed_outputs=True,
            dim_output_embedding=None,
            dec_stack=1,
            conv_num_filters=1,
            data_prepend_eos=True,
            # softmax is the default set in SequenceContentAndConvAttention
            energy_normalizer=None,
            # for speech this is the approximate phoneme duration in frames
            max_decoded_length_scale=1,
            **kwargs):

        if post_merge_activation is None:
            post_merge_activation = Tanh()
        super(SpeechRecognizer, self).__init__(**kwargs)
        self.eos_label = eos_label
        self.data_prepend_eos = data_prepend_eos

        self.rec_weights_init = None
        self.initial_states_init = None

        self.enc_transition = enc_transition
        self.dec_transition = dec_transition
        self.dec_stack = dec_stack

        self.criterion = criterion

        self.max_decoded_length_scale = max_decoded_length_scale

        post_merge_activation = post_merge_activation

        if dim_matcher is None:
            dim_matcher = dim_dec

        # The bottom part, before BiRNN
        bottom_class = bottom.pop('bottom_class')
        bottom = bottom_class(input_dims=input_dims,
                              input_num_chars=input_num_chars,
                              name='bottom',
                              **bottom)

        # BiRNN
        if not subsample:
            subsample = [1] * len(dims_bidir)
        encoder = Encoder(self.enc_transition,
                          dims_bidir,
                          bottom.get_dim(bottom.apply.outputs[0]),
                          subsample,
                          bidir=bidir)
        dim_encoded = encoder.get_dim(encoder.apply.outputs[0])

        generators = [None, None]
        for i in range(2):
            # The top part, on top of BiRNN but before the attention
            if dims_top:
                top = MLP([Tanh()], [dim_encoded] + dims_top + [dim_encoded],
                          name="top{}".format(i))
            else:
                top = Identity(name='top{}'.format(i))

            if dec_stack == 1:
                transition = self.dec_transition(dim=dim_dec,
                                                 activation=Tanh(),
                                                 name="transition{}".format(i))
            else:
                transitions = [
                    self.dec_transition(dim=dim_dec,
                                        activation=Tanh(),
                                        name="transition_{}_{}".format(
                                            i, trans_level))
                    for trans_level in xrange(dec_stack)
                ]
                transition = RecurrentStack(transitions=transitions,
                                            skip_connections=True)
            # Choose attention mechanism according to the configuration
            if attention_type == "content":
                attention = SequenceContentAttention(
                    state_names=transition.apply.states,
                    attended_dim=dim_encoded,
                    match_dim=dim_matcher,
                    name="cont_att" + i)
            elif attention_type == "content_and_conv":
                attention = SequenceContentAndConvAttention(
                    state_names=transition.apply.states,
                    conv_n=conv_n,
                    conv_num_filters=conv_num_filters,
                    attended_dim=dim_encoded,
                    match_dim=dim_matcher,
                    prior=prior,
                    energy_normalizer=energy_normalizer,
                    name="conv_att{}".format(i))
            else:
                raise ValueError(
                    "Unknown attention type {}".format(attention_type))
            if embed_outputs:
                feedback = LookupFeedback(
                    num_phonemes + 1, dim_dec
                    if dim_output_embedding is None else dim_output_embedding)
            else:
                feedback = OneOfNFeedback(num_phonemes + 1)
            if criterion['name'] == 'log_likelihood':
                emitter = SoftmaxEmitter(initial_output=num_phonemes,
                                         name="emitter{}".format(i))
                if lm:
                    # In case we use LM it is Readout that is responsible
                    # for normalization.
                    emitter = LMEmitter()
            elif criterion['name'].startswith('mse'):
                emitter = RewardRegressionEmitter(criterion['name'],
                                                  eos_label,
                                                  num_phonemes,
                                                  criterion.get(
                                                      'min_reward', -1.0),
                                                  name="emitter")
            else:
                raise ValueError("Unknown criterion {}".format(
                    criterion['name']))
            readout_config = dict(
                readout_dim=num_phonemes,
                source_names=(transition.apply.states if use_states_for_readout
                              else []) + [attention.take_glimpses.outputs[0]],
                emitter=emitter,
                feedback_brick=feedback,
                name="readout{}".format(i))
            if post_merge_dims:
                readout_config['merged_dim'] = post_merge_dims[0]
                readout_config['post_merge'] = InitializableSequence(
                    [
                        Bias(post_merge_dims[0]).apply,
                        post_merge_activation.apply,
                        MLP(
                            [post_merge_activation] *
                            (len(post_merge_dims) - 1) + [Identity()],
                            # MLP was designed to support Maxout is activation
                            # (because Maxout in a way is not one). However
                            # a single layer Maxout network works with the trick below.
                            # For deeper Maxout network one has to use the
                            # Sequence brick.
                            [
                                d //
                                getattr(post_merge_activation, 'num_pieces', 1)
                                for d in post_merge_dims
                            ] + [num_phonemes]).apply,
                    ],
                    name='post_merge{}'.format(i))
            readout = Readout(**readout_config)

            language_model = None
            if lm and lm.get('path'):
                lm_weight = lm.pop('weight', 0.0)
                normalize_am_weights = lm.pop('normalize_am_weights', True)
                normalize_lm_weights = lm.pop('normalize_lm_weights', False)
                normalize_tot_weights = lm.pop('normalize_tot_weights', False)
                am_beta = lm.pop('am_beta', 1.0)
                if normalize_am_weights + normalize_lm_weights + normalize_tot_weights < 1:
                    logger.warn(
                        "Beam search is prone to fail with no log-prob normalization"
                    )
                language_model = LanguageModel(nn_char_map=character_map, **lm)
                readout = ShallowFusionReadout(
                    lm_costs_name='lm_add',
                    lm_weight=lm_weight,
                    normalize_am_weights=normalize_am_weights,
                    normalize_lm_weights=normalize_lm_weights,
                    normalize_tot_weights=normalize_tot_weights,
                    am_beta=am_beta,
                    **readout_config)

            generators[i] = SequenceGenerator(readout=readout,
                                              transition=transition,
                                              attention=attention,
                                              language_model=language_model,
                                              name="generator{}".format(i))

        self.generator = generators[0]

        self.forward_to_backward = Linear(dim_dec, dim_dec)

        # Remember child bricks
        self.encoder = encoder
        self.bottom = bottom
        self.top = top
        self.generators = generators
        self.children = [self.forward_to_backward, encoder, top, bottom
                         ] + generators

        # Create input variables
        self.inputs = self.bottom.batch_inputs
        self.inputs_mask = self.bottom.mask

        self.labels = tensor.lmatrix('labels')
        self.labels_mask = tensor.matrix("labels_mask")

        self.single_inputs = self.bottom.single_inputs
        self.single_labels = tensor.lvector('labels')
        self.n_steps = tensor.lscalar('n_steps')

    def push_initialization_config(self):
        super(SpeechRecognizer, self).push_initialization_config()
        if self.rec_weights_init:
            rec_weights_config = {
                'weights_init': self.rec_weights_init,
                'recurrent_weights_init': self.rec_weights_init
            }
            global_push_initialization_config(self, rec_weights_config,
                                              BaseRecurrent)
        if self.initial_states_init:
            global_push_initialization_config(
                self, {'initial_states_init': self.initial_states_init})

    @application
    def cost(self, application_call, **kwargs):
        # pop inputs we know about
        inputs_mask = kwargs.pop('inputs_mask')
        labels = kwargs.pop('labels')
        labels_mask = kwargs.pop('labels_mask')

        # the rest is for bottom
        bottom_processed = self.bottom.apply(**kwargs)
        encoded, encoded_mask = self.encoder.apply(input_=bottom_processed,
                                                   mask=inputs_mask)
        encoded = self.top.apply(encoded)
        outs_forward = self.generators[0].evaluate(labels,
                                                   labels_mask,
                                                   attended=encoded,
                                                   attended_mask=encoded_mask)
        costs_forward, states_forward, _, _, _, _ = outs_forward
        outs_backward = self.generators[1].evaluate(
            labels[::-1],
            labels_mask[::-1] if labels_mask else None,
            attended=encoded[::-1],
            attended_mask=encoded_mask[::-1])
        costs_backward, states_backward, _, _, _, _ = outs_backward
        costs_backward = costs_backward[::-1]
        states_backward = states_backward[::-1]

        states_shape = states_forward.shape
        backward_predicted = self.forward_to_backward.apply(
            states_forward.reshape((states_shape[0] * states_shape[1], -1)))
        backward_predicted = backward_predicted.reshape(states_shape)
        backward_predicted = backward_predicted * labels_mask[:, :, None]

        states_backward = gradient.disconnected_grad(states_backward)
        states_backward = states_backward * labels_mask[:, :, None]
        l2_cost = ((backward_predicted - states_backward)**2).mean(axis=2)
        l2_cost.name = 'l2_cost_aux'
        application_call.add_auxiliary_variable(
            l2_cost.sum(axis=0).mean().copy(name='l2_cost_aux'))
        costs_forward_aux = (costs_forward.sum(axis=0).mean()).copy(
            name='costs_forward_aux')
        application_call.add_auxiliary_variable(costs_forward_aux)
        return costs_forward + costs_backward + 1.5 * l2_cost

    @application
    def generate(self, **kwargs):
        inputs_mask = kwargs.pop('inputs_mask')
        n_steps = kwargs.pop('n_steps')

        encoded, encoded_mask = self.encoder.apply(
            input_=self.bottom.apply(**kwargs), mask=inputs_mask)
        encoded = self.top.apply(encoded)
        return self.generator.generate(
            n_steps=n_steps if n_steps is not None else self.n_steps,
            batch_size=encoded.shape[1],
            attended=encoded,
            attended_mask=encoded_mask,
            as_dict=True)

    def load_params(self, path):
        generated = self.get_generate_graph()
        with open(path, 'r') as src:
            param_values = load_parameters(src)
        Model(generated['outputs']).set_parameter_values(param_values)

    def get_generate_graph(self, use_mask=True, n_steps=None):
        inputs_mask = None
        if use_mask:
            inputs_mask = self.inputs_mask
        bottom_inputs = self.inputs
        return self.generate(n_steps=n_steps,
                             inputs_mask=inputs_mask,
                             **bottom_inputs)

    def get_cost_graph(self,
                       batch=True,
                       prediction=None,
                       prediction_mask=None):

        if batch:
            inputs = self.inputs
            inputs_mask = self.inputs_mask
            groundtruth = self.labels
            groundtruth_mask = self.labels_mask
        else:
            inputs, inputs_mask = self.bottom.single_to_batch_inputs(
                self.single_inputs)
            groundtruth = self.single_labels[:, None]
            groundtruth_mask = None

        if not prediction:
            prediction = groundtruth
        if not prediction_mask:
            prediction_mask = groundtruth_mask

        cost = self.cost(inputs_mask=inputs_mask,
                         labels=prediction,
                         labels_mask=prediction_mask,
                         **inputs)
        cost_cg = ComputationGraph(cost)
        if self.criterion['name'].startswith("mse"):
            placeholder, = VariableFilter(theano_name='groundtruth')(cost_cg)
            cost_cg = cost_cg.replace({placeholder: groundtruth})
        return cost_cg

    def analyze(self, inputs, groundtruth, prediction=None):
        """Compute cost and aligment."""

        input_values_dict = dict(inputs)
        input_values_dict['groundtruth'] = groundtruth
        if prediction is not None:
            input_values_dict['prediction'] = prediction
        if not hasattr(self, "_analyze"):
            input_variables = list(self.single_inputs.values())
            input_variables.append(self.single_labels.copy(name='groundtruth'))

            prediction_variable = tensor.lvector('prediction')
            if prediction is not None:
                input_variables.append(prediction_variable)
                cg = self.get_cost_graph(batch=False,
                                         prediction=prediction_variable[:,
                                                                        None])
            else:
                cg = self.get_cost_graph(batch=False)
            cost = cg.outputs[0]

            weights, = VariableFilter(bricks=[self.generator],
                                      name="weights")(cg)

            energies = VariableFilter(bricks=[self.generator],
                                      name="energies")(cg)
            energies_output = [
                energies[0][:,
                            0, :] if energies else tensor.zeros_like(weights)
            ]

            states, = VariableFilter(applications=[self.encoder.apply],
                                     roles=[OUTPUT],
                                     name="encoded")(cg)

            ctc_matrix_output = []
            # Temporarily disabled for compatibility with LM code
            # if len(self.generator.readout.source_names) == 1:
            #    ctc_matrix_output = [
            #        self.generator.readout.readout(weighted_averages=states)[:, 0, :]]

            self._analyze = theano.function(
                input_variables, [cost[:, 0], weights[:, 0, :]] +
                energies_output + ctc_matrix_output,
                on_unused_input='warn')
        return self._analyze(**input_values_dict)

    def init_beam_search(self, beam_size):
        """Compile beam search and set the beam size.

        See Blocks issue #500.

        """
        if hasattr(self, '_beam_search') and self.beam_size == beam_size:
            # Only recompile if the user wants a different beam size
            return
        self.beam_size = beam_size
        generated = self.get_generate_graph(use_mask=False, n_steps=3)
        cg = ComputationGraph(generated.values())
        samples, = VariableFilter(applications=[self.generator.generate],
                                  name="outputs")(cg)
        self._beam_search = BeamSearch(beam_size, samples)
        self._beam_search.compile()

    def beam_search(self, inputs, **kwargs):
        # When a recognizer is unpickled, self.beam_size is available
        # but beam search has to be recompiled.

        self.init_beam_search(self.beam_size)
        inputs = dict(inputs)
        max_length = int(
            self.bottom.num_time_steps(**inputs) /
            self.max_decoded_length_scale)
        search_inputs = {}
        for var in self.inputs.values():
            search_inputs[var] = inputs.pop(var.name)[:, numpy.newaxis, ...]
        if inputs:
            raise Exception('Unknown inputs passed to beam search: {}'.format(
                inputs.keys()))
        outputs, search_costs = self._beam_search.search(
            search_inputs,
            self.eos_label,
            max_length,
            ignore_first_eol=self.data_prepend_eos,
            **kwargs)
        return outputs, search_costs

    def init_generate(self):
        generated = self.get_generate_graph(use_mask=False)
        cg = ComputationGraph(generated['outputs'])
        self._do_generate = cg.get_theano_function()

    def sample(self, inputs, n_steps=None):
        if not hasattr(self, '_do_generate'):
            self.init_generate()
        batch, unused_mask = self.bottom.single_to_batch_inputs(inputs)
        batch['n_steps'] = n_steps if n_steps is not None \
            else int(self.bottom.num_time_steps(**batch) /
                     self.max_decoded_length_scale)
        return self._do_generate(**batch)[0]

    def __getstate__(self):
        state = dict(self.__dict__)
        for attr in ['_analyze', '_beam_search']:
            state.pop(attr, None)
        return state

    def __setstate__(self, state):
        self.__dict__.update(state)
        # To use bricks used on a GPU first on a CPU later
        try:
            emitter = self.generator.readout.emitter
            del emitter._theano_rng
        except:
            pass
Ejemplo n.º 6
0
loader = LoadNMTUtils(get_nmt_model_path_best_bleu(config),
                      config['saveto'],
                      nmt_model.search_model)
loader.load_weights()

src_sentences = load_sentences(args.src_test,
                               args.range,
                               config['src_vocab_size'])
n_sentences = len(src_sentences)

logging.info("%d source sentences loaded. Initialize decoding.." 
                    % n_sentences)

beam_search = BeamSearch(samples=nmt_model.samples)
beam_search.compile()

logging.info("Sort sentences, longest sentence first...")
src_sentences.sort(key=lambda x: len(x[1]), reverse=True)

# Bucketing
cur_bucket = Bucket(0)
buckets = [cur_bucket]
for sen_id, sen in src_sentences:
    sen_len = len(sen)
    if not cur_bucket.can_add(sen_len):
        cur_bucket.compile()
        cur_bucket = Bucket(len(buckets))
        buckets.append(cur_bucket)
    cur_bucket.add_task(DecodingTask(sen_id, sen))
cur_bucket.compile()
Ejemplo n.º 7
0
class SpeechRecognizer(Initializable):
    """Encapsulate all reusable logic.

    This class plays a few roles: (a) it's a top brick that knows
    how to combine bottom, bidirectional and recognizer network, (b)
    it has the inputs variables and can build whole computation graphs
    starting with them (c) it hides compilation of Theano functions
    and initialization of beam search. I find it simpler to have it all
    in one place for research code.

    Parameters
    ----------
    All defining the structure and the dimensions of the model. Typically
    receives everything from the "net" section of the config.

    """

    def __init__(self,
                 input_dims,
                 input_num_chars,
                 eos_label,
                 num_phonemes,
                 dim_dec, dims_bidir,
                 enc_transition, dec_transition,
                 use_states_for_readout,
                 attention_type,
                 criterion,
                 bottom,
                 lm=None, character_map=None,
                 bidir=True,
                 subsample=None,
                 dims_top=None,
                 prior=None, conv_n=None,
                 post_merge_activation=None,
                 post_merge_dims=None,
                 dim_matcher=None,
                 embed_outputs=True,
                 dim_output_embedding=None,
                 dec_stack=1,
                 conv_num_filters=1,
                 data_prepend_eos=True,
                 # softmax is the default set in SequenceContentAndConvAttention
                 energy_normalizer=None,
                 # for speech this is the approximate phoneme duration in frames
                 max_decoded_length_scale=1,
                 **kwargs):

        if post_merge_activation is None:
            post_merge_activation = Tanh()
        super(SpeechRecognizer, self).__init__(**kwargs)
        self.eos_label = eos_label
        self.data_prepend_eos = data_prepend_eos

        self.rec_weights_init = None
        self.initial_states_init = None

        self.enc_transition = enc_transition
        self.dec_transition = dec_transition
        self.dec_stack = dec_stack

        self.criterion = criterion

        self.max_decoded_length_scale = max_decoded_length_scale

        post_merge_activation = post_merge_activation

        if dim_matcher is None:
            dim_matcher = dim_dec

        # The bottom part, before BiRNN
        bottom_class = bottom.pop('bottom_class')
        bottom = bottom_class(
            input_dims=input_dims, input_num_chars=input_num_chars,
            name='bottom',
            **bottom)

        # BiRNN
        if not subsample:
            subsample = [1] * len(dims_bidir)
        encoder = Encoder(self.enc_transition, dims_bidir,
                          bottom.get_dim(bottom.apply.outputs[0]),
                          subsample, bidir=bidir)
        dim_encoded = encoder.get_dim(encoder.apply.outputs[0])

        # The top part, on top of BiRNN but before the attention
        if dims_top:
            top = MLP([Tanh()],
                      [dim_encoded] + dims_top + [dim_encoded], name="top")
        else:
            top = Identity(name='top')

        if dec_stack == 1:
            transition = self.dec_transition(
                dim=dim_dec, activation=Tanh(), name="transition")
        else:
            transitions = [self.dec_transition(dim=dim_dec,
                                               activation=Tanh(),
                                               name="transition_{}".format(trans_level))
                           for trans_level in xrange(dec_stack)]
            transition = RecurrentStack(transitions=transitions,
                                        skip_connections=True)
        # Choose attention mechanism according to the configuration
        if attention_type == "content":
            attention = SequenceContentAttention(
                state_names=transition.apply.states,
                attended_dim=dim_encoded, match_dim=dim_matcher,
                name="cont_att")
        elif attention_type == "content_and_conv":
            attention = SequenceContentAndConvAttention(
                state_names=transition.apply.states,
                conv_n=conv_n,
                conv_num_filters=conv_num_filters,
                attended_dim=dim_encoded, match_dim=dim_matcher,
                prior=prior,
                energy_normalizer=energy_normalizer,
                name="conv_att")
        else:
            raise ValueError("Unknown attention type {}"
                             .format(attention_type))
        if embed_outputs:
            feedback = LookupFeedback(num_phonemes + 1,
                                      dim_dec if
                                      dim_output_embedding is None
                                      else dim_output_embedding)
        else:
            feedback = OneOfNFeedback(num_phonemes + 1)
        if criterion['name'] == 'log_likelihood':
            emitter = SoftmaxEmitter(initial_output=num_phonemes, name="emitter")
            if lm:
                # In case we use LM it is Readout that is responsible
                # for normalization.
                emitter = LMEmitter()
        elif criterion['name'].startswith('mse'):
            emitter = RewardRegressionEmitter(
                criterion['name'], eos_label, num_phonemes,
                criterion.get('min_reward', -1.0),
                name="emitter")
        else:
            raise ValueError("Unknown criterion {}".format(criterion['name']))
        readout_config = dict(
            readout_dim=num_phonemes,
            source_names=(transition.apply.states if use_states_for_readout else [])
                         + [attention.take_glimpses.outputs[0]],
            emitter=emitter,
            feedback_brick=feedback,
            name="readout")
        if post_merge_dims:
            readout_config['merged_dim'] = post_merge_dims[0]
            readout_config['post_merge'] = InitializableSequence([
                Bias(post_merge_dims[0]).apply,
                post_merge_activation.apply,
                MLP([post_merge_activation] * (len(post_merge_dims) - 1) + [Identity()],
                    # MLP was designed to support Maxout is activation
                    # (because Maxout in a way is not one). However
                    # a single layer Maxout network works with the trick below.
                    # For deeper Maxout network one has to use the
                    # Sequence brick.
                    [d//getattr(post_merge_activation, 'num_pieces', 1)
                     for d in post_merge_dims] + [num_phonemes]).apply,
            ],
                name='post_merge')
        readout = Readout(**readout_config)

        language_model = None
        if lm and lm.get('path'):
            lm_weight = lm.pop('weight', 0.0)
            normalize_am_weights = lm.pop('normalize_am_weights', True)
            normalize_lm_weights = lm.pop('normalize_lm_weights', False)
            normalize_tot_weights = lm.pop('normalize_tot_weights', False)
            am_beta = lm.pop('am_beta', 1.0)
            if normalize_am_weights + normalize_lm_weights + normalize_tot_weights < 1:
                logger.warn("Beam search is prone to fail with no log-prob normalization")
            language_model = LanguageModel(nn_char_map=character_map, **lm)
            readout = ShallowFusionReadout(lm_costs_name='lm_add',
                                           lm_weight=lm_weight,
                                           normalize_am_weights=normalize_am_weights,
                                           normalize_lm_weights=normalize_lm_weights,
                                           normalize_tot_weights=normalize_tot_weights,
                                           am_beta=am_beta,
                                           **readout_config)

        generator = SequenceGenerator(
            readout=readout, transition=transition, attention=attention,
            language_model=language_model,
            name="generator")

        # Remember child bricks
        self.encoder = encoder
        self.bottom = bottom
        self.top = top
        self.generator = generator
        self.children = [encoder, top, bottom, generator]

        # Create input variables
        self.inputs = self.bottom.batch_inputs
        self.inputs_mask = self.bottom.mask

        self.labels = tensor.lmatrix('labels')
        self.labels_mask = tensor.matrix("labels_mask")

        self.single_inputs = self.bottom.single_inputs
        self.single_labels = tensor.lvector('labels')
        self.n_steps = tensor.lscalar('n_steps')

    def push_initialization_config(self):
        super(SpeechRecognizer, self).push_initialization_config()
        if self.rec_weights_init:
            rec_weights_config = {'weights_init': self.rec_weights_init,
                                  'recurrent_weights_init': self.rec_weights_init}
            global_push_initialization_config(self,
                                              rec_weights_config,
                                              BaseRecurrent)
        if self.initial_states_init:
            global_push_initialization_config(self,
                                              {'initial_states_init': self.initial_states_init})

    @application
    def cost(self, **kwargs):
        # pop inputs we know about
        inputs_mask = kwargs.pop('inputs_mask')
        labels = kwargs.pop('labels')
        labels_mask = kwargs.pop('labels_mask')

        # the rest is for bottom
        bottom_processed = self.bottom.apply(**kwargs)
        encoded, encoded_mask = self.encoder.apply(
            input_=bottom_processed,
            mask=inputs_mask)
        encoded = self.top.apply(encoded)
        return self.generator.cost_matrix(
            labels, labels_mask,
            attended=encoded, attended_mask=encoded_mask)

    @application
    def generate(self, **kwargs):
        inputs_mask = kwargs.pop('inputs_mask')
        n_steps = kwargs.pop('n_steps')

        encoded, encoded_mask = self.encoder.apply(
            input_=self.bottom.apply(**kwargs),
            mask=inputs_mask)
        encoded = self.top.apply(encoded)
        return self.generator.generate(
            n_steps=n_steps if n_steps is not None else self.n_steps,
            batch_size=encoded.shape[1],
            attended=encoded,
            attended_mask=encoded_mask,
            as_dict=True)

    def load_params(self, path):
        generated = self.get_generate_graph()
        with open(path, 'r') as src:
            param_values = load_parameters(src)
        Model(generated['outputs']).set_parameter_values(param_values)

    def get_generate_graph(self, use_mask=True, n_steps=None):
        inputs_mask = None
        if use_mask:
            inputs_mask = self.inputs_mask
        bottom_inputs = self.inputs
        return self.generate(n_steps=n_steps,
                             inputs_mask=inputs_mask,
                             **bottom_inputs)

    def get_cost_graph(self, batch=True,
                       prediction=None, prediction_mask=None):

        if batch:
            inputs = self.inputs
            inputs_mask = self.inputs_mask
            groundtruth = self.labels
            groundtruth_mask = self.labels_mask
        else:
            inputs, inputs_mask = self.bottom.single_to_batch_inputs(
                self.single_inputs)
            groundtruth = self.single_labels[:, None]
            groundtruth_mask = None

        if not prediction:
            prediction = groundtruth
        if not prediction_mask:
            prediction_mask = groundtruth_mask

        cost = self.cost(inputs_mask=inputs_mask,
                         labels=prediction,
                         labels_mask=prediction_mask,
                         **inputs)
        cost_cg = ComputationGraph(cost)
        if self.criterion['name'].startswith("mse"):
            placeholder, = VariableFilter(theano_name='groundtruth')(cost_cg)
            cost_cg = cost_cg.replace({placeholder: groundtruth})
        return cost_cg

    def analyze(self, inputs, groundtruth, prediction=None):
        """Compute cost and aligment."""

        input_values_dict = dict(inputs)
        input_values_dict['groundtruth'] = groundtruth
        if prediction is not None:
            input_values_dict['prediction'] = prediction
        if not hasattr(self, "_analyze"):
            input_variables = list(self.single_inputs.values())
            input_variables.append(self.single_labels.copy(name='groundtruth'))

            prediction_variable = tensor.lvector('prediction')
            if prediction is not None:
                input_variables.append(prediction_variable)
                cg = self.get_cost_graph(
                    batch=False, prediction=prediction_variable[:, None])
            else:
                cg = self.get_cost_graph(batch=False)
            cost = cg.outputs[0]

            weights, = VariableFilter(
                bricks=[self.generator], name="weights")(cg)

            energies = VariableFilter(
                bricks=[self.generator], name="energies")(cg)
            energies_output = [energies[0][:, 0, :] if energies
                               else tensor.zeros_like(weights)]

            states, = VariableFilter(
                applications=[self.encoder.apply], roles=[OUTPUT],
                name="encoded")(cg)

            ctc_matrix_output = []
            # Temporarily disabled for compatibility with LM code
            # if len(self.generator.readout.source_names) == 1:
            #    ctc_matrix_output = [
            #        self.generator.readout.readout(weighted_averages=states)[:, 0, :]]

            self._analyze = theano.function(
                input_variables,
                [cost[:, 0], weights[:, 0, :]] + energies_output + ctc_matrix_output,
                on_unused_input='warn')
        return self._analyze(**input_values_dict)

    def init_beam_search(self, beam_size):
        """Compile beam search and set the beam size.

        See Blocks issue #500.

        """
        if hasattr(self, '_beam_search') and self.beam_size == beam_size:
            # Only recompile if the user wants a different beam size
            return
        self.beam_size = beam_size
        generated = self.get_generate_graph(use_mask=False, n_steps=3)
        cg = ComputationGraph(generated.values())
        samples, = VariableFilter(
            applications=[self.generator.generate], name="outputs")(cg)
        self._beam_search = BeamSearch(beam_size, samples)
        self._beam_search.compile()

    def beam_search(self, inputs, **kwargs):
        # When a recognizer is unpickled, self.beam_size is available
        # but beam search has to be recompiled.

        self.init_beam_search(self.beam_size)
        inputs = dict(inputs)
        max_length = int(self.bottom.num_time_steps(**inputs) /
                         self.max_decoded_length_scale)
        search_inputs = {}
        for var in self.inputs.values():
            search_inputs[var] = inputs.pop(var.name)[:, numpy.newaxis, ...]
        if inputs:
            raise Exception(
                'Unknown inputs passed to beam search: {}'.format(
                    inputs.keys()))
        outputs, search_costs = self._beam_search.search(
            search_inputs, self.eos_label,
            max_length,
            ignore_first_eol=self.data_prepend_eos,
            **kwargs)
        return outputs, search_costs

    def init_generate(self):
        generated = self.get_generate_graph(use_mask=False)
        cg = ComputationGraph(generated['outputs'])
        self._do_generate = cg.get_theano_function()

    def sample(self, inputs, n_steps=None):
        if not hasattr(self, '_do_generate'):
            self.init_generate()
        batch, unused_mask = self.bottom.single_to_batch_inputs(inputs)
        batch['n_steps'] = n_steps if n_steps is not None \
            else int(self.bottom.num_time_steps(**batch) /
                     self.max_decoded_length_scale)
        return self._do_generate(**batch)[0]

    def __getstate__(self):
        state = dict(self.__dict__)
        for attr in ['_analyze', '_beam_search']:
            state.pop(attr, None)
        return state

    def __setstate__(self, state):
        self.__dict__.update(state)
        # To use bricks used on a GPU first on a CPU later
        try:
            emitter = self.generator.readout.emitter
            del emitter._theano_rng
        except:
            pass