def _adapt(self, samples): """ Performs MAML inner step for each task and stores the updated parameters in the policy Args: samples (list) : list of dicts of samples (each is a dict) split by meta task """ assert len(samples) == self.meta_batch_size assert [sample_dict.keys() for sample_dict in samples] sess = tf.get_default_session() # prepare feed dict input_dict = self._extract_input_dict(samples, self._optimization_keys, prefix='adapt') input_ph_dict = self.adapt_input_ph_dict feed_dict_inputs = utils.create_feed_dict(placeholder_dict=input_ph_dict, value_dict=input_dict) feed_dict_params = self.policy.policies_params_feed_dict feed_dict = {**feed_dict_inputs, **feed_dict_params} # merge the two feed dicts # compute the post-update / adapted policy parameters adapted_policies_params_vals = sess.run(self.adapted_policies_params, feed_dict=feed_dict) # store the new parameter values in the policy self.policy.update_task_parameters(adapted_policies_params_vals)
def optimize_policy(self, buffer, timestep, grad_steps, log=True): sess = tf.get_default_session() for i in range(grad_steps): feed_dict = create_feed_dict(placeholder_dict=self.op_phs_dict, value_dict=buffer.random_batch( self.sampler_batch_size)) sess.run(self.training_ops, feed_dict) if log: diagnostics = sess.run({**self.diagnostics_ops}, feed_dict) for k, v in diagnostics.items(): logger.logkv(k, v) if timestep % self.target_update_interval == 0: self._update_target()
def optimize_policy(self, samples_data, log=True): sess = tf.get_default_session() input_dict = self._extract_input_dict(samples_data, self._optimization_keys, prefix='train') if self._dataset is None: self._dataset = input_dict else: for k, v in input_dict.items(): n_new_samples = len(v) n_max = self.buffer_size - n_new_samples self._dataset[k] = np.concatenate([self._dataset[k][-n_max:], v], axis=0) num_elements = len(list(self._dataset.values())[0]) policy_params_dict = self.policy.policy_params_feed_dict for _ in range(self.num_grad_steps): idxs = np.random.randint(0, num_elements, size=self.batch_size) batch_dict = self._get_indices_from_dict(self._dataset, idxs) feed_dict = create_feed_dict(placeholder_dict=self.op_phs_dict, value_dict=batch_dict) feed_dict.update(policy_params_dict) _ = sess.run(self._train_op, feed_dict=feed_dict)
def create_feed_dict(self, input_val_dict): return utils.create_feed_dict(placeholder_dict=self._input_ph_dict, value_dict=input_val_dict)