def run(self, T_or_generator, inner_objective_feed_dicts=None, outer_objective_feed_dicts=None, initializer_feed_dict=None, global_step=None, session=None, online=False, callback=None): # callback may be a pair, first for froward pass, second for reverse pass callback = utils.as_tuple_or_list(callback) # same thing for T T_or_generator = utils.as_tuple_or_list(T_or_generator) ss = session or tf.get_default_session() del self._history[:] if not online: _fd = utils.maybe_call(initializer_feed_dict, utils.maybe_eval(global_step, ss)) self._save_history(ss.run(self.initialization, feed_dict=_fd)) # else: # not totally clear if i should add this # self._save_history(ss.run(list(self.state))) T = 0 # this is useful if T_or_generator is indeed a generator... for t in utils.solve_int_or_generator(T_or_generator[0]): # nonlocal t # with nonlocal would not be necessary the variable T... not compatible with 2.7 _fd = utils.maybe_call(inner_objective_feed_dicts, t) self._save_history(ss.run(self.iteration, feed_dict=_fd)) utils.maybe_call(callback[0], t, _fd, ss) T = t # initialization of support variables (supports stochastic evaluation of outer objective via global_step -> # variable) # TODO (maybe tf bug or oddity) for some strange reason, if some variable's initializer depends on # a placeholder, then the initializer of alpha SEEMS TO DEPEND ALSO ON THAT placeholder, # as if the primary variable should be reinitialized as well, but, I've checked, the primary variable is NOT # actually reinitialized. This doesn't make sense since the primary variable is already initialized # and Tensorflow seems not to care... should maybe look better into this issue reverse_init_fd = utils.maybe_call(outer_objective_feed_dicts, utils.maybe_eval(global_step, ss)) # now adding also the initializer_feed_dict because of tf quirk... maybe_init_fd = utils.maybe_call(initializer_feed_dict, utils.maybe_eval(global_step, ss)) reverse_init_fd = utils.merge_dicts(reverse_init_fd, maybe_init_fd) ss.run(self._reverse_initializer, feed_dict=reverse_init_fd) for pt, state_feed_dict in self._state_feed_dict_generator( reversed(self._history[:-1]), T_or_generator[-1]): # this should be fine also for truncated reverse... but check again the index t t = T - pt - 1 # if T is int then len(self.history) is T + 1 and this numerator # shall start at T-1 _fd = utils.merge_dicts( state_feed_dict, utils.maybe_call(inner_objective_feed_dicts, t)) ss.run(self._alpha_iter, _fd) if len(callback) == 2: utils.maybe_call(callback[1], t, _fd, ss)
def run(self, T_or_generator, inner_objective_feed_dicts=None, outer_objective_feed_dicts=None, initializer_feed_dict=None, global_step=None, session=None, online=False, callback=None): ss = session or tf.get_default_session() inner_objective_feed_dicts = utils.as_tuple_or_list( inner_objective_feed_dicts) if not online: self._run_batch_initialization( ss, utils.maybe_call(initializer_feed_dict, utils.maybe_eval(global_step, ss))) for t in utils.solve_int_or_generator(T_or_generator): _fd = utils.maybe_call(inner_objective_feed_dicts[0], t) self._forward_step(ss, _fd) utils.maybe_call(callback, t, _fd, ss) # end of optimization. Solve linear systems. tol_val = utils.maybe_call(self.tolerance, utils.maybe_eval( global_step, ss)) # decreasing tolerance (seq.) # feed dictionaries (could...in theory, implement stochastic solution of this linear system...) _fd = utils.maybe_call(inner_objective_feed_dicts[-1], -1) _fd_outer = utils.maybe_call(outer_objective_feed_dicts, utils.maybe_eval(global_step, ss)) _fd = utils.merge_dicts(_fd, _fd_outer) for lin_sys in self._lin_sys: lin_sys(tol_val).minimize( ss, _fd) # implicitly warm restarts with previously found q