Пример #1
0
    def _adapt(self, samples):
        """
        Performs MAML inner step for each task and stores the updated parameters in the policy

        Args:
            samples (list) : list of dicts of samples (each is a dict) split by meta task

        """
        assert len(samples) == self.meta_batch_size
        assert [sample_dict.keys() for sample_dict in samples]
        sess = tf.get_default_session()

        # prepare feed dict
        input_dict = self._extract_input_dict(samples, self._optimization_keys, prefix='adapt')
        input_ph_dict = self.adapt_input_ph_dict

        feed_dict_inputs = utils.create_feed_dict(placeholder_dict=input_ph_dict, value_dict=input_dict)
        feed_dict_params = self.policy.policies_params_feed_dict

        feed_dict = {**feed_dict_inputs, **feed_dict_params}  # merge the two feed dicts

        # compute the post-update / adapted policy parameters
        adapted_policies_params_vals = sess.run(self.adapted_policies_params, feed_dict=feed_dict)

        # store the new parameter values in the policy
        self.policy.update_task_parameters(adapted_policies_params_vals)
Пример #2
0
 def optimize_policy(self, buffer, timestep, grad_steps, log=True):
     sess = tf.get_default_session()
     for i in range(grad_steps):
         feed_dict = create_feed_dict(placeholder_dict=self.op_phs_dict,
                                      value_dict=buffer.random_batch(
                                          self.sampler_batch_size))
         sess.run(self.training_ops, feed_dict)
         if log:
             diagnostics = sess.run({**self.diagnostics_ops}, feed_dict)
             for k, v in diagnostics.items():
                 logger.logkv(k, v)
         if timestep % self.target_update_interval == 0:
             self._update_target()
Пример #3
0
    def optimize_policy(self, samples_data, log=True):
        sess = tf.get_default_session()
        input_dict = self._extract_input_dict(samples_data, self._optimization_keys, prefix='train')
        if self._dataset is None:
            self._dataset = input_dict

        else:
            for k, v in input_dict.items():
                n_new_samples = len(v)
                n_max = self.buffer_size - n_new_samples
                self._dataset[k] = np.concatenate([self._dataset[k][-n_max:], v], axis=0)

        num_elements = len(list(self._dataset.values())[0])
        policy_params_dict = self.policy.policy_params_feed_dict
        for _ in range(self.num_grad_steps):
            idxs = np.random.randint(0, num_elements, size=self.batch_size)
            batch_dict = self._get_indices_from_dict(self._dataset, idxs)

            feed_dict = create_feed_dict(placeholder_dict=self.op_phs_dict, value_dict=batch_dict)
            feed_dict.update(policy_params_dict)

            _ = sess.run(self._train_op, feed_dict=feed_dict)
Пример #4
0
 def create_feed_dict(self, input_val_dict):
     return utils.create_feed_dict(placeholder_dict=self._input_ph_dict,
                                   value_dict=input_val_dict)