def initial_state(self): """Initial state for a learning process.""" value_dict = self.loss_module.initial_state() shapes = [v.shape.as_list() for v in value_dict.values()] b = super(Learner, self).initial_state() return base_arch.merged_namedtuple( LearnerState, b, rolling_features=self.rolling_features.initial_state(shapes))
def current_state(self): """State stored on local tf.Variable for this Learner.""" var_dict = self.loss_module.current_state() shapes = [v.shape.as_list() for v in var_dict.values()] b = super(Learner, self).current_state() return base_arch.merged_namedtuple( LearnerState, b, rolling_features=self.rolling_features.current_state(shapes), )