class InnerRecurrent(BaseRecurrent, Initializable): def __init__(self, inner_input_dim, outer_input_dim, inner_dim, **kwargs): self.inner_gru = GatedRecurrent(dim=inner_dim, name='inner_gru') self.inner_input_fork = Fork( output_names=[name for name in self.inner_gru.apply.sequences if 'mask' not in name], input_dim=inner_input_dim, name='inner_input_fork') self.outer_input_fork = Fork( output_names=[name for name in self.inner_gru.apply.sequences if 'mask' not in name], input_dim=outer_input_dim, name='inner_outer_fork') super(InnerRecurrent, self).__init__(**kwargs) self.children = [ self.inner_gru, self.inner_input_fork, self.outer_input_fork] def _push_allocation_config(self): self.inner_input_fork.output_dims = self.inner_gru.get_dims( self.inner_input_fork.output_names) self.outer_input_fork.output_dims = self.inner_gru.get_dims( self.outer_input_fork.output_names) @recurrent(sequences=['inner_inputs'], states=['states'], contexts=['outer_inputs'], outputs=['states']) def apply(self, inner_inputs, states, outer_inputs): forked_inputs = self.inner_input_fork.apply(inner_inputs, as_dict=True) forked_states = self.outer_input_fork.apply(outer_inputs, as_dict=True) gru_inputs = {key: forked_inputs[key] + forked_states[key] for key in forked_inputs.keys()} new_states = self.inner_gru.apply( iterate=False, **dict_union(gru_inputs, {'states': states})) return new_states # mean according to the time axis def get_dim(self, name): if name == 'states': return self.inner_gru.get_dim(name) else: return AttributeError