def test_adding(self): buffer = PrioritizedReplay(5, 2) buffer.add(0.1, (1, 1)) buffer.add_batch([0.2, 0.3], ([6, 7], [3, 4])) # expectation = [(0.1, (1, 1)), (0.2, (6, 3)), (0.3, (7, 4))] self.assertIn((0.1, 1, (1, 1)), buffer.memory) self.assertIn((0.2, 2, (6, 3)), buffer.memory) self.assertIn((0.3, 3, (7, 4)), buffer.memory) buffer.add_batch([0.4, 0.5, 0.6], ([8, 9, 10], [5, 6, 7])) self.assertNotIn((0.1, 1, (1, 1)), buffer.memory) self.assertEqual(len(buffer.memory), 5) self.assertIn((0.2, 2, (6, 3)), buffer.memory) buffer.add(0.7, (11, 8)) self.assertNotIn((0.2, 2, (6, 3)), buffer.memory)
def test_get_batch(self): buffer = PrioritizedReplay(3, 2) buffer.add(0.1, (1, 1)) buffer.add(0.2, (2, 2)) buffer.add(0.3, (3, 3)) batch = buffer.get_batch() # [(0.1, (1, 1)), (0.2, (2, 2))] self.assertEqual(len(batch), 2) self.assertEqual(len(buffer.memory), 1) self.assertNotIn(buffer.memory[0], batch) # Dynamically reduce batch size batch = buffer.get_batch() self.assertEqual(len(batch), 1) self.assertEqual(len(buffer.memory), 0) # What happens if get batch from empty batch? batch = buffer.get_batch() self.assertEqual(len(batch), 0)
def test_duplicate_loss_key(self): buffer = PrioritizedReplay(3, 2) buffer.add(0.1, (1, 1)) buffer.add(0.1, (2, 2)) buffer.add(0.3, (3, 3)) buffer.add(0.1, (3, 3))
def test_add_duplicate_losses(self): buffer = PrioritizedReplay(20, 1) dummy = np.ones(10) buffer.add_batch(np.zeros(10), (dummy, dummy, dummy, dummy)) self.assertEqual(len(buffer.memory), 10)
def test_add_batch_more_than_limit(self): buffer = PrioritizedReplay(2, 1) buffer.add_batch([0.1, 0.2, 0.3], ([5, 6, 7], [2, 3, 4])) self.assertNotIn((0.1, 1, (5, 2)), buffer.memory) self.assertEqual(len(buffer.memory), 2)
def train_dqn_double( env_class: Type[EnvWrapper], model: nn.Module, config: DictConfig, project_name=None, run_name=None, ): env = BatchEnvWrapper(env_class, config.batch_size) env.reset() optim = torch.optim.Adam(model.parameters(), lr=config.lr) epsilon_scheduler = decay_functions[config.epsilon_decay_function] target_model = deepcopy(model) target_model.load_state_dict(model.state_dict()) target_model.eval() wandb.init( name=f"{run_name}_{str(datetime.now().timestamp())[5:10]}", project=project_name or "testing_dqn", config=dict(config), save_code=True, group=None, tags=None, # List of string tags notes=None, # longer description of run dir=BASE_DIR, ) wandb.watch(model) replay = PrioritizedReplay( buffer_size=config.replay_size, batch_size=config.replay_batch, delete_freq=config.delete_freq, delete_percentage=config.delete_percentage, transform=state_action_reward_state_2_transform, ) env_recorder = EnvRecorder(config.env_record_freq, config.env_record_duration) sample_actions = EpsilonRandomActionSampler() cumulative_reward = 0 cumulative_done = 0 # ======= Start training ========== # We need _some_ initial replay buffer to start with. store_initial_replay(env, replay) for step in range(config.steps): log = DictConfig({"step": step}) ( states_replay, actions_replay, rewards_replay, states2_replay, ) = replay.get_batch() states = _combine(env.get_state_batch(), states_replay) q_pred = model(states) epsilon_exploration = epsilon_scheduler(config, log) actions_live = sample_actions( valid_actions=env.get_legal_actions(), q_values=q_pred[: config.batch_size], epsilon=epsilon_exploration, ) # ============ Observe the reward && predict value of next state ============== states2, actions, rewards, dones_live = step_with_replay( env, actions_live, actions_replay, states2_replay, rewards_replay ) with torch.no_grad(): q_next_target = target_model(states2) model.eval() q_next_primary = model(states2) model.train() # Bellman equation state2_primary_actions = torch.argmax(q_next_primary, dim=1) state2_value = q_next_target[range(len(q_next_target)), state2_primary_actions] value = rewards + config.gamma_discount * state2_value q_select_actions = q_pred[range(len(q_pred)), actions] # =========== LEARN =============== loss = F.mse_loss(q_select_actions, value, reduction="none") replay.add_batch(loss, (states, actions, rewards, states2)) loss = torch.mean(loss) optim.zero_grad() loss.backward() optim.step() # Copy parameters ever so often if step % config.target_model_sync_freq == 0: target_model.load_state_dict(model.state_dict()) # ============ Logging ============= log.loss = loss.item() max_reward = torch.amax(rewards, 0).item() min_reward = torch.amin(rewards, 0).item() mean_reward = torch.mean(rewards, 0).item() log.max_reward = max_reward log.min_reward = min_reward log.mean_reward = mean_reward cumulative_done += dones_live.sum() # number of dones log.cumulative_done = int(cumulative_done) cumulative_reward += mean_reward log.cumulative_reward = cumulative_reward log.epsilon_exploration = epsilon_exploration env_recorder.record(step, env.envs, wandb) wandb.log(log)