コード例 #1
0
 def obtain_samples(self,
                    dyn_model=None,
                    itr=None,
                    policy=None,
                    rau=None,
                    delta=0,
                    constraint_fn=None,
                    constraint_cost_fn=None,
                    HCMPC_Activation=False,
                    Constrained=False):
     cur_params = self.algo.policy.get_param_values()
     paths = parallel_sampler.sample_paths(
         policy_params=cur_params,
         max_samples=self.algo.batch_size,
         dyn_model=dyn_model,
         max_path_length=self.algo.max_path_length,
         scope=self.algo.scope,
         policy=policy,
         rau=rau,
         delta=delta,
         constraint_fn=constraint_fn,
         constraint_cost_fn=constraint_cost_fn,
         HCMPC_Activation=HCMPC_Activation,
         Constrained=Constrained,
     )
     if self.algo.whole_paths:
         return paths
     else:
         paths_truncated = parallel_sampler.truncate_paths(
             paths, self.algo.batch_size)
         return paths_truncated
コード例 #2
0
def test_truncate_paths():
    from rllab.sampler.parallel_sampler import truncate_paths

    paths = [
        dict(
            observations=np.zeros((100, 1)),
            actions=np.zeros((100, 1)),
            rewards=np.zeros(100),
            env_infos=dict(),
            agent_infos=dict(lala=np.zeros(100)),
        ),
        dict(
            observations=np.zeros((50, 1)),
            actions=np.zeros((50, 1)),
            rewards=np.zeros(50),
            env_infos=dict(),
            agent_infos=dict(lala=np.zeros(50)),
        ),
    ]

    truncated = truncate_paths(paths, 130)
    assert len(truncated) == 2
    assert len(truncated[-1]["observations"]) == 30
    assert len(truncated[0]["observations"]) == 100
    # make sure not to change the original one
    assert len(paths) == 2
    assert len(paths[-1]["observations"]) == 50
コード例 #3
0
 def obtain_samples(self, itr):
     # print("obtain samples in batch_polopt")
     cur_params = self.algo.policy.get_param_values()  # a list of numbers
     try:
         cur_low_params = self.algo.low_policy.get_param_values()
         # env_params = cur_low_params if self.algo.train_low else None # need to reset low policy only when training low!
         paths = parallel_sampler.sample_paths(
             policy_params=cur_params,
             low_policy_params=
             cur_low_params,  # low policy params as env params!
             env_params=[self.algo.env.time_steps_agg,
                         self.algo],  # the parameters to recover for env!
             max_samples=self.algo.batch_size,
             max_path_length=self.algo.max_path_length,
             scope=self.algo.scope,
         )
     except AttributeError:
         paths = parallel_sampler.sample_paths(
             policy_params=cur_params,
             max_samples=self.algo.batch_size,
             max_path_length=self.algo.max_path_length,
             scope=self.algo.scope,
         )
     if self.algo.whole_paths:  # this line is run (whole path)
         return paths
     else:
         paths_truncated = parallel_sampler.truncate_paths(
             paths, self.algo.batch_size)
         return paths_truncated
コード例 #4
0
 def obtain_samples(self, itr):
     if self.algo.ma_mode == 'concurrent':
         cur_policy_params = [
             policy.get_param_values() for policy in self.algo.policies
         ]
     else:
         cur_policy_params = self.algo.policy.get_param_values()
     if hasattr(self.algo.env, "get_param_values"):
         cur_env_params = self.algo.env.get_param_values()
     else:
         cur_env_params = None
     paths = ma_sampler.sample_paths(
         policy_params=cur_policy_params,
         env_params=cur_env_params,
         max_samples=self.algo.batch_size,
         max_path_length=self.algo.max_path_length,
         ma_mode=self.algo.ma_mode,
         scope=self.algo.scope,
     )
     if self.algo.whole_paths:
         return paths
     else:
         paths_truncated = parallel_sampler.truncate_paths(
             paths, self.algo.batch_size)
         return paths_truncated
コード例 #5
0
 def obtain_samples(self, itr):
     cur_params = self.policy.get_param_values()
     paths = parallel_sampler.sample_paths(
         policy_params=cur_params,
         max_samples=self.batch_size,
         max_path_length=self.max_path_length,
     )
     if self.whole_paths:
         return paths
     else:
         paths_truncated = parallel_sampler.truncate_paths(paths, self.batch_size)
         return paths_truncated
コード例 #6
0
 def obtain_samples(self, itr):
     cur_params = self.algo.policy.get_param_values()
     paths = self.sample_paths(  # use the sample function above
         policy_params=cur_params,
         max_samples=self.algo.batch_size,
         max_path_length=self.algo.max_path_length,
         scope=self.algo.scope,
     )
     if self.algo.whole_paths:
         return paths
     else:
         paths_truncated = parallel_sampler.truncate_paths(paths, self.algo.batch_size)
         return paths_truncated
コード例 #7
0
 def obtain_samples(self, itr):
     cur_pro_params = self.algo.pro_policy.get_param_values()
     cur_adv_params = self.algo.adv_policy.get_param_values()
     paths = parallel_sampler.sample_paths(
         pro_policy_params=cur_pro_params,
         max_samples=self.algo.batch_size,
         max_path_length=self.algo.max_path_length,
         scope=self.algo.scope,
         adv_policy_params=cur_adv_params
     )
     if self.algo.whole_paths:
         return paths
     else:
         paths_truncated = parallel_sampler.truncate_paths(paths, self.algo.batch_size)
         return paths_truncated
コード例 #8
0
 def obtain_samples(self, itr):
     cur_params = self.algo.policy.get_param_values()
     paths = parallel_sampler.sample_paths(
         policy_params=cur_params,  # TODO - can I just pass in new parameters here? (the updated ones?)
         max_samples=self.algo.batch_size,
         max_path_length=self.algo.max_path_length,
         scope=self.algo.scope,
     )
     # TODO - does the optimizer assume that the paths came from a policy with params cur_params?
     # Or can I just pass in cur_params - alpha*grads?
     if self.algo.whole_paths:
         return paths
     else:
         paths_truncated = parallel_sampler.truncate_paths(paths, self.algo.batch_size)
         return paths_truncated
コード例 #9
0
ファイル: batch_polopt.py プロジェクト: zizai/EMI
 def obtain_samples(self, itr):
     cur_params = self.algo.policy.get_param_values()
     paths = parallel_sampler.sample_paths(
         policy_params=cur_params,
         max_samples=self.algo.batch_size,
         include_original_frames=True,
         max_path_length=self.algo.max_path_length,
         scope=self.algo.scope,
     )
     if self.algo.whole_paths:
         return paths
     else:
         paths_truncated = parallel_sampler.truncate_paths(
             paths, self.algo.batch_size)
         return paths_truncated
コード例 #10
0
ファイル: BP.py プロジェクト: hl00/maml_rl
 def obtain_samples(self, itr):
     cur_params = self.algo.policy.get_param_values()
     paths = parallel_sampler.sample_paths(
         policy_params=cur_params,  # TODO - can I just pass in new parameters here? (the updated ones?)
         max_samples=self.algo.batch_size,
         max_path_length=self.algo.max_path_length,
         scope=self.algo.scope,
     )
     # TODO - does the optimizer assume that the paths came from a policy with params cur_params?
     # Or can I just pass in cur_params - alpha*grads?
     if self.algo.whole_paths:
         return paths
     else:
         paths_truncated = parallel_sampler.truncate_paths(paths, self.algo.batch_size)
         return paths_truncated
コード例 #11
0
 def obtain_samples(self, itr, determ=False):
     cur_policy_params = self.algo.policy.get_param_values()
     cur_env_params = self.algo.env.get_param_values()
     paths = parallel_sampler.sample_paths(
         policy_params=cur_policy_params,
         env_params=cur_env_params,
         max_samples=self.algo.batch_size,
         max_path_length=self.algo.max_path_length,
         scope=self.algo.scope,
     )
     if self.algo.whole_paths:
         return paths
     else:
         paths_truncated = parallel_sampler.truncate_paths(
             paths, self.algo.batch_size)
         return paths_truncated
コード例 #12
0
 def obtain_samples(self, itr):
     if config.TF_NN_SETTRACE:
         ipdb.set_trace()
     cur_policy_params = self.algo.policy.get_param_values()
     paths = parallel_sampler.sample_paths(
         policy_params=cur_policy_params,
         env_params=None,
         max_samples=self.algo.batch_size,
         max_path_length=self.algo.max_path_length,
         scope=self.algo.scope,
     )
     if self.algo.whole_paths:
         return paths
     else:
         paths_truncated = parallel_sampler.truncate_paths(
             paths, self.algo.batch_size)
         return paths_truncated
コード例 #13
0
 def obtain_samples(self, itr):
     if hasattr(self.algo.policy, 'get_param_values_with_baseline'):
         cur_params = self.algo.policy.get_param_values_with_baseline()
     else:
         cur_params = self.algo.policy.get_param_values()
     paths = parallel_sampler.sample_paths(
         policy_params=cur_params,
         max_samples=self.algo.batch_size,
         max_path_length=self.algo.max_path_length,
         scope=self.algo.scope,
     )
     if self.algo.whole_paths:
         return paths
     else:
         paths_truncated = parallel_sampler.truncate_paths(
             paths, self.algo.batch_size)
         return paths_truncated
コード例 #14
0
 def obtain_samples(self, itr, include_joint_coords=False):
     # TODO: include_joint_coords not supported for BatchSampler yet.
     cur_policy_params = self.algo.policy.get_param_values()
     cur_env_params = self.algo.env.get_param_values()
     paths = parallel_sampler.sample_paths(
         policy_params=cur_policy_params,
         env_params=cur_env_params,
         max_samples=self.algo.batch_size,
         max_path_length=self.algo.max_path_length,
         scope=self.algo.scope,
     )
     if self.algo.whole_paths:
         return paths
     else:
         paths_truncated = parallel_sampler.truncate_paths(
             paths, self.algo.batch_size)
         return paths_truncated
コード例 #15
0
 def obtain_samples(self, itr):
     cur_policy_params = self.algo.policy.get_param_values()
     #if hasattr(self.algo.env,"get_param_values"):
         #cur_env_params = self.algo.env.get_param_values()
     #else:
         #cur_env_params = None
     paths = parallel_sampler.sample_paths(
         policy_params=cur_policy_params,
         env_params=None,
         max_samples=self.algo.batch_size,
         max_path_length=self.algo.max_path_length,
         scope=self.algo.scope,
     )
     if self.algo.whole_paths:
         return paths
     else:
         paths_truncated = parallel_sampler.truncate_paths(paths, self.algo.batch_size)
         return paths_truncated
コード例 #16
0
ファイル: batch_sampler.py プロジェクト: paulhendricks/rllab
 def obtain_samples(self, itr):
     cur_policy_params = self.algo.policy.get_param_values()
     if hasattr(self.algo.env,"get_param_values"):
         cur_env_params = self.algo.env.get_param_values()
     else:
         cur_env_params = None
     paths = parallel_sampler.sample_paths(
         policy_params=cur_policy_params,
         env_params=cur_env_params,
         max_samples=self.algo.batch_size,
         max_path_length=self.algo.max_path_length,
         scope=self.algo.scope,
     )
     if self.algo.whole_paths:
         return paths
     else:
         paths_truncated = parallel_sampler.truncate_paths(paths, self.algo.batch_size)
         return paths_truncated
コード例 #17
0
    def obtain_samples(self, itr, target_task=None):
        cur_params = self.algo.policy.get_param_values()

        paths = parallel_sampler.sample_paths(
            policy_params=cur_params,
            max_samples=self.algo.batch_size,
            max_path_length=self.algo.max_path_length,
            scope=self.algo.scope,
            iter=itr,
            policy=self.algo.policy,
            env=self.algo.env,
            baseline=self.algo.baseline,
            target_task=target_task,
        )

        if self.algo.whole_paths:
            return paths
        else:
            paths_truncated = parallel_sampler.truncate_paths(
                paths, self.algo.batch_size)
            return paths_truncated
コード例 #18
0
    def obtain_samples(self, itr):
        cur_params = self.algo.policy.get_param_values()
        raw_paths = parallel_sampler.sample_paths(
            policy_params=cur_params,
            max_samples=self.algo.batch_size,
            max_path_length=self.algo.max_path_length,
            scope=self.algo.scope,
        )
        if self.period is None:  # hippo random p
            paths = raw_paths
        else:
            #todo: this will break for environments where the rollout terminates after goal is reached
            paths = []
            for path in raw_paths:
                new_length = (len(path['rewards']) //
                              self.period) * self.period
                for key in path.keys():
                    if isinstance(path[key], dict):
                        for key2 in path[key].keys():
                            path[key][key2] = path[key][key2][:new_length]
                    else:
                        path[key] = path[key][:new_length]
                if len(path['rewards']) > 0:
                    paths.append(path)

                # num_padding = self.period - (len(path['rewards']) % self.period)
                # for key in path.keys():
                #     if isinstance(path[key], dict):
                #         for key2 in path[key].keys():
                #             path[key][key2].
            # paths = raw_paths

        if self.algo.whole_paths:
            return paths
        else:
            paths_truncated = parallel_sampler.truncate_paths(
                paths, self.algo.batch_size)
            return paths_truncated