def hgrads_hvars(self, hyper_list=None, aggregation_fn=None, process_fn=None): """ Method for getting hypergradient and hyperparameters as required by apply_gradient methods from tensorflow optimizers. :param hyper_list: Optional list of hyperparameters to consider. If not provided will get all variables in the hyperparameter collection in the current scope. :param aggregation_fn: Optional operation to aggregate multiple hypergradients (for the same hyperparameter), by default reduce_mean :param process_fn: Optional operation like clipping to be applied. :return: """ if hyper_list is None: hyper_list = utils.hyperparameters(tf.get_variable_scope().name) assert all([h in self._hypergrad_dictionary for h in hyper_list]), 'FINAL ERROR!' if aggregation_fn is None: aggregation_fn = lambda hgrad_list: tf.reduce_mean(hgrad_list, axis=0) def _aggregate_process_manage_collection(_hg_lst): if len(_hg_lst) == 1: # avoid useless operations... aggr = _hg_lst[0] else: with tf.name_scope(_hg_lst[0].op.name): aggr = aggregation_fn(_hg_lst) if len(_hg_lst) > 1 else _hg_lst[0] if process_fn is not None: with tf.name_scope('process_gradients'): aggr = process_fn(aggr) tf.add_to_collection(utils.GraphKeys.HYPERGRADIENTS, aggr) return aggr return [(_aggregate_process_manage_collection(self._hypergrad_dictionary[h]), h) for h in hyper_list]
def compute_gradients(self, outer_objective, optimizer_dict, hyper_list=None): # Doesn't do anything useful here. To be overridden. """ Function overridden by specific methods. :param optimizer_dict: OptimzerDict object resulting from the inner objective optimization. :param outer_objective: A loss function for the hyperparameters (scalar tensor) :param hyper_list: Optional list of hyperparameters to consider. If not provided will get all variables in the hyperparameter collection in the current scope. :return: list of hyperparameters involved in the computation """ assert isinstance(optimizer_dict, OptimizerDict), HyperGradient._ERROR_NOT_OPTIMIZER_DICT.format(optimizer_dict) self._optimizer_dicts.add(optimizer_dict) if hyper_list is None: # get default hyperparameters hyper_list = utils.hyperparameters(tf.get_variable_scope().name) return hyper_list