Пример #1
0
class InnerRecurrent(BaseRecurrent, Initializable):
    def __init__(self, inner_input_dim, outer_input_dim, inner_dim, **kwargs):
        self.inner_gru = GatedRecurrent(dim=inner_dim, name='inner_gru')

        self.inner_input_fork = Fork(
            output_names=[name for name in self.inner_gru.apply.sequences
                          if 'mask' not in name],
            input_dim=inner_input_dim, name='inner_input_fork')
        self.outer_input_fork = Fork(
            output_names=[name for name in self.inner_gru.apply.sequences
                          if 'mask' not in name],
            input_dim=outer_input_dim, name='inner_outer_fork')

        super(InnerRecurrent, self).__init__(**kwargs)

        self.children = [
            self.inner_gru, self.inner_input_fork, self.outer_input_fork]

    def _push_allocation_config(self):
        self.inner_input_fork.output_dims = self.inner_gru.get_dims(
            self.inner_input_fork.output_names)
        self.outer_input_fork.output_dims = self.inner_gru.get_dims(
            self.outer_input_fork.output_names)

    @recurrent(sequences=['inner_inputs'], states=['states'],
               contexts=['outer_inputs'], outputs=['states'])
    def apply(self, inner_inputs, states, outer_inputs):
        forked_inputs = self.inner_input_fork.apply(inner_inputs, as_dict=True)
        forked_states = self.outer_input_fork.apply(outer_inputs, as_dict=True)

        gru_inputs = {key: forked_inputs[key] + forked_states[key]
                      for key in forked_inputs.keys()}

        new_states = self.inner_gru.apply(
            iterate=False,
            **dict_union(gru_inputs, {'states': states}))
        return new_states  # mean according to the time axis

    def get_dim(self, name):
        if name == 'states':
            return self.inner_gru.get_dim(name)
        else:
            return AttributeError
Пример #2
0
class GatedRecurrentWithContext(Initializable):
    def __init__(self, *args, **kwargs):
        self.gated_recurrent = GatedRecurrent(*args, **kwargs)
        self.children = [self.gated_recurrent]

    @application(states=['states'],
                 outputs=['states'],
                 contexts=[
                     'readout_context', 'transition_context', 'update_context',
                     'reset_context'
                 ])
    def apply(self, transition_context, update_context, reset_context, *args,
              **kwargs):
        kwargs['inputs'] += transition_context
        kwargs['update_inputs'] += update_context
        kwargs['reset_inputs'] += reset_context
        # readout_context was only added for the Readout brick, discard it
        kwargs.pop('readout_context')
        return self.gated_recurrent.apply(*args, **kwargs)

    def get_dim(self, name):
        if name in [
                'readout_context', 'transition_context', 'update_context',
                'reset_context'
        ]:
            return self.dim
        return self.gated_recurrent.get_dim(name)

    def __getattr__(self, name):
        if name == 'gated_recurrent':
            raise AttributeError
        return getattr(self.gated_recurrent, name)

    @apply.property('sequences')
    def apply_inputs(self):
        sequences = ['mask', 'inputs']
        if self.use_update_gate:
            sequences.append('update_inputs')
        if self.use_reset_gate:
            sequences.append('reset_inputs')
        return sequences
Пример #3
0
class Encoder(Initializable):
    """Encoder of RNNsearch model."""

    def __init__(self, blockid, vocab_size, embedding_dim, state_dim, **kwargs):
        super(Encoder, self).__init__(**kwargs)
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.state_dim = state_dim
        self.blockid = blockid

        self.lookup = LookupTable(name='embeddings' + '_' + self.blockid)
        self.gru = GatedRecurrent(activation=Tanh(), dim=state_dim, name = "GatedRNN" + self.blockid)
        self.fwd_fork = Fork(
            [name for name in self.gru.apply.sequences
             if name != 'mask'], prototype=Linear(), name='fwd_fork' + '_' + self.blockid)

        self.children = [self.lookup, self.gru, self.fwd_fork]

    def _push_allocation_config(self):
        self.lookup.length = self.vocab_size
        self.lookup.dim = self.embedding_dim

        self.fwd_fork.input_dim = self.embedding_dim
        self.fwd_fork.output_dims = [self.gru.get_dim(name)
                                     for name in self.fwd_fork.output_names]

    @application(inputs=['source_sentence', 'source_sentence_mask'],
                 outputs=['representation'])
    def apply(self, source_sentence, source_sentence_mask):
        # Time as first dimension
        source_sentence = source_sentence.T
        source_sentence_mask = source_sentence_mask.T

        embeddings = self.lookup.apply(source_sentence)
        grupara =  merge( self.fwd_fork.apply(embeddings, as_dict=True) , {'mask': source_sentence_mask})
        representation = self.gru.apply(**grupara)
        return representation
Пример #4
0
class GatedRecurrentWithContext(Initializable):
    def __init__(self, *args, **kwargs):
        self.gated_recurrent = GatedRecurrent(*args, **kwargs)
        self.children = [self.gated_recurrent]

    @application(states=['states'], outputs=['states'],
                 contexts=['readout_context', 'transition_context',
                           'update_context', 'reset_context'])
    def apply(self, transition_context, update_context, reset_context,
              *args, **kwargs):
        kwargs['inputs'] += transition_context
        kwargs['update_inputs'] += update_context
        kwargs['reset_inputs'] += reset_context
        # readout_context was only added for the Readout brick, discard it
        kwargs.pop('readout_context')
        return self.gated_recurrent.apply(*args, **kwargs)

    def get_dim(self, name):
        if name in ['readout_context', 'transition_context',
                    'update_context', 'reset_context']:
            return self.dim
        return self.gated_recurrent.get_dim(name)

    def __getattr__(self, name):
        if name == 'gated_recurrent':
            raise AttributeError
        return getattr(self.gated_recurrent, name)

    @apply.property('sequences')
    def apply_inputs(self):
        sequences = ['mask', 'inputs']
        if self.use_update_gate:
            sequences.append('update_inputs')
        if self.use_reset_gate:
            sequences.append('reset_inputs')
        return sequences