def __init__(self, readout, transition, attention=None, fork_inputs=None, add_contexts=True, **kwargs): if not fork_inputs: fork_inputs = [ name for name in transition.apply.sequences if name != 'mask' ] fork = Fork(fork_inputs) if attention: distribute = Distribute(fork_inputs, attention.take_glimpses.outputs[0]) transition = AttentionRecurrent(transition, attention, distribute, add_contexts=add_contexts, name="att_trans") else: transition = FakeAttentionRecurrent(transition, name="with_fake_attention") super(SequenceGenerator, self).__init__(readout, transition, fork, **kwargs)
def __init__(self, transition, context_transition, attention, distribute=None, add_contexts=True, attended_name=None, attended_mask_name=None, **kwargs): super(AttentionRecurrent, self).__init__(**kwargs) self._sequence_names = list(transition.apply.sequences) self._state_names = list(transition.apply.states) self._context_names = list(transition.apply.contexts) if add_contexts: if not attended_name: attended_name = 'attended_list' if not attended_mask_name: attended_mask_name = 'attended_mask_list' self.posTag_name = 'posTag' self._context_names += [ attended_name, attended_mask_name, self.posTag_name ] else: attended_name = self._context_names[0] attended_mask_name = self._context_names[1] if not distribute: normal_inputs = [ name for name in self._sequence_names if 'mask' not in name ] distribute = Distribute(normal_inputs, attention.take_glimpses.outputs[0]) self.transition = transition self.context_transition = context_transition self.attention = attention self.distribute = distribute self.add_contexts = add_contexts self.attended_name = attended_name self.attended_mask_name = attended_mask_name self.preprocessed_attended_name = "preprocessed_" + self.attended_name self.preprocessed_posTag_name = 'preprocessed_' + self.posTag_name self._glimpse_names = self.attention.take_glimpses.outputs #unchanged # We need to determine which glimpses are fed back. # Currently we extract it from `take_glimpses` signature. self.previous_glimpses_needed = [ name for name in self._glimpse_names if name in self.attention.take_glimpses.inputs ] self.children = [ self.transition, self.context_transition, self.attention, self.distribute ]
def __init__(self, recurrent, extra_input_name, extra_input_dim, **kwargs): self.recurrent = recurrent self.extra_input_name = extra_input_name self.extra_input_dim = extra_input_dim self._normal_inputs = [ name for name in self.recurrent.apply.sequences if name != 'mask' ] self.distribute = Distribute(self._normal_inputs, self.extra_input_name) children = [self.recurrent, self.distribute] super(RecurrentWithExtraInput, self).__init__(children=children, **kwargs) self.apply.sequences = self.recurrent.apply.sequences + [ self.extra_input_name ] self.apply.outputs = self.recurrent.apply.outputs self.apply.states = self.recurrent.apply.states self.apply.contexts = self.recurrent.apply.contexts self.initial_states.outputs = self.recurrent.initial_states.outputs
class RecurrentWithExtraInput(Initializable): @lazy(allocation=['extra_input_dim']) def __init__(self, recurrent, extra_input_name, extra_input_dim, **kwargs): self.recurrent = recurrent self.extra_input_name = extra_input_name self.extra_input_dim = extra_input_dim self._normal_inputs = [ name for name in self.recurrent.apply.sequences if name != 'mask' ] self.distribute = Distribute(self._normal_inputs, self.extra_input_name) children = [self.recurrent, self.distribute] super(RecurrentWithExtraInput, self).__init__(children=children, **kwargs) self.apply.sequences = self.recurrent.apply.sequences + [ self.extra_input_name ] self.apply.outputs = self.recurrent.apply.outputs self.apply.states = self.recurrent.apply.states self.apply.contexts = self.recurrent.apply.contexts self.initial_states.outputs = self.recurrent.initial_states.outputs def _push_allocation_config(self): self.distribute.source_dim = self.extra_input_dim self.distribute.target_dims = self.recurrent.get_dims( self.distribute.target_names) @application def apply(self, **kwargs): # Should handle both "iterate=True" and "iterate=False" extra_input = kwargs.pop(self.extra_input_name) mask = kwargs.pop('mask', None) normal_inputs = dict_subset(kwargs, self._normal_inputs, pop=True) normal_inputs = self.distribute.apply( as_dict=True, **dict_union(normal_inputs, {self.extra_input_name: extra_input})) return self.recurrent.apply(mask=mask, **dict_union(normal_inputs, kwargs)) @application def initial_states(self, *args, **kwargs): return self.recurrent.initial_states(*args, **kwargs) def get_dim(self, name): if name == self.extra_input_name: return self.extra_input_dim return self.recurrent.get_dim(name)
def __init__(self, transition, attention, topical_attention, distribute=None, topical_distribute=None, add_contexts=True, attended_name=None, attended_mask_name=None, topical_name=None, topical_attended_name=None, topical_attended_mask_name=None, content_name=None, **kwargs): super(AttentionRecurrent, self).__init__(**kwargs) self._sequence_names = list(transition.apply.sequences) self._state_names = list(transition.apply.states) self._context_names = list(transition.apply.contexts) self._topical_glimpse_names = [ 'topical_weighted_averages', 'topical_weights' ] if add_contexts: if not attended_name: attended_name = 'attended' if not attended_mask_name: attended_mask_name = 'attended_mask' self._context_names += [attended_name, attended_mask_name] else: attended_name = self._context_names[0] attended_mask_name = self._context_names[1] if not distribute: normal_inputs = [ name for name in self._sequence_names if 'mask' not in name ] distribute = Distribute(normal_inputs, attention.take_glimpses.outputs[0], name='distribute') if not topical_distribute: normal_inputs = [ name for name in self._sequence_names if 'mask' not in name ] topical_distribute = Distribute(normal_inputs, self._topical_glimpse_names[0], name='topical_distribute') self.transition = transition self.attention = attention self.topical_attention = topical_attention self.distribute = distribute self.topical_distribute = topical_distribute self.add_contexts = add_contexts self.attended_name = attended_name self.attended_mask_name = attended_mask_name self.topical_attended_name = topical_attended_name self.topical_attended_mask_name = topical_attended_mask_name self.content_name = content_name if not topical_name: self.topical_name = 'topical_embeddingq' else: self.topical_name = topical_name self.preprocessed_topical_attended_name = "preprocessed_" + self.topical_attended_name self.preprocessed_attended_name = "preprocessed_" + self.attended_name self._glimpse_names = self.attention.take_glimpses.outputs # We need to determine which glimpses are fed back. # Currently we extract it from `take_glimpses` signature. self.previous_glimpses_needed = [ name for name in self._glimpse_names if name in self.attention.take_glimpses.inputs ] self.children = [ self.transition, self.attention, self.topical_attention, self.distribute, self.topical_distribute ]