コード例 #1
0
ファイル: run_alg.py プロジェクト: ASzot/rl-toolkit
    def get_policy(self):
        alg = self.base_args.alg
        if alg in DIST_ACTOR_CRITIC:
            return DistActorCritic()
        elif alg in REG_ACTOR_CRITIC:
            # DDPG is hard to train, make some changes to the base actor
            if alg == 'ddpg' and self.base_args.env_name == 'Hopper-v3':
                return RegActorCritic(
                    get_actor_head_fn=lambda _, i_shape: MLPBase(
                        i_shape[0], False, (128, 128)),
                    get_critic_head_fn=lambda _, i_shape, a_space: TwoLayerMlpWithAction(
                        i_shape[0], (128, 128), a_space.shape[0])
                )

            if alg == 'ddpg' and self.base_args.env_name == 'MountainCarContinuous-v0':
                return RegActorCritic(
                    get_actor_head_fn=lambda _, i_shape: MLPBase(
                        i_shape[0], False, (400, 300)),
                    get_critic_head_fn=lambda _, i_shape, a_space: TwoLayerMlpWithAction(
                        i_shape[0], (400, 300), a_space.shape[0])
                )
            else:
                return RegActorCritic()
        elif alg in NO_CRITIC:
            return DQN()
        elif alg in BASIC_POLICY:
            return BasicPolicy()
        elif alg in RND_POLICY:
            return RandomPolicy()
        else:
            raise ValueError('Unrecognized alg for policy architecture')
コード例 #2
0
ファイル: utils.py プロジェクト: ASzot/rl-toolkit
def get_mlp_net_var_out_fn(hidden_sizes):
    """
    Same as `get_mlp_net_fn` but you can specify a variable output size later
    after the function is returned.
    """
    return lambda i_shp, n_out: MLPBase(i_shp[0],
                                        False, [*hidden_sizes, n_out],
                                        weight_init=reg_mlp_weight_init,
                                        no_last_act=True)
コード例 #3
0
ファイル: utils.py プロジェクト: ASzot/rl-toolkit
def get_mlp_net_out_fn(hidden_sizes):
    """
    Gives a function returning an MLP base with the specified architecture that
    takes as input the shape. This is an easy way to create default NN creation
    functions that can be later overriden.
    Returns: (i_shape -> MLPBase)
    """
    return lambda i_shape: MLPBase(i_shape[0],
                                   False,
                                   hidden_sizes,
                                   weight_init=reg_mlp_weight_init,
                                   no_last_act=True)
コード例 #4
0
 def get_policy(self):
     hidden_size = 256
     if 'BitFlip' in self.base_args.env_name:
         return DQN(get_base_net_fn=lambda i_shape, recurrent: MLPBase(
             i_shape[0],
             False, (hidden_size, ),
             weight_init=reg_init,
             get_activation=lambda: nn.ReLU()),
                    use_goal=True)
     else:
         return RegActorCritic(
             get_actor_fn=lambda _, i_shape: MLPBase(
                 i_shape[0],
                 False, (hidden_size, hidden_size),
                 weight_init=reg_init,
                 get_activation=lambda: nn.ReLU()),
             get_actor_head_fn=get_actor_head,
             get_critic_fn=lambda _, i_shape, a_space:
             TwoLayerMlpWithAction(i_shape[0], (hidden_size, hidden_size),
                                   a_space.shape[0],
                                   weight_init=reg_init,
                                   get_activation=lambda: nn.ReLU()),
             get_critic_head_fn=lambda hidden_dim: nn.Linear(hidden_dim, 1),
             use_goal=True)
コード例 #5
0
ファイル: distributions.py プロジェクト: ASzot/rl-toolkit
    def __init__(self, num_inputs, num_outputs, hidden_dim, hidden_depth,
                 log_std_bounds):
        super().__init__()

        dims = [hidden_dim] * hidden_depth
        dims.append(2 * num_outputs)

        self.log_std_bounds = log_std_bounds
        self.trunk = MLPBase(num_inputs,
                             False,
                             dims,
                             weight_init=reg_mlp_weight_init,
                             get_activation=lambda: nn.ReLU(inplace=True),
                             no_last_act=True)

        self.apply(no_bias_weight_init)
コード例 #6
0
ファイル: def.py プロジェクト: ASzot/rl-toolkit
 def get_policy(self):
     if 'Pendulum' in self.base_args.env_name:
         hidden_size = 128
     else:
         hidden_size = 256
     return RegActorCritic(
             get_actor_fn=lambda _, i_shape: MLPBase(
                 i_shape[0], False, (hidden_size, hidden_size),
                 weight_init=reg_init,
                 get_activation=lambda: nn.ReLU()),
             get_actor_head_fn=get_actor_head,
             get_critic_fn=lambda _, i_shape, a_space: TwoLayerMlpWithAction(
                 i_shape[0], (hidden_size, hidden_size), a_space.shape[0],
                 weight_init=reg_init,
                 get_activation=lambda: nn.ReLU()),
             get_critic_head_fn = lambda hidden_dim: nn.Linear(hidden_dim, 1)
             )
コード例 #7
0
ファイル: def.py プロジェクト: ASzot/rl-toolkit
 def get_policy(self):
     return DistActorCritic(
         get_actor_fn=lambda _, i_shape: MLPBase(i_shape[0], False,
                                                 (400, 300)),
         get_critic_fn=lambda _, i_shape, a_shape: MLPBase(
             i_shape[0], False, (400, 300)))
コード例 #8
0
 def get_policy(self):
     return BasicPolicy(is_stoch=self.base_args.stoch_policy,
                        get_base_net_fn=lambda i_shape, recurrent: MLPBase(
                            i_shape[0], False, (400, 300)))