Ejemplo n.º 1
0
    def test_run_continuous(self):
        from tf2rl.algos.ddpg import DDPG
        parser = DDPG.get_argument(self.parser)
        parser.set_defaults(n_warmup=1)
        args, _ = parser.parse_known_args()

        def env_fn():
            return gym.make('Pendulum-v0')

        def policy_fn(env, name, memory_capacity=int(1e6), gpu=-1, *args, **kwargs):
            return DDPG(
                state_shape=env.observation_space.shape,
                action_dim=env.action_space.high.size,
                n_warmup=500,
                gpu=-1)

        def get_weights_fn(policy):
            return [policy.actor.weights,
                    policy.critic.weights,
                    policy.critic_target.weights]

        def set_weights_fn(policy, weights):
            actor_weights, critic_weights, critic_target_weights = weights
            update_target_variables(
                policy.actor.weights, actor_weights, tau=1.)
            update_target_variables(
                policy.critic.weights, critic_weights, tau=1.)
            update_target_variables(
                policy.critic_target.weights, critic_target_weights, tau=1.)

        run(args, env_fn, policy_fn, get_weights_fn, set_weights_fn)
Ejemplo n.º 2
0
def _test_run_continuous(parser):
    from tf2rl.algos.ddpg import DDPG
    parser = DDPG.get_argument(parser)
    args = parser.parse_args()

    def env_fn():
        return gym.make('Pendulum-v0')

    sample_env = env_fn()

    def policy_fn(env, name, memory_capacity=int(1e6), gpu=-1):
        return DDPG(state_shape=env.observation_space.shape,
                    action_dim=env.action_space.high.size,
                    gpu=-1)

    def get_weights_fn(policy):
        return [
            policy.actor.weights, policy.critic.weights,
            policy.critic_target.weights
        ]

    def set_weights_fn(policy, weights):
        actor_weights, critic_weights, critic_target_weights = weights
        update_target_variables(policy.actor.weights, actor_weights, tau=1.)
        update_target_variables(policy.critic.weights, critic_weights, tau=1.)
        update_target_variables(policy.critic_target.weights,
                                critic_target_weights,
                                tau=1.)

    run(args, env_fn, policy_fn, get_weights_fn, set_weights_fn)
Ejemplo n.º 3
0
    def test_run_continuous(self):
        from tf2rl.algos.ddpg import DDPG
        parser = DDPG.get_argument(self.parser)
        parser.set_defaults(n_warmup=1)
        args, _ = parser.parse_known_args()

        run(args, env_fn_continuous, policy_fn_continuous,
            get_weights_fn_continuous, set_weights_fn_continuous)
Ejemplo n.º 4
0
    def get_argument(parser=None):
        """
        Create or update argument parser for command line program

        Args:
            parser (argparse.ArgParser, optional): argument parser

        Returns:
            argparse.ArgParser: argument parser
        """
        parser = DDPG.get_argument(parser)
        parser.add_argument('--eta', type=float, default=0.05)
        return parser
Ejemplo n.º 5
0
                critic_units=[400, 300],
                memory_capacity=memory_capacity)


def get_weights_fn(policy):
    # TODO: Check if following needed
    import tensorflow as tf
    with tf.device(policy.device):
        return [
            policy.actor.weights, policy.critic.weights,
            policy.critic_target.weights
        ]


def set_weights_fn(policy, weights):
    actor_weights, critic_weights, critic_target_weights = weights
    update_target_variables(policy.actor.weights, actor_weights, tau=1.)
    update_target_variables(policy.critic.weights, critic_weights, tau=1.)
    update_target_variables(policy.critic_target.weights,
                            critic_target_weights,
                            tau=1.)


if __name__ == '__main__':
    parser = apex_argument()
    parser.add_argument('--env-name', type=str, default="Pendulum-v0")
    parser = DDPG.get_argument(parser)
    args = parser.parse_args()

    run(args, env_fn(args.env_name), policy_fn, get_weights_fn, set_weights_fn)
Ejemplo n.º 6
0
 def get_argument(parser=None):
     parser = DDPG.get_argument(parser)
     parser.add_argument('--eta', type=float, default=0.05)
     return parser