def train_pg( env_class: Type[EnvWrapper], model: nn.Module, config: DictConfig, project_name=None, run_name=None, ): env = DoneIgnoreBatchedEnvWrapper(env_class, config.batch_size) optim = torch.optim.Adam(model.parameters(), lr=config.lr) wandb.init( name=f"{run_name}_{str(datetime.now().timestamp())[5:10]}", project=project_name or "testing_dqn", config=dict(config), save_code=True, group=None, tags=None, # List of string tags notes=None, # longer description of run dir=BASE_DIR, ) wandb.watch(model) # TODO: Episodic env recorder? env_recorder = EnvRecorder(config.env_record_freq, config.env_record_duration) sample_actions = ProbabilityActionSampler() cumulative_reward = 0 cumulative_done = 0 # ======= Start training ========== for episode in range(config.episodes): stats = PGStats(config.batch_size) # Stores (reward, policy prob) step = 0 env.reset() # Monte Carlo loop while not env.is_done("all"): log = DictConfig({"step": step}) states = env.get_state_batch() p_pred = model(states) p_pred = F.softmax(p_pred, 1) actions = sample_actions(valid_actions=env.get_legal_actions(), probs=p_pred, noise=0.1) _, rewards, done_list, _ = env.step(actions) stats.record(rewards, actions, p_pred, done_list) # ======== Step logging ========= mean_reward = float(np.mean(rewards)) log.mean_reward = mean_reward cumulative_done += mean_reward log.cumulative_reward = cumulative_reward cumulative_done += float(np.sum(done_list)) log.cumulative_done = cumulative_done # TODO: Log policy histograms wandb.log(log) step += 1 returns = stats.get_returns(config.gamma_discount_returns) credits = stats.get_credits(config.gamma_discount_credits) probs = stats.get_probs() loss = -1 * (probs * credits * returns) loss = torch.sum(loss) optim.zero_grad() loss.backward() optim.step() # ======== Episodic logging ======== log = DictConfig({"episode": episode}) log.episodic_reward = stats.get_mean_rewards() wandb.log(log)
def train_dqn_double( env_class: Type[EnvWrapper], model: nn.Module, config: DictConfig, project_name=None, run_name=None, ): env = BatchEnvWrapper(env_class, config.batch_size) env.reset() optim = torch.optim.Adam(model.parameters(), lr=config.lr) epsilon_scheduler = decay_functions[config.epsilon_decay_function] target_model = deepcopy(model) target_model.load_state_dict(model.state_dict()) target_model.eval() wandb.init( name=f"{run_name}_{str(datetime.now().timestamp())[5:10]}", project=project_name or "testing_dqn", config=dict(config), save_code=True, group=None, tags=None, # List of string tags notes=None, # longer description of run dir=BASE_DIR, ) wandb.watch(model) replay = PrioritizedReplay( buffer_size=config.replay_size, batch_size=config.replay_batch, delete_freq=config.delete_freq, delete_percentage=config.delete_percentage, transform=state_action_reward_state_2_transform, ) env_recorder = EnvRecorder(config.env_record_freq, config.env_record_duration) sample_actions = EpsilonRandomActionSampler() cumulative_reward = 0 cumulative_done = 0 # ======= Start training ========== # We need _some_ initial replay buffer to start with. store_initial_replay(env, replay) for step in range(config.steps): log = DictConfig({"step": step}) ( states_replay, actions_replay, rewards_replay, states2_replay, ) = replay.get_batch() states = _combine(env.get_state_batch(), states_replay) q_pred = model(states) epsilon_exploration = epsilon_scheduler(config, log) actions_live = sample_actions( valid_actions=env.get_legal_actions(), q_values=q_pred[: config.batch_size], epsilon=epsilon_exploration, ) # ============ Observe the reward && predict value of next state ============== states2, actions, rewards, dones_live = step_with_replay( env, actions_live, actions_replay, states2_replay, rewards_replay ) with torch.no_grad(): q_next_target = target_model(states2) model.eval() q_next_primary = model(states2) model.train() # Bellman equation state2_primary_actions = torch.argmax(q_next_primary, dim=1) state2_value = q_next_target[range(len(q_next_target)), state2_primary_actions] value = rewards + config.gamma_discount * state2_value q_select_actions = q_pred[range(len(q_pred)), actions] # =========== LEARN =============== loss = F.mse_loss(q_select_actions, value, reduction="none") replay.add_batch(loss, (states, actions, rewards, states2)) loss = torch.mean(loss) optim.zero_grad() loss.backward() optim.step() # Copy parameters ever so often if step % config.target_model_sync_freq == 0: target_model.load_state_dict(model.state_dict()) # ============ Logging ============= log.loss = loss.item() max_reward = torch.amax(rewards, 0).item() min_reward = torch.amin(rewards, 0).item() mean_reward = torch.mean(rewards, 0).item() log.max_reward = max_reward log.min_reward = min_reward log.mean_reward = mean_reward cumulative_done += dones_live.sum() # number of dones log.cumulative_done = int(cumulative_done) cumulative_reward += mean_reward log.cumulative_reward = cumulative_reward log.epsilon_exploration = epsilon_exploration env_recorder.record(step, env.envs, wandb) wandb.log(log)
def train(self): config = self.config env = self.env optim = torch.optim.Adam(self.model.parameters(), lr=config.lr) for episode in range(config.episodes): step = 0 env.reset() # Monte Carlo loop while not env.is_done("all"): log = DictConfig({"step": step}) states = env.get_state_batch() p_pred, q_pred = self.model(states) p_pred = F.softmax(p_pred, 1) actions = self.sample_actions( valid_actions=env.get_legal_actions(), probs=p_pred, noise=0.1) _, rewards, dones, _ = env.step(actions) self.stats.record(rewards, actions, p_pred, q_pred, dones) # ===== Logging ===== mean_reward = float(np.mean(rewards)) log.mean_reward = mean_reward self.stats.cumulative_done += mean_reward log.cumulative_reward = self.stats.cumulative_reward self.stats.cumulative_done += float(np.sum(dones)) log.cumulative_done = self.stats.cumulative_done # TODO: Log policy histograms wandb.log(log) step += 1 # ======= Learn ======= returns = self.stats.get_returns(config.gamma_discount_returns) probs = self.stats.get_probs() values = self.stats.get_values() loss_p = -1 * probs * (returns - values) loss_q = F.mse_loss(values, returns, reduction="none") loss = loss_p + loss_q loss = torch.sum(loss) optim.zero_grad() loss.backward() optim.step() # ======== Episodic logging ======== log = DictConfig({"episode": episode}) log.episodic_reward = self.stats.get_mean_rewards() wandb.log(log)
def train_dqn( env_class: Type[EnvWrapper], model: nn.Module, config: DictConfig, project_name=None, run_name=None, ): env = BatchEnvWrapper(env_class, config.batch_size) env.reset() optim = torch.optim.Adam(model.parameters(), lr=config.lr) wandb.init( name=f"{run_name}_{str(datetime.now().timestamp())[5:10]}", project=project_name or "testing_dqn", config=dict(config), save_code=True, group=None, tags=None, # List of string tags notes=None, # longer description of run dir=BASE_DIR, ) wandb.watch(model) env_recorder = EnvRecorder(config.env_record_freq, config.env_record_duration) sample_actions = EpsilonRandomActionSampler() cumulative_reward = 0 cumulative_done = 0 # ======= Start training ========== for step in range(config.steps): log = DictConfig({"step": step}) states = env.get_state_batch() q_pred = model(states) actions = sample_actions( valid_actions=env.get_legal_actions(), q_values=q_pred, epsilon=config.epsilon_exploration, ) # ============ Observe the reward && predict value of next state ============== _, rewards, done_list, _ = env.step(actions) rewards = torch.tensor(rewards).float() done_list = torch.tensor(done_list, dtype=torch.int8) next_states = env.get_state_batch() model.eval() with torch.no_grad(): q_next = model(next_states) model.train() value = rewards + config.gamma_discount * torch.amax(q_next, 1) q_actions = q_pred[range(config.batch_size), actions] # =========== LEARN =============== loss = F.mse_loss(q_actions, value) optim.zero_grad() loss.backward() optim.step() # ============ Logging ============= log.loss = loss.item() max_reward = torch.amax(rewards, 0).item() min_reward = torch.amin(rewards, 0).item() mean_reward = torch.mean(rewards, 0).item() log.max_reward = max_reward log.min_reward = min_reward log.mean_reward = mean_reward cumulative_done += done_list.sum() # number of dones log.cumulative_done = int(cumulative_done) cumulative_reward += mean_reward log.cumulative_reward = cumulative_reward env_recorder.record(step, env.envs, wandb) wandb.log(log)