Beispiel #1
0
def train_pg(
    env_class: Type[EnvWrapper],
    model: nn.Module,
    config: DictConfig,
    project_name=None,
    run_name=None,
):
    env = DoneIgnoreBatchedEnvWrapper(env_class, config.batch_size)
    optim = torch.optim.Adam(model.parameters(), lr=config.lr)
    wandb.init(
        name=f"{run_name}_{str(datetime.now().timestamp())[5:10]}",
        project=project_name or "testing_dqn",
        config=dict(config),
        save_code=True,
        group=None,
        tags=None,  # List of string tags
        notes=None,  # longer description of run
        dir=BASE_DIR,
    )
    wandb.watch(model)
    # TODO: Episodic env recorder?
    env_recorder = EnvRecorder(config.env_record_freq,
                               config.env_record_duration)
    sample_actions = ProbabilityActionSampler()

    cumulative_reward = 0
    cumulative_done = 0

    # ======= Start training ==========

    for episode in range(config.episodes):
        stats = PGStats(config.batch_size)  # Stores (reward, policy prob)
        step = 0
        env.reset()

        # Monte Carlo loop
        while not env.is_done("all"):
            log = DictConfig({"step": step})

            states = env.get_state_batch()
            p_pred = model(states)
            p_pred = F.softmax(p_pred, 1)

            actions = sample_actions(valid_actions=env.get_legal_actions(),
                                     probs=p_pred,
                                     noise=0.1)

            _, rewards, done_list, _ = env.step(actions)

            stats.record(rewards, actions, p_pred, done_list)

            # ======== Step logging =========

            mean_reward = float(np.mean(rewards))
            log.mean_reward = mean_reward

            cumulative_done += mean_reward
            log.cumulative_reward = cumulative_reward

            cumulative_done += float(np.sum(done_list))
            log.cumulative_done = cumulative_done

            # TODO: Log policy histograms

            wandb.log(log)

            step += 1

        returns = stats.get_returns(config.gamma_discount_returns)
        credits = stats.get_credits(config.gamma_discount_credits)
        probs = stats.get_probs()

        loss = -1 * (probs * credits * returns)
        loss = torch.sum(loss)

        optim.zero_grad()
        loss.backward()
        optim.step()

        # ======== Episodic logging ========

        log = DictConfig({"episode": episode})
        log.episodic_reward = stats.get_mean_rewards()

        wandb.log(log)
Beispiel #2
0
        state, reward, is_done, info = env.step(action)

        with torch.no_grad():
            qs2 = model(torch.FloatTensor([state.flatten()]))[0]

        target = reward + 0.9 * qs2.amax()
        loss = (target - qs[action])**2

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        log = DictConfig({"episode": episode})
        log.ep_loss = loss.item()

        cumulative_reward += reward
        log.cumulative_reward = cumulative_reward

        rewards.append(reward)
        if must_record:
            video_buffer.append(deepcopy(env.render("rgb_array")))
        if is_done:
            log.ep_mean_reward = float(np.mean(rewards))
            log.ep_length = len(rewards)
            if must_record:
                log = dict(log)
                log[f"video_ep{episode}_reward{reward}"] = wandb.Video(
                    _format_video(video_buffer), fps=4, format="gif")

        wandb.log(log)
Beispiel #3
0
def train_dqn_double(
    env_class: Type[EnvWrapper],
    model: nn.Module,
    config: DictConfig,
    project_name=None,
    run_name=None,
):
    env = BatchEnvWrapper(env_class, config.batch_size)
    env.reset()
    optim = torch.optim.Adam(model.parameters(), lr=config.lr)
    epsilon_scheduler = decay_functions[config.epsilon_decay_function]

    target_model = deepcopy(model)
    target_model.load_state_dict(model.state_dict())
    target_model.eval()

    wandb.init(
        name=f"{run_name}_{str(datetime.now().timestamp())[5:10]}",
        project=project_name or "testing_dqn",
        config=dict(config),
        save_code=True,
        group=None,
        tags=None,  # List of string tags
        notes=None,  # longer description of run
        dir=BASE_DIR,
    )
    wandb.watch(model)
    replay = PrioritizedReplay(
        buffer_size=config.replay_size,
        batch_size=config.replay_batch,
        delete_freq=config.delete_freq,
        delete_percentage=config.delete_percentage,
        transform=state_action_reward_state_2_transform,
    )
    env_recorder = EnvRecorder(config.env_record_freq, config.env_record_duration)
    sample_actions = EpsilonRandomActionSampler()

    cumulative_reward = 0
    cumulative_done = 0

    # ======= Start training ==========

    # We need _some_ initial replay buffer to start with.
    store_initial_replay(env, replay)

    for step in range(config.steps):
        log = DictConfig({"step": step})

        (
            states_replay,
            actions_replay,
            rewards_replay,
            states2_replay,
        ) = replay.get_batch()
        states = _combine(env.get_state_batch(), states_replay)

        q_pred = model(states)

        epsilon_exploration = epsilon_scheduler(config, log)
        actions_live = sample_actions(
            valid_actions=env.get_legal_actions(),
            q_values=q_pred[: config.batch_size],
            epsilon=epsilon_exploration,
        )

        # ============ Observe the reward && predict value of next state ==============

        states2, actions, rewards, dones_live = step_with_replay(
            env, actions_live, actions_replay, states2_replay, rewards_replay
        )

        with torch.no_grad():
            q_next_target = target_model(states2)
            model.eval()
            q_next_primary = model(states2)
            model.train()

        # Bellman equation
        state2_primary_actions = torch.argmax(q_next_primary, dim=1)
        state2_value = q_next_target[range(len(q_next_target)), state2_primary_actions]
        value = rewards + config.gamma_discount * state2_value

        q_select_actions = q_pred[range(len(q_pred)), actions]

        # =========== LEARN ===============

        loss = F.mse_loss(q_select_actions, value, reduction="none")

        replay.add_batch(loss, (states, actions, rewards, states2))
        loss = torch.mean(loss)

        optim.zero_grad()
        loss.backward()
        optim.step()

        # Copy parameters ever so often
        if step % config.target_model_sync_freq == 0:
            target_model.load_state_dict(model.state_dict())

        # ============ Logging =============

        log.loss = loss.item()

        max_reward = torch.amax(rewards, 0).item()
        min_reward = torch.amin(rewards, 0).item()
        mean_reward = torch.mean(rewards, 0).item()
        log.max_reward = max_reward
        log.min_reward = min_reward
        log.mean_reward = mean_reward

        cumulative_done += dones_live.sum()  # number of dones
        log.cumulative_done = int(cumulative_done)

        cumulative_reward += mean_reward
        log.cumulative_reward = cumulative_reward

        log.epsilon_exploration = epsilon_exploration

        env_recorder.record(step, env.envs, wandb)

        wandb.log(log)
Beispiel #4
0
    def train(self):
        config = self.config
        env = self.env

        optim = torch.optim.Adam(self.model.parameters(), lr=config.lr)

        for episode in range(config.episodes):
            step = 0
            env.reset()

            # Monte Carlo loop
            while not env.is_done("all"):
                log = DictConfig({"step": step})

                states = env.get_state_batch()
                p_pred, q_pred = self.model(states)
                p_pred = F.softmax(p_pred, 1)

                actions = self.sample_actions(
                    valid_actions=env.get_legal_actions(),
                    probs=p_pred,
                    noise=0.1)

                _, rewards, dones, _ = env.step(actions)

                self.stats.record(rewards, actions, p_pred, q_pred, dones)

                # ===== Logging =====

                mean_reward = float(np.mean(rewards))
                log.mean_reward = mean_reward

                self.stats.cumulative_done += mean_reward
                log.cumulative_reward = self.stats.cumulative_reward

                self.stats.cumulative_done += float(np.sum(dones))
                log.cumulative_done = self.stats.cumulative_done

                # TODO: Log policy histograms

                wandb.log(log)
                step += 1

            # ======= Learn =======

            returns = self.stats.get_returns(config.gamma_discount_returns)
            probs = self.stats.get_probs()
            values = self.stats.get_values()

            loss_p = -1 * probs * (returns - values)
            loss_q = F.mse_loss(values, returns, reduction="none")
            loss = loss_p + loss_q
            loss = torch.sum(loss)

            optim.zero_grad()
            loss.backward()
            optim.step()

            # ======== Episodic logging ========

            log = DictConfig({"episode": episode})
            log.episodic_reward = self.stats.get_mean_rewards()

            wandb.log(log)
Beispiel #5
0
def train_dqn(
    env_class: Type[EnvWrapper],
    model: nn.Module,
    config: DictConfig,
    project_name=None,
    run_name=None,
):
    env = BatchEnvWrapper(env_class, config.batch_size)
    env.reset()
    optim = torch.optim.Adam(model.parameters(), lr=config.lr)
    wandb.init(
        name=f"{run_name}_{str(datetime.now().timestamp())[5:10]}",
        project=project_name or "testing_dqn",
        config=dict(config),
        save_code=True,
        group=None,
        tags=None,  # List of string tags
        notes=None,  # longer description of run
        dir=BASE_DIR,
    )
    wandb.watch(model)
    env_recorder = EnvRecorder(config.env_record_freq,
                               config.env_record_duration)
    sample_actions = EpsilonRandomActionSampler()

    cumulative_reward = 0
    cumulative_done = 0

    # ======= Start training ==========

    for step in range(config.steps):
        log = DictConfig({"step": step})

        states = env.get_state_batch()
        q_pred = model(states)

        actions = sample_actions(
            valid_actions=env.get_legal_actions(),
            q_values=q_pred,
            epsilon=config.epsilon_exploration,
        )

        # ============ Observe the reward && predict value of next state ==============

        _, rewards, done_list, _ = env.step(actions)

        rewards = torch.tensor(rewards).float()
        done_list = torch.tensor(done_list, dtype=torch.int8)
        next_states = env.get_state_batch()

        model.eval()
        with torch.no_grad():
            q_next = model(next_states)
        model.train()

        value = rewards + config.gamma_discount * torch.amax(q_next, 1)
        q_actions = q_pred[range(config.batch_size), actions]

        # =========== LEARN ===============

        loss = F.mse_loss(q_actions, value)

        optim.zero_grad()
        loss.backward()
        optim.step()

        # ============ Logging =============

        log.loss = loss.item()

        max_reward = torch.amax(rewards, 0).item()
        min_reward = torch.amin(rewards, 0).item()
        mean_reward = torch.mean(rewards, 0).item()
        log.max_reward = max_reward
        log.min_reward = min_reward
        log.mean_reward = mean_reward

        cumulative_done += done_list.sum()  # number of dones
        log.cumulative_done = int(cumulative_done)

        cumulative_reward += mean_reward
        log.cumulative_reward = cumulative_reward

        env_recorder.record(step, env.envs, wandb)

        wandb.log(log)