Ejemplo n.º 1
0
 def _state_feed_dict_generator(self, history, T_or_generator):
     for t, his in zip(utils.solve_int_or_generator(T_or_generator),
                       history):
         yield t, utils.merge_dicts(*[
             od.state_feed_dict(h)
             for od, h in zip(sorted(self._optimizer_dicts), his)
         ])
Ejemplo n.º 2
0
    def run(self,
            T_or_generator,
            inner_objective_feed_dicts=None,
            outer_objective_feed_dicts=None,
            initializer_feed_dict=None,
            global_step=None,
            session=None,
            online=False,
            callback=None):

        ss = session or tf.get_default_session()

        if online == False:
            self._run_batch_initialization(
                ss,
                utils.maybe_call(initializer_feed_dict,
                                 utils.maybe_eval(global_step, ss)))

        elif online == 'wr':
            self._run_wr_initialization(
                ss,
                utils.maybe_call(initializer_feed_dict,
                                 utils.maybe_eval(global_step, ss)))

        for t in utils.solve_int_or_generator(T_or_generator):
            _fd = utils.maybe_call(inner_objective_feed_dicts, t)
            self._forward_step(ss, _fd)
            utils.maybe_call(callback, t, _fd, ss)
Ejemplo n.º 3
0
    def run(self,
            T_or_generator,
            inner_objective_feed_dicts=None,
            outer_objective_feed_dicts=None,
            initializer_feed_dict=None,
            global_step=None,
            session=None,
            online=False,
            callback=None):
        # callback may be a pair, first for froward pass, second for reverse pass
        callback = utils.as_tuple_or_list(callback)
        # same thing for T
        T_or_generator = utils.as_tuple_or_list(T_or_generator)

        ss = session or tf.get_default_session()

        del self._history[:]
        if not online:
            _fd = utils.maybe_call(initializer_feed_dict,
                                   utils.maybe_eval(global_step, ss))
            self._save_history(ss.run(self.initialization, feed_dict=_fd))

        # else:  # not totally clear if i should add this
        #     self._save_history(ss.run(list(self.state)))

        T = 0  # this is useful if T_or_generator is indeed a generator...
        for t in utils.solve_int_or_generator(T_or_generator[0]):
            # nonlocal t  # with nonlocal would not be necessary the variable T... not compatible with 2.7
            _fd = utils.maybe_call(inner_objective_feed_dicts, t)
            self._save_history(ss.run(self.iteration, feed_dict=_fd))
            utils.maybe_call(callback[0], t, _fd, ss)
            T = t

        # initialization of support variables (supports stochastic evaluation of outer objective via global_step ->
        # variable)
        # TODO (maybe tf bug or oddity) for some strange reason, if some variable's initializer depends on
        # a placeholder, then the initializer of alpha SEEMS TO DEPEND ALSO ON THAT placeholder,
        # as if the primary variable should be reinitialized as well, but, I've checked, the primary variable is NOT
        # actually reinitialized. This doesn't make sense since the primary variable is already initialized
        # and Tensorflow seems not to care... should maybe look better into this issue
        reverse_init_fd = utils.maybe_call(outer_objective_feed_dicts,
                                           utils.maybe_eval(global_step, ss))
        # now adding also the initializer_feed_dict because of tf quirk...
        maybe_init_fd = utils.maybe_call(initializer_feed_dict,
                                         utils.maybe_eval(global_step, ss))
        reverse_init_fd = utils.merge_dicts(reverse_init_fd, maybe_init_fd)
        ss.run(self._reverse_initializer, feed_dict=reverse_init_fd)

        for pt, state_feed_dict in self._state_feed_dict_generator(
                reversed(self._history[:-1]), T_or_generator[-1]):
            # this should be fine also for truncated reverse... but check again the index t
            t = T - pt - 1  # if T is int then len(self.history) is T + 1 and this numerator
            # shall start at T-1
            _fd = utils.merge_dicts(
                state_feed_dict, utils.maybe_call(inner_objective_feed_dicts,
                                                  t))
            ss.run(self._alpha_iter, _fd)
            if len(callback) == 2: utils.maybe_call(callback[1], t, _fd, ss)
Ejemplo n.º 4
0
    def run(self,
            T_or_generator,
            inner_objective_feed_dicts=None,
            outer_objective_feed_dicts=None,
            initializer_feed_dict=None,
            global_step=None,
            session=None,
            online=False,
            callback=None):
        ss = session or tf.get_default_session()

        inner_objective_feed_dicts = utils.as_tuple_or_list(
            inner_objective_feed_dicts)
        if not online:
            self._run_batch_initialization(
                ss,
                utils.maybe_call(initializer_feed_dict,
                                 utils.maybe_eval(global_step, ss)))

        for t in utils.solve_int_or_generator(T_or_generator):
            _fd = utils.maybe_call(inner_objective_feed_dicts[0], t)
            self._forward_step(ss, _fd)
            utils.maybe_call(callback, t, _fd, ss)

        # end of optimization. Solve linear systems.
        tol_val = utils.maybe_call(self.tolerance,
                                   utils.maybe_eval(
                                       global_step,
                                       ss))  # decreasing tolerance (seq.)
        # feed dictionaries (could...in theory, implement stochastic solution of this linear system...)
        _fd = utils.maybe_call(inner_objective_feed_dicts[-1], -1)
        _fd_outer = utils.maybe_call(outer_objective_feed_dicts,
                                     utils.maybe_eval(global_step, ss))
        _fd = utils.merge_dicts(_fd, _fd_outer)

        for lin_sys in self._lin_sys:
            lin_sys(tol_val).minimize(
                ss, _fd)  # implicitly warm restarts with previously found q