def _adapt(self, samples, extra_feed_dict={}): """ Performs MAML inner step for each task and stores the updated parameters in the policy Args: samples (list) : list of dicts of samples (each is a dict) split by meta task """ assert len(samples) == self.meta_batch_size assert [sample_dict.keys() for sample_dict in samples] sess = tf.get_default_session() # prepare feed dict input_dict = self._extract_input_dict(samples, self._optimization_keys, prefix='adapt') input_ph_dict = self.adapt_input_ph_dict feed_dict_inputs = create_feed_dict(placeholder_dict=input_ph_dict, value_dict=input_dict) feed_dict_params = self.policy.policies_params_feed_dict feed_dict = {**feed_dict_inputs, **feed_dict_params, **extra_feed_dict} if len(extra_feed_dict) == 0: param_phs = self.adapted_policies_params else: param_phs = self.mod_adapted_policies_params # compute the post-update / adapted policy parameters adapted_policies_params_vals = sess.run(param_phs, feed_dict=feed_dict) # store the new parameter values in the policy self.policy.update_task_parameters(adapted_policies_params_vals)
def compute_gradients(self, all_samples_data, log=True): meta_op_input_dict = self._extract_input_dict_meta_op( all_samples_data, self._optimization_keys) feed_dict = utils.create_feed_dict( placeholder_dict=self.meta_op_phs_dict, value_dict=meta_op_input_dict) if log: logger.log("compute gradients") gradients_values = tf.get_default_session().run(self.gradients, feed_dict=feed_dict) return gradients_values
def _ng_adapt(self, samples): """ Performs MAML inner step for each task and stores the updated parameters in the policy Args: samples (list) : list of dicts of samples (each is a dict) split by meta task """ assert len(samples) == self.meta_batch_size assert [sample_dict.keys() for sample_dict in samples] sess = tf.get_default_session() # prepare feed dict input_dict = self._extract_input_dict(samples, self._optimization_keys, prefix='adapt') input_ph_dict = self.adapt_input_ph_dict feed_dict_inputs = utils.create_feed_dict( placeholder_dict=input_ph_dict, value_dict=input_dict) feed_dict_params = self.policy.policies_params_feed_dict feed_dict = { **feed_dict_inputs, **feed_dict_params } # merge the two feed dicts # compute the post-update / adapted policy parameters # adapted_policies_params_vals = sess.run(self.adapted_policies_params, feed_dict=feed_dict) for i in range(self.meta_batch_size): Hx = self.hvps[i].build_eval(input_dict) Hx = self._hvp_approach.build_eval(input_dict) # todo gradient = sess.run(self.ng_gradients, feed_dict=input_dict) flatten_grad = _flatten_params(gradient) descent_direction = conjugate_gradients(Hx, flatten_grad, cg_iters=10) descent_direction = _unflatten_params(descent_direction, gradient) params = self.policy.get_param_values() # todo match or not update_param_keys = list(params.keys()) adapted_policy_params = [ params[key] - tf.multiply(self.step_sizes[key], descent_direction[key]) for key in update_param_keys ] # store the new parameter values in the policy self.policy.update_task_parameters(adapted_policy_params)
def create_feed_dict(self, input_val_dict): return utils.create_feed_dict(placeholder_dict=self._input_ph_dict, value_dict=input_val_dict)