Пример #1
0
Файл: HER.py Проект: ymd-h/cpprb
    def test_PER(self):
        rew_func = lambda s, a, g: -1 * (s != g)
        batch_size = 4

        hrb = HindsightReplayBuffer(size=10,
                                    env_dict={
                                        "obs": {},
                                        "act": {},
                                        "next_obs": {}
                                    },
                                    max_episode_len=2,
                                    strategy="future",
                                    reward_func=rew_func,
                                    additional_goals=2,
                                    prioritized=True)

        hrb.add(obs=0, act=0, next_obs=1)
        hrb.add(obs=1, act=0, next_obs=2)

        hrb.on_episode_end(3)
        self.assertEqual(hrb.get_stored_size(), 6)

        sample = hrb.sample(batch_size)
        hrb.update_priorities(indexes=sample["indexes"],
                              priorities=np.zeros_like(sample["indexes"],
                                                       dtype=np.float))
Пример #2
0
Файл: HER.py Проект: ymd-h/cpprb
    def test_assert_PER(self):
        rew_func = lambda s, a, g: -1 * (s != g)
        hrb = HindsightReplayBuffer(size=10,
                                    env_dict={
                                        "obs": {},
                                        "act": {},
                                        "next_obs": {}
                                    },
                                    max_episode_len=2,
                                    strategy="future",
                                    reward_func=rew_func,
                                    additional_goals=2,
                                    prioritized=False)

        hrb.add(obs=0, act=0, next_obs=1)
        hrb.add(obs=1, act=0, next_obs=2)

        with self.assertRaises(ValueError):
            hrb.get_max_priority()

        with self.assertRaises(ValueError):
            hrb.update_priorities([], [])
Пример #3
0
        target_Q = tf.stop_gradient(target_func(model,target_model,
                                                sg(sample["next_obs"],sample["goal"]),
                                                sample_rew,
                                                sample_done,
                                                discount,
                                                tf.constant(env.action_space.n)))
        absTD = tf.math.abs(target_Q - Q)
        loss = tf.reduce_mean(loss_func(absTD)*weights)

    grad = tape.gradient(loss, model.trainable_weights)
    optimizer.apply_gradients(zip(grad, model.trainable_weights))
    tf.summary.scalar("Loss vs training step", data=loss, step=n_step)


    if prioritized:
        Q =  Q_func(model,
                    sg(sample["obs"], sample["goal"]),
                    tf.constant(sample["act"].ravel()),
                    tf.constant(env.action_space.n))
        absTD = tf.math.abs(target_Q - Q)
        rb.update_priorities(sample["indexes"], absTD)


    if n_step % target_update_freq == 0:
        target_model.set_weights(model.get_weights())

    if n_step % eval_freq == eval_freq-1:
        eval_rew = evaluate(model, eval_env)
        tf.summary.scalar("success rate vs training step",
                          data=eval_rew, step=n_step)