예제 #1
0
파일: in_place.py 프로젝트: NagisaZj/oyster
 def obtain_samples(self,
                    deterministic=False,
                    max_samples=np.inf,
                    max_trajs=np.inf,
                    accum_context=True,
                    resample=1):
     """
     Obtains samples in the environment until either we reach either max_samples transitions or
     num_traj trajectories.
     The resample argument specifies how often (in trajectories) the agent will resample it's context.
     """
     assert max_samples < np.inf or max_trajs < np.inf, "either max_samples or max_trajs must be finite"
     policy = MakeDeterministic(
         self.policy) if deterministic else self.policy
     paths = []
     n_steps_total = 0
     n_trajs = 0
     while n_steps_total < max_samples and n_trajs < max_trajs:
         path = rollout(self.env,
                        policy,
                        max_path_length=self.max_path_length,
                        accum_context=accum_context)
         # save the latent context that generated this trajectory
         path['context'] = policy.z.detach().cpu().numpy()
         paths.append(path)
         n_steps_total += len(path['observations'])
         n_trajs += 1
         # don't we also want the option to resample z ever transition?
         if n_trajs % resample == 0:
             policy.sample_z()
     return paths, n_steps_total
예제 #2
0
    def obtain_samples(self,
                       deterministic=False,
                       max_samples=np.inf,
                       max_trajs=np.inf,
                       accum_context=True,
                       resample=1,
                       testing=False):
        assert max_samples < np.inf or max_trajs < np.inf, "either max_samples or max_trajs must be finite"
        policy = MakeDeterministic(
            self.policy) if deterministic else self.policy
        paths = []
        n_steps_total = 0
        n_trajs = 0

        if self.itr <= self.num_train_itr:
            if self.tandem_train:
                self._train(policy, accum_context)
                self.itr += 1
            else:
                for _ in range(self.num_train_itr):
                    self._train(policy, accum_context)
                    self.itr += 1

        while n_steps_total < max_samples and n_trajs < max_trajs:
            if testing:
                path = rollout(self.env,
                               policy,
                               max_path_length=self.max_path_length,
                               accum_context=accum_context)
            else:
                path = rollout(self.model,
                               policy,
                               max_path_length=self.max_path_length,
                               accum_context=accum_context)

            # save the latent context that generated this trajectory
            path['context'] = policy.z.detach().cpu().numpy()
            paths.append(path)
            n_steps_total += len(path['observations'])
            n_trajs += 1
            # don't we also want the option to resample z ever transition?
            if n_trajs % resample == 0:
                policy.sample_z()

        return paths, n_steps_total