示例#1
0
    def test_save(self):
        metrics = Metrics('test')
        controller = BaseController(metrics, 110, 120, 130, 140)

        step = np.random.randint(100)
        metrics.get = MagicMock(return_value=step)
        metrics.save_model = MagicMock()
        controller.save()
        metrics.get.assert_called_once_with('step')
        metrics.save_model.assert_called_once_with(step)
示例#2
0
    def test_should_eval(self):
        metrics = Metrics('test')
        controller = BaseController(metrics, 110, 120, 130, 140)

        metrics.get = MagicMock(return_value=np.random.randint(100))
        assert not controller.should_eval()
        metrics.get.assert_called_once_with('step')

        metrics.get = MagicMock(return_value=np.random.randint(100) * 140)
        assert controller.should_eval()
        metrics.get.assert_called_once_with('step')
示例#3
0
def main(args):
    # environment
    env = MuJoCoWrapper(gym.make(args.env), args.reward_scale, args.render)
    env.seed(args.seed)
    eval_env = MuJoCoWrapper(gym.make(args.env))
    eval_env.seed(args.seed)
    num_actions = env.action_space.shape[0]

    # network parameters
    params = TD3NetworkParams(fcs=args.layers,
                              concat_index=args.concat_index,
                              state_shape=env.observation_space.shape,
                              num_actions=num_actions,
                              gamma=args.gamma,
                              tau=args.tau,
                              actor_lr=args.actor_lr,
                              critic_lr=args.critic_lr,
                              target_noise_sigma=args.target_noise_sigma,
                              target_noise_clip=args.target_noise_clip)

    # deep neural network
    network = TD3Network(params)

    # replay buffer
    buffer = Buffer(args.buffer_size)

    # metrics
    saver = tf.train.Saver()
    metrics = Metrics(args.name, args.log_adapter, saver)

    # exploration noise
    noise = NormalActionNoise(np.zeros(num_actions),
                              np.ones(num_actions) * 0.1)

    # controller
    controller = TD3Controller(network, buffer, metrics, noise, num_actions,
                               args.batch_size, args.final_steps,
                               args.log_interval, args.save_interval,
                               args.eval_interval)

    # view
    view = View(controller)

    # evaluation
    eval_controller = EvalController(network, metrics, args.eval_episode)
    eval_view = View(eval_controller)

    # save hyperparameters
    metrics.log_parameters(vars(args))

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        # save model graph for debugging
        metrics.set_model_graph(sess.graph)

        if args.load is not None:
            saver.restore(sess, args.load)

        interact(env, view, eval_env, eval_view)
示例#4
0
def main(args):
    # environments
    env = BatchEnvWrapper(
        make_envs(args.env, args.num_envs, args.reward_scale), args.render)
    env.seed(args.seed)
    eval_env = BatchEnvWrapper(
        make_envs(args.env, args.num_envs, args.reward_scale))
    eval_env.seed(args.seed)
    num_actions = env.action_space.shape[0]

    # network parameters
    params = PPONetworkParams(fcs=args.layers,
                              num_actions=num_actions,
                              state_shape=env.observation_space.shape,
                              num_envs=args.num_envs,
                              batch_size=args.batch_size,
                              epsilon=args.epsilon,
                              learning_rate=args.lr,
                              grad_clip=args.grad_clip,
                              value_factor=args.value_factor,
                              entropy_factor=args.entropy_factor)

    # deep neural network
    network = PPONetwork(params)

    # rollout buffer
    rollout = Rollout()

    # metrics
    saver = tf.train.Saver()
    metrics = Metrics(args.name, args.log_adapter, saver)

    # controller
    controller = PPOController(network, rollout, metrics, args.num_envs,
                               args.time_horizon, args.epoch, args.batch_size,
                               args.gamma, args.lam, args.final_steps,
                               args.log_interval, args.save_interval,
                               args.eval_interval)

    # view
    view = View(controller)

    # evaluation
    eval_controller = EvalController(network, metrics, args.eval_episodes)
    eval_view = View(eval_controller)

    # save hyperparameters
    metrics.log_parameters(vars(args))

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        # save model graph for debugging
        metrics.set_model_graph(sess.graph)

        if args.load is not None:
            saver.restore(sess, args.load)

        interact(env, view, eval_env, eval_view, batch=True)
示例#5
0
文件: sac.py 项目: MegaYEye/mvc-drl
def main(args):
    # environment
    env = MuJoCoWrapper(gym.make(args.env), args.reward_scale, args.render)
    eval_env = MuJoCoWrapper(gym.make(args.env))
    num_actions = env.action_space.shape[0]

    # deep neural network
    network = SACNetwork(args.layers, args.concat_index,
                         env.observation_space.shape, num_actions, args.gamma,
                         args.tau, args.pi_lr, args.q_lr, args.v_lr, args.reg)

    # replay buffer
    buffer = Buffer(args.buffer_size)

    # metrics
    saver = tf.train.Saver()
    metrics = Metrics(args.name, args.log_adapter, saver)

    # exploration noise
    noise = EmptyNoise()

    # controller
    controller = SACController(network, buffer, metrics, noise, num_actions,
                               args.batch_size, args.final_steps,
                               args.log_interval, args.save_interval,
                               args.eval_interval)

    # view
    view = View(controller)

    # evaluation
    eval_controller = EvalController(network, metrics, args.eval_episode)
    eval_view = View(eval_controller)

    # save hyperparameters
    metrics.log_parameters(vars(args))

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        # save model graph for debugging
        metrics.set_model_graph(sess.graph)

        if args.load is not None:
            saver.restore(sess, args.load)

        interact(env, view, eval_env, eval_view)
示例#6
0
def main(args):
    env = BatchEnvWrapper(
        make_envs(args.env, args.num_envs, args.reward_scale), args.render)
    eval_env = BatchEnvWrapper(
        make_envs(args.env, args.num_envs, args.reward_scale))

    num_actions = env.action_space.shape[0]

    network = PPONetwork(args.layers, env.observation_space.shape,
                         args.num_envs, num_actions, args.batch_size,
                         args.epsilon, args.lr, args.grad_clip,
                         args.value_factor, args.entropy_factor)

    rollout = Rollout()

    saver = tf.train.Saver()
    metrics = Metrics(args.name, args.log_adapter, saver)

    controller = PPOController(network, rollout, metrics, args.num_envs,
                               args.time_horizon, args.epoch, args.batch_size,
                               args.gamma, args.lam, args.final_steps,
                               args.log_interval, args.save_interval,
                               args.eval_interval)
    view = View(controller)

    eval_controller = EvalController(network, metrics, args.eval_episodes)
    eval_view = View(eval_controller)

    # save hyperparameters
    metrics.log_parameters(vars(args))

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        # save model graph for debugging
        metrics.set_model_graph(sess.graph)

        if args.load is not None:
            saver.restore(sess, args.load)

        batch_interact(env, view, eval_env, eval_view)
示例#7
0
    def test_log(self):
        metrics = Metrics('test')
        controller = BaseController(metrics, 110, 120, 130, 140)

        with pytest.raises(NotImplementedError):
            controller.log()
示例#8
0
    def test_stop_episode(self):
        metrics = Metrics('test')
        controller = BaseController(metrics, 110, 120, 130, 140)

        with pytest.raises(NotImplementedError):
            controller.stop_episode('obs', 'reward', 'info')