コード例 #1
0
ファイル: test_per.py プロジェクト: Akhilez/reward_lab
    def test_adding(self):
        buffer = PrioritizedReplay(5, 2)
        buffer.add(0.1, (1, 1))
        buffer.add_batch([0.2, 0.3], ([6, 7], [3, 4]))

        # expectation = [(0.1, (1, 1)), (0.2, (6, 3)), (0.3, (7, 4))]

        self.assertIn((0.1, 1, (1, 1)), buffer.memory)
        self.assertIn((0.2, 2, (6, 3)), buffer.memory)
        self.assertIn((0.3, 3, (7, 4)), buffer.memory)

        buffer.add_batch([0.4, 0.5, 0.6], ([8, 9, 10], [5, 6, 7]))

        self.assertNotIn((0.1, 1, (1, 1)), buffer.memory)
        self.assertEqual(len(buffer.memory), 5)

        self.assertIn((0.2, 2, (6, 3)), buffer.memory)
        buffer.add(0.7, (11, 8))
        self.assertNotIn((0.2, 2, (6, 3)), buffer.memory)
コード例 #2
0
ファイル: test_per.py プロジェクト: Akhilez/reward_lab
    def test_get_batch(self):
        buffer = PrioritizedReplay(3, 2)
        buffer.add(0.1, (1, 1))
        buffer.add(0.2, (2, 2))
        buffer.add(0.3, (3, 3))

        batch = buffer.get_batch()  # [(0.1, (1, 1)), (0.2, (2, 2))]

        self.assertEqual(len(batch), 2)
        self.assertEqual(len(buffer.memory), 1)

        self.assertNotIn(buffer.memory[0], batch)

        # Dynamically reduce batch size
        batch = buffer.get_batch()
        self.assertEqual(len(batch), 1)
        self.assertEqual(len(buffer.memory), 0)

        # What happens if get batch from empty batch?
        batch = buffer.get_batch()
        self.assertEqual(len(batch), 0)
コード例 #3
0
ファイル: test_per.py プロジェクト: Akhilez/reward_lab
 def test_duplicate_loss_key(self):
     buffer = PrioritizedReplay(3, 2)
     buffer.add(0.1, (1, 1))
     buffer.add(0.1, (2, 2))
     buffer.add(0.3, (3, 3))
     buffer.add(0.1, (3, 3))
コード例 #4
0
ファイル: test_per.py プロジェクト: Akhilez/reward_lab
    def test_add_duplicate_losses(self):
        buffer = PrioritizedReplay(20, 1)
        dummy = np.ones(10)
        buffer.add_batch(np.zeros(10), (dummy, dummy, dummy, dummy))

        self.assertEqual(len(buffer.memory), 10)
コード例 #5
0
ファイル: test_per.py プロジェクト: Akhilez/reward_lab
    def test_add_batch_more_than_limit(self):
        buffer = PrioritizedReplay(2, 1)
        buffer.add_batch([0.1, 0.2, 0.3], ([5, 6, 7], [2, 3, 4]))

        self.assertNotIn((0.1, 1, (5, 2)), buffer.memory)
        self.assertEqual(len(buffer.memory), 2)
コード例 #6
0
def train_dqn_double(
    env_class: Type[EnvWrapper],
    model: nn.Module,
    config: DictConfig,
    project_name=None,
    run_name=None,
):
    env = BatchEnvWrapper(env_class, config.batch_size)
    env.reset()
    optim = torch.optim.Adam(model.parameters(), lr=config.lr)
    epsilon_scheduler = decay_functions[config.epsilon_decay_function]

    target_model = deepcopy(model)
    target_model.load_state_dict(model.state_dict())
    target_model.eval()

    wandb.init(
        name=f"{run_name}_{str(datetime.now().timestamp())[5:10]}",
        project=project_name or "testing_dqn",
        config=dict(config),
        save_code=True,
        group=None,
        tags=None,  # List of string tags
        notes=None,  # longer description of run
        dir=BASE_DIR,
    )
    wandb.watch(model)
    replay = PrioritizedReplay(
        buffer_size=config.replay_size,
        batch_size=config.replay_batch,
        delete_freq=config.delete_freq,
        delete_percentage=config.delete_percentage,
        transform=state_action_reward_state_2_transform,
    )
    env_recorder = EnvRecorder(config.env_record_freq, config.env_record_duration)
    sample_actions = EpsilonRandomActionSampler()

    cumulative_reward = 0
    cumulative_done = 0

    # ======= Start training ==========

    # We need _some_ initial replay buffer to start with.
    store_initial_replay(env, replay)

    for step in range(config.steps):
        log = DictConfig({"step": step})

        (
            states_replay,
            actions_replay,
            rewards_replay,
            states2_replay,
        ) = replay.get_batch()
        states = _combine(env.get_state_batch(), states_replay)

        q_pred = model(states)

        epsilon_exploration = epsilon_scheduler(config, log)
        actions_live = sample_actions(
            valid_actions=env.get_legal_actions(),
            q_values=q_pred[: config.batch_size],
            epsilon=epsilon_exploration,
        )

        # ============ Observe the reward && predict value of next state ==============

        states2, actions, rewards, dones_live = step_with_replay(
            env, actions_live, actions_replay, states2_replay, rewards_replay
        )

        with torch.no_grad():
            q_next_target = target_model(states2)
            model.eval()
            q_next_primary = model(states2)
            model.train()

        # Bellman equation
        state2_primary_actions = torch.argmax(q_next_primary, dim=1)
        state2_value = q_next_target[range(len(q_next_target)), state2_primary_actions]
        value = rewards + config.gamma_discount * state2_value

        q_select_actions = q_pred[range(len(q_pred)), actions]

        # =========== LEARN ===============

        loss = F.mse_loss(q_select_actions, value, reduction="none")

        replay.add_batch(loss, (states, actions, rewards, states2))
        loss = torch.mean(loss)

        optim.zero_grad()
        loss.backward()
        optim.step()

        # Copy parameters ever so often
        if step % config.target_model_sync_freq == 0:
            target_model.load_state_dict(model.state_dict())

        # ============ Logging =============

        log.loss = loss.item()

        max_reward = torch.amax(rewards, 0).item()
        min_reward = torch.amin(rewards, 0).item()
        mean_reward = torch.mean(rewards, 0).item()
        log.max_reward = max_reward
        log.min_reward = min_reward
        log.mean_reward = mean_reward

        cumulative_done += dones_live.sum()  # number of dones
        log.cumulative_done = int(cumulative_done)

        cumulative_reward += mean_reward
        log.cumulative_reward = cumulative_reward

        log.epsilon_exploration = epsilon_exploration

        env_recorder.record(step, env.envs, wandb)

        wandb.log(log)