class InnerRecurrent(BaseRecurrent, Initializable): def __init__(self, inner_input_dim, outer_input_dim, inner_dim, **kwargs): self.inner_gru = GatedRecurrent(dim=inner_dim, name='inner_gru') self.inner_input_fork = Fork( output_names=[name for name in self.inner_gru.apply.sequences if 'mask' not in name], input_dim=inner_input_dim, name='inner_input_fork') self.outer_input_fork = Fork( output_names=[name for name in self.inner_gru.apply.sequences if 'mask' not in name], input_dim=outer_input_dim, name='inner_outer_fork') super(InnerRecurrent, self).__init__(**kwargs) self.children = [ self.inner_gru, self.inner_input_fork, self.outer_input_fork] def _push_allocation_config(self): self.inner_input_fork.output_dims = self.inner_gru.get_dims( self.inner_input_fork.output_names) self.outer_input_fork.output_dims = self.inner_gru.get_dims( self.outer_input_fork.output_names) @recurrent(sequences=['inner_inputs'], states=['states'], contexts=['outer_inputs'], outputs=['states']) def apply(self, inner_inputs, states, outer_inputs): forked_inputs = self.inner_input_fork.apply(inner_inputs, as_dict=True) forked_states = self.outer_input_fork.apply(outer_inputs, as_dict=True) gru_inputs = {key: forked_inputs[key] + forked_states[key] for key in forked_inputs.keys()} new_states = self.inner_gru.apply( iterate=False, **dict_union(gru_inputs, {'states': states})) return new_states # mean according to the time axis def get_dim(self, name): if name == 'states': return self.inner_gru.get_dim(name) else: return AttributeError
class GatedRecurrentWithContext(Initializable): def __init__(self, *args, **kwargs): self.gated_recurrent = GatedRecurrent(*args, **kwargs) self.children = [self.gated_recurrent] @application(states=['states'], outputs=['states'], contexts=[ 'readout_context', 'transition_context', 'update_context', 'reset_context' ]) def apply(self, transition_context, update_context, reset_context, *args, **kwargs): kwargs['inputs'] += transition_context kwargs['update_inputs'] += update_context kwargs['reset_inputs'] += reset_context # readout_context was only added for the Readout brick, discard it kwargs.pop('readout_context') return self.gated_recurrent.apply(*args, **kwargs) def get_dim(self, name): if name in [ 'readout_context', 'transition_context', 'update_context', 'reset_context' ]: return self.dim return self.gated_recurrent.get_dim(name) def __getattr__(self, name): if name == 'gated_recurrent': raise AttributeError return getattr(self.gated_recurrent, name) @apply.property('sequences') def apply_inputs(self): sequences = ['mask', 'inputs'] if self.use_update_gate: sequences.append('update_inputs') if self.use_reset_gate: sequences.append('reset_inputs') return sequences
class Encoder(Initializable): """Encoder of RNNsearch model.""" def __init__(self, blockid, vocab_size, embedding_dim, state_dim, **kwargs): super(Encoder, self).__init__(**kwargs) self.vocab_size = vocab_size self.embedding_dim = embedding_dim self.state_dim = state_dim self.blockid = blockid self.lookup = LookupTable(name='embeddings' + '_' + self.blockid) self.gru = GatedRecurrent(activation=Tanh(), dim=state_dim, name = "GatedRNN" + self.blockid) self.fwd_fork = Fork( [name for name in self.gru.apply.sequences if name != 'mask'], prototype=Linear(), name='fwd_fork' + '_' + self.blockid) self.children = [self.lookup, self.gru, self.fwd_fork] def _push_allocation_config(self): self.lookup.length = self.vocab_size self.lookup.dim = self.embedding_dim self.fwd_fork.input_dim = self.embedding_dim self.fwd_fork.output_dims = [self.gru.get_dim(name) for name in self.fwd_fork.output_names] @application(inputs=['source_sentence', 'source_sentence_mask'], outputs=['representation']) def apply(self, source_sentence, source_sentence_mask): # Time as first dimension source_sentence = source_sentence.T source_sentence_mask = source_sentence_mask.T embeddings = self.lookup.apply(source_sentence) grupara = merge( self.fwd_fork.apply(embeddings, as_dict=True) , {'mask': source_sentence_mask}) representation = self.gru.apply(**grupara) return representation
class GatedRecurrentWithContext(Initializable): def __init__(self, *args, **kwargs): self.gated_recurrent = GatedRecurrent(*args, **kwargs) self.children = [self.gated_recurrent] @application(states=['states'], outputs=['states'], contexts=['readout_context', 'transition_context', 'update_context', 'reset_context']) def apply(self, transition_context, update_context, reset_context, *args, **kwargs): kwargs['inputs'] += transition_context kwargs['update_inputs'] += update_context kwargs['reset_inputs'] += reset_context # readout_context was only added for the Readout brick, discard it kwargs.pop('readout_context') return self.gated_recurrent.apply(*args, **kwargs) def get_dim(self, name): if name in ['readout_context', 'transition_context', 'update_context', 'reset_context']: return self.dim return self.gated_recurrent.get_dim(name) def __getattr__(self, name): if name == 'gated_recurrent': raise AttributeError return getattr(self.gated_recurrent, name) @apply.property('sequences') def apply_inputs(self): sequences = ['mask', 'inputs'] if self.use_update_gate: sequences.append('update_inputs') if self.use_reset_gate: sequences.append('reset_inputs') return sequences