コード例 #1
0
ファイル: test_apex.py プロジェクト: chentianba/coflowgym
    def test_run_continuous(self):
        from tf2rl.algos.ddpg import DDPG
        parser = DDPG.get_argument(self.parser)
        parser.set_defaults(n_warmup=1)
        args, _ = parser.parse_known_args()

        def env_fn():
            return gym.make('Pendulum-v0')

        def policy_fn(env, name, memory_capacity=int(1e6), gpu=-1, *args, **kwargs):
            return DDPG(
                state_shape=env.observation_space.shape,
                action_dim=env.action_space.high.size,
                n_warmup=500,
                gpu=-1)

        def get_weights_fn(policy):
            return [policy.actor.weights,
                    policy.critic.weights,
                    policy.critic_target.weights]

        def set_weights_fn(policy, weights):
            actor_weights, critic_weights, critic_target_weights = weights
            update_target_variables(
                policy.actor.weights, actor_weights, tau=1.)
            update_target_variables(
                policy.critic.weights, critic_weights, tau=1.)
            update_target_variables(
                policy.critic_target.weights, critic_target_weights, tau=1.)

        run(args, env_fn, policy_fn, get_weights_fn, set_weights_fn)
コード例 #2
0
ファイル: test_apex.py プロジェクト: chentianba/coflowgym
    def test_run_discrete(self):
        from tf2rl.algos.dqn import DQN
        parser = DQN.get_argument(self.parser)
        parser.set_defaults(n_warmup=1)
        args, _ = parser.parse_known_args()

        def env_fn():
            return gym.make("CartPole-v0")

        def policy_fn(env, name, memory_capacity=int(1e6), gpu=-1, *args, **kwargs):
            return DQN(
                name=name,
                state_shape=env.observation_space.shape,
                action_dim=env.action_space.n,
                n_warmup=500,
                target_replace_interval=300,
                batch_size=32,
                memory_capacity=memory_capacity,
                discount=0.99,
                gpu=-1)

        def get_weights_fn(policy):
            return [policy.q_func.weights,
                    policy.q_func_target.weights]

        def set_weights_fn(policy, weights):
            q_func_weights, qfunc_target_weights = weights
            update_target_variables(
                policy.q_func.weights, q_func_weights, tau=1.)
            update_target_variables(
                policy.q_func_target.weights, qfunc_target_weights, tau=1.)

        run(args, env_fn, policy_fn, get_weights_fn, set_weights_fn)
コード例 #3
0
ファイル: test_apex.py プロジェクト: zhb0318/tf2rl
def _test_run_continuous(parser):
    from tf2rl.algos.ddpg import DDPG
    parser = DDPG.get_argument(parser)
    args = parser.parse_args()

    def env_fn():
        return gym.make('Pendulum-v0')

    sample_env = env_fn()

    def policy_fn(env, name, memory_capacity=int(1e6), gpu=-1):
        return DDPG(state_shape=env.observation_space.shape,
                    action_dim=env.action_space.high.size,
                    gpu=-1)

    def get_weights_fn(policy):
        return [
            policy.actor.weights, policy.critic.weights,
            policy.critic_target.weights
        ]

    def set_weights_fn(policy, weights):
        actor_weights, critic_weights, critic_target_weights = weights
        update_target_variables(policy.actor.weights, actor_weights, tau=1.)
        update_target_variables(policy.critic.weights, critic_weights, tau=1.)
        update_target_variables(policy.critic_target.weights,
                                critic_target_weights,
                                tau=1.)

    run(args, env_fn, policy_fn, get_weights_fn, set_weights_fn)
コード例 #4
0
    def test_run_continuous(self):
        from tf2rl.algos.ddpg import DDPG
        parser = DDPG.get_argument(self.parser)
        parser.set_defaults(n_warmup=1)
        args, _ = parser.parse_known_args()

        run(args, env_fn_continuous, policy_fn_continuous,
            get_weights_fn_continuous, set_weights_fn_continuous)
コード例 #5
0
    def test_run_discrete(self):
        from tf2rl.algos.dqn import DQN
        parser = DQN.get_argument(self.parser)
        parser.set_defaults(n_warmup=1)
        args, _ = parser.parse_known_args()

        run(args, env_fn_discrete, policy_fn_discrete, get_weights_fn_discrete,
            set_weights_fn_discrete)
コード例 #6
0
                critic_units=[400, 300],
                memory_capacity=memory_capacity)


def get_weights_fn(policy):
    # TODO: Check if following needed
    import tensorflow as tf
    with tf.device(policy.device):
        return [
            policy.actor.weights, policy.critic.weights,
            policy.critic_target.weights
        ]


def set_weights_fn(policy, weights):
    actor_weights, critic_weights, critic_target_weights = weights
    update_target_variables(policy.actor.weights, actor_weights, tau=1.)
    update_target_variables(policy.critic.weights, critic_weights, tau=1.)
    update_target_variables(policy.critic_target.weights,
                            critic_target_weights,
                            tau=1.)


if __name__ == '__main__':
    parser = apex_argument()
    parser.add_argument('--env-name', type=str, default="Pendulum-v0")
    parser = DDPG.get_argument(parser)
    args = parser.parse_args()

    run(args, env_fn(args.env_name), policy_fn, get_weights_fn, set_weights_fn)
コード例 #7
0
ファイル: run_apex_ddpg.py プロジェクト: zhb0318/tf2rl
                    action_dim=env.action_space.high.size,
                    max_action=env.action_space.high[0],
                    gpu=gpu,
                    name=name,
                    sigma=noise_level,
                    batch_size=100,
                    lr_actor=0.001,
                    lr_critic=0.001,
                    actor_units=[400, 300],
                    critic_units=[400, 300],
                    memory_capacity=memory_capacity)

    def get_weights_fn(policy):
        # TODO: Check if following needed
        import tensorflow as tf
        with tf.device(policy.device):
            return [
                policy.actor.weights, policy.critic.weights,
                policy.critic_target.weights
            ]

    def set_weights_fn(policy, weights):
        actor_weights, critic_weights, critic_target_weights = weights
        update_target_variables(policy.actor.weights, actor_weights, tau=1.)
        update_target_variables(policy.critic.weights, critic_weights, tau=1.)
        update_target_variables(policy.critic_target.weights,
                                critic_target_weights,
                                tau=1.)

    run(args, env_fn, policy_fn, get_weights_fn, set_weights_fn)
コード例 #8
0
    parser = DQN.get_argument(parser)
    parser.add_argument('--atari', action='store_true')
    parser.add_argument('--env-name',
                        type=str,
                        default="SpaceInvadersNoFrameskip-v4")
    args = parser.parse_args()

    if args.atari:
        env_name = args.env_name
        n_warmup = 50000
        target_replace_interval = 10000
        batch_size = 32
        optimizer = tf.keras.optimizers.Adam(learning_rate=0.0000625,
                                             epsilon=1.5e-4)
        epsilon_decay_rate = int(1e6)
        QFunc = AtariQFunc
    else:
        env_name = "CartPole-v0"
        n_warmup = 500
        target_replace_interval = 300
        batch_size = 32
        optimizer = None
        epsilon_decay_rate = int(1e3)
        QFunc = None

    run(
        args, env_fn(env_name),
        policy_fn(args, n_warmup, target_replace_interval, batch_size,
                  optimizer, epsilon_decay_rate, QFunc), get_weights_fn,
        set_weights_fn)