Exemple #1
0
 def __init__(
         self,
         *args,
         her_kwargs,
         tsac_kwargs,
         **kwargs
 ):
     HER.__init__(self, **her_kwargs)
     TwinSAC.__init__(self, *args, **kwargs, **tsac_kwargs)
     assert isinstance(
         self.replay_buffer, RelabelingReplayBuffer
     ) or isinstance(
         self.replay_buffer, ObsDictRelabelingBuffer
     )
Exemple #2
0
 def __init__(self,
              *args,
              observation_key='observation',
              desired_goal_key='desired_goal',
              **kwargs):
     HER.__init__(
         self,
         observation_key=observation_key,
         desired_goal_key=desired_goal_key,
     )
     TwinSAC.__init__(self, *args, **kwargs)
     assert isinstance(self.replay_buffer,
                       RelabelingReplayBuffer) or isinstance(
                           self.replay_buffer, ObsDictRelabelingBuffer)
Exemple #3
0
def experiment(variant):
    import gym
    env = NormalizedBoxEnv(gym.make('HalfCheetah-v2'))

    obs_dim = int(np.prod(env.observation_space.shape))
    action_dim = int(np.prod(env.action_space.shape))

    net_size = variant['net_size']
    qf1 = FlattenMlp(
        hidden_sizes=[net_size, net_size],
        input_size=obs_dim + action_dim,
        output_size=1,
    )
    qf2 = FlattenMlp(
        hidden_sizes=[net_size, net_size],
        input_size=obs_dim + action_dim,
        output_size=1,
    )
    vf = FlattenMlp(
        hidden_sizes=[net_size, net_size],
        input_size=obs_dim,
        output_size=1,
    )
    policy = TanhGaussianPolicy(
        hidden_sizes=[net_size, net_size],
        obs_dim=obs_dim,
        action_dim=action_dim,
    )
    algorithm = TwinSAC(
        env=env,
        policy=policy,
        qf1=qf1,
        qf2=qf2,
        vf=vf,
        **variant['algo_params']
    )
    algorithm.to(ptu.device)
    algorithm.train()
Exemple #4
0
 def __init__(self, *args, url_kwargs, tsac_kwargs, **kwargs):
     URL.__init__(self, **url_kwargs)
     TwinSAC.__init__(self, *args, **kwargs, **tsac_kwargs)
     assert isinstance(self.replay_buffer, ObsDictPathReplayBuffer)