def apply_gradients( self, inner_objective_feed_dicts=None, outer_objective_feed_dicts=None, initializer_feed_dict=None, param_dict=OrderedDict(), global_step=None, session=None, ): if self._inner_method == "Aggr": alpha = param_dict["alpha"] t_tensor = param_dict["t_tensor"] ss = session or tf.get_default_session() self._history.clear() _fd = utils.maybe_call(initializer_feed_dict, utils.maybe_eval(global_step, ss)) self._save_history(ss.run(self.initialization, feed_dict=_fd)) # perform one-step update to the task parameters and store weights along the optimization path _fd = inner_objective_feed_dicts if self._inner_method == "Aggr": _fd.update(outer_objective_feed_dicts) if not alpha.get_shape().as_list(): _fd[t_tensor] = float(1.0) else: tmp = np.zeros((alpha.get_shape().as_list()[1], 1)) tmp[0][0] = 1.0 _fd[t_tensor] = tmp self._save_history(ss.run(self.iteration, feed_dict=_fd)) # compute the differentiation part, multiplied by Epsilon with one-step forward pass _fd = utils.maybe_call( outer_objective_feed_dicts, utils.maybe_eval(global_step, ss) ) darts_init_fd = utils.merge_dicts(_fd, inner_objective_feed_dicts) ss.run(self._diff_initializer, feed_dict=darts_init_fd) del self._history[-1] # do not consider the final task parameters # compute the second-order part and add them to the first-order item state_feed_dict = utils.merge_dicts( *[ od.state_feed_dict(h) for od, h in zip(sorted(self._optimizer_dicts), self._history[-1]) ] ) new_fd = utils.merge_dicts(state_feed_dict, inner_objective_feed_dicts) if self._inner_method == "Aggr": new_fd = utils.merge_dicts(new_fd, outer_objective_feed_dicts) # modified - mark if not alpha.shape.as_list(): new_fd[t_tensor] = float(1.0) else: tmp = np.zeros((alpha.get_shape().as_list()[1], 1)) tmp[0][0] = 1 new_fd[t_tensor] = tmp new_fd = utils.merge_dicts(new_fd, outer_objective_feed_dicts) ss.run(self._darts_initializer, new_fd)
def _state_feed_dict_generator(self, history, T_or_generator): for t, his in zip(utils.solve_int_or_generator(T_or_generator), history): yield t, utils.merge_dicts(*[ od.state_feed_dict(h) for od, h in zip(sorted(self._optimizer_dicts), his) ])
def _opt_fd(): _io_fd = (utils.maybe_call(inner_objective_feed_dicts, utils.maybe_eval(self._global_step)) if inner_objective_feed_dicts else {}) _oo_fd = (utils.maybe_call(outer_objective_feed_dicts, utils.maybe_eval(self._global_step)) if outer_objective_feed_dicts else {}) return utils.merge_dicts(_io_fd, _oo_fd)
def _opt_fd(): # e.g. hyper-learning rate is a placeholder _io_fd = (utils.maybe_call(inner_objective_feed_dicts, utils.maybe_eval(self._global_step)) if inner_objective_feed_dicts else {}) _oo_fd = (utils.maybe_call(outer_objective_feed_dicts, utils.maybe_eval(self._global_step)) if outer_objective_feed_dicts else {}) return utils.merge_dicts(_io_fd, _oo_fd)
def apply_gradients( self, inner_objective_feed_dicts=None, outer_objective_feed_dicts=None, initializer_feed_dict=None, param_dict=OrderedDict(), global_step=None, session=None, ): ss = session or tf.get_default_session() inner_objective_feed_dicts = utils.as_tuple_or_list(inner_objective_feed_dicts) self._run_batch_initialization( ss, utils.maybe_call(initializer_feed_dict, utils.maybe_eval(global_step, ss)), ) for t in utils.solve_int_or_generator(param_dict["T"]): _fd = utils.maybe_call(inner_objective_feed_dicts[0], t) self._forward_step(ss, _fd) # end of optimization. Solve linear systems. tol_val = utils.maybe_call( self.tolerance, utils.maybe_eval(global_step, ss) ) # decreasing tolerance (seq.) # feed dictionaries (could...in theory, implement stochastic solution of this linear system...) _fd = utils.maybe_call(inner_objective_feed_dicts[-1], -1) _fd_outer = utils.maybe_call( outer_objective_feed_dicts, utils.maybe_eval(global_step, ss) ) _fd = utils.merge_dicts(_fd, _fd_outer) for lin_sys in self._lin_sys: lin_sys(tol_val).minimize( ss, _fd ) # implicitly warm restarts with previously found q
def apply_gradients( self, inner_objective_feed_dicts=None, outer_objective_feed_dicts=None, initializer_feed_dict=None, param_dict=OrderedDict(), train_batches=None, experiments=[], global_step=None, session=None, ): if self._inner_method == "Aggr": alpha = param_dict["alpha"] t_tensor = param_dict["t_tensor"] # same thing for T T_or_generator = utils.as_tuple_or_list(param_dict["T"]) ss = session or tf.get_default_session() self._history.clear() _fd = utils.maybe_call(initializer_feed_dict, utils.maybe_eval(global_step, ss)) self._save_history(ss.run(self.initialization, feed_dict=_fd)) T = 0 # this is useful if T_or_generator is indeed a generator... for t in utils.solve_int_or_generator(T_or_generator[0]): # nonlocal t # with nonlocal would not be necessary the variable T... not compatible with 2.7 _fd = inner_objective_feed_dicts if self._inner_method == "Aggr": _fd.update(outer_objective_feed_dicts) if not alpha.get_shape().as_list(): _fd[t_tensor] = float(t + 1.0) else: tmp = np.zeros((alpha.get_shape().as_list()[1], 1)) tmp[t][0] = 1.0 _fd[t_tensor] = tmp self._save_history(ss.run(self.iteration, feed_dict=_fd)) T = t # initialization of support variables (supports stochastic evaluation of outer objective via global_step -> # variable) reverse_init_fd = utils.maybe_call(outer_objective_feed_dicts, utils.maybe_eval(global_step, ss)) # now adding also the initializer_feed_dict because of tf quirk... maybe_init_fd = utils.maybe_call(initializer_feed_dict, utils.maybe_eval(global_step, ss)) reverse_init_fd = utils.merge_dicts(reverse_init_fd, maybe_init_fd) ss.run(self._reverse_initializer, feed_dict=reverse_init_fd) del self._history[-1] # do not consider last point for pt, state_feed_dict in self._state_feed_dict_generator( reversed(self._history), T_or_generator[-1]): # this should be fine also for truncated reverse... but check again the index t t = ( T - pt - 1 ) # if T is int then len(self.history) is T + 1 and this numerator new_fd = utils.merge_dicts(state_feed_dict, inner_objective_feed_dicts) if self._inner_method == "Aggr": new_fd = utils.merge_dicts(new_fd, outer_objective_feed_dicts) # modified - mark if not alpha.shape.as_list(): new_fd[t_tensor] = float(t + 2.0) else: tmp = np.zeros((alpha.get_shape().as_list()[1], 1)) tmp[t][0] = 1 new_fd[t_tensor] = tmp ss.run(self._alpha_iter, new_fd)
inner_grad = boml_ho.ll_problem( inner_objective=loss_inner, learning_rate=args.lr, T=args.T, experiment=ex, var_list=ex.model.var_list, ) # define UL objectives and UL calculation process loss_outer = utils.cross_entropy(pred=ex.model.re_forward(ex.x_).out, label=ex.y_) boml_ho.ul_problem( outer_objective=loss_outer, meta_learning_rate=args.meta_lr, inner_grad=inner_grad, meta_param=tf.get_collection(boml.extension.GraphKeys.METAPARAMETERS), ) # aggregate all the defined operations boml_ho.aggregate_all() # meta training iteration with tf.Session() as sess: tf.global_variables_initializer().run(session=sess) for itr in range(args.meta_train_iterations): # generate the feed_dict for calling run() everytime train_batch = BatchQueueMock( dataset.train, 1, args.meta_batch_size, utils.get_rand_state(1) ) tr_fd, v_fd = utils.feed_dict(train_batch.get_single_batch(), ex) # meta training step boml_ho.run(tr_fd, v_fd) if itr % 100 == 0: print(sess.run(loss_inner, utils.merge_dicts(tr_fd, v_fd)))
T=args.T, experiment=ex, var_list=ex.model.var_list, ) # define UL objectives and UL calculation process loss_outer = utils.cross_entropy(pred=ex.model.re_forward(ex.x_).out, label=ex.y_) boml_ho.ul_problem( outer_objective=loss_outer, meta_learning_rate=args.meta_lr, inner_grad=inner_grad, meta_param=tf.get_collection(boml.extension.GraphKeys.METAPARAMETERS), ) # aggregate all the defined operations boml_ho.aggregate_all() # meta training iteration with tf.Session() as sess: tf.global_variables_initializer().run(session=sess) for itr in range(args.meta_train_iterations): # generate the feed_dict for calling run() everytime train_batch = BatchQueueMock(dataset.train, 1, args.meta_batch_size, utils.get_rand_state(1)) tr_fd, v_fd = utils.feed_dict(train_batch.get_single_batch(), ex) # meta training step boml_ho.run(tr_fd, v_fd) if itr % 100 == 0: loss_list = sess.run([loss_inner, loss_outer], utils.merge_dicts(tr_fd, v_fd)) print('Iteration {}: Inner_loss {} , Outer_loss {}'.format( itr, loss_list[0], loss_list[1]))