Exemple #1
0
def test_gae_4():
    t = 5
    reward_t = torch.ones(1, t)
    value_t = torch.arange(1, t + 1).flip(0).unsqueeze(0)
    value_prime = torch.zeros(1)
    done_t = torch.zeros(1, t, dtype=torch.bool)

    for l in torch.linspace(0, 1):
        actual = utils.generalized_advantage_estimation(
            reward_t, value_t, value_prime, done_t, 1, l)
        expected = torch.zeros(1, t)

        assert torch.allclose(actual, expected)
Exemple #2
0
def test_gae_3():
    t = 5
    reward_t = torch.ones(1, t) * 2
    value_t = torch.ones(1, t) * 3
    value_prime = torch.ones(1) * 3
    done_t = torch.zeros(1, t, dtype=torch.bool)
    done_t[:, 2] = True

    actual = utils.generalized_advantage_estimation(reward_t, value_t,
                                                    value_prime, done_t, 1, 0)
    expected = torch.ones(1, t) * 2
    expected[:, 2] = -1

    assert torch.allclose(actual, expected)
Exemple #3
0
def test_gae_2():
    t = 5
    reward_t = torch.ones(1, t) * 2
    value_t = torch.ones(1, t) * 3
    value_prime = torch.ones(1) * 3
    done_t = torch.zeros(1, t, dtype=torch.bool)
    done_t[:, 2] = True

    actual = utils.generalized_advantage_estimation(reward_t, value_t,
                                                    value_prime, done_t, 1, 1)
    expected = utils.n_step_bootstrapped_return(reward_t, done_t, value_prime,
                                                1) - value_t

    assert torch.allclose(actual, expected)
Exemple #4
0
def test_gae():
    reward_t = torch.tensor([[1.0, 1.0, 1.0, 1.0, 1.0, 1.0]],
                            dtype=torch.float)
    value_t = torch.tensor([[3.0, 4.0, 5.0, 3.0, 4.0, 5.0]], dtype=torch.float)
    value_prime = torch.tensor([6.0], dtype=torch.float)
    done_t = torch.tensor([[False, False, True, False, False, False]],
                          dtype=torch.bool)

    actual = utils.generalized_advantage_estimation(reward_t,
                                                    value_t,
                                                    value_prime,
                                                    done_t,
                                                    gamma=0.9,
                                                    lambda_=0.8)
    expected = torch.tensor([[0.6064, -1.38, -4.0, 3.40576, 2.508, 1.4]])

    assert torch.allclose(actual, expected)
Exemple #5
0
def main(**kwargs):
    config = C(
        random_seed=42,
        learning_rate=1e-4,
        horizon=32,
        discount=0.99,
        num_episodes=100000,
        num_workers=32,
        entropy_weight=1e-2,
        log_interval=100,
    )
    for k in kwargs:
        config[k] = kwargs[k]

    utils.random_seed(config.random_seed)
    writer = SummaryWriter(config.experiment_path)

    # build env
    env = VecEnv([build_env for _ in range(config.num_workers)])
    env = wrappers.TensorboardBatchMonitor(env,
                                           writer,
                                           log_interval=config.log_interval,
                                           fps_mul=0.5)
    env = wrappers.Torch(env)

    # build agent and optimizer
    agent = Agent(env.observation_space, env.action_space)
    optimizer = torch.optim.Adam(agent.parameters(),
                                 config.learning_rate * config.num_workers,
                                 betas=(0.0, 0.999))

    # train
    metrics = {
        "episode/return": Stack(),
        "episode/length": Stack(),
        "rollout/reward": Stack(),
        "rollout/value_target": Stack(),
        "rollout/value": Stack(),
        "rollout/td_error": Stack(),
        "rollout/entropy": Stack(),
        "rollout/actor_loss": Stack(),
        "rollout/critic_loss": Stack(),
        "rollout/loss": Stack(),
    }

    episode = 0
    opt_step = 0
    pbar = tqdm(total=config.num_episodes)

    env.seed(config.random_seed)
    obs = env.reset()
    action = torch.zeros(config.num_workers, dtype=torch.int)
    memory = agent.zero_memory(config.num_workers)

    while episode < config.num_episodes:
        memory = tuple(x.detach() for x in memory)

        history = collect_rollout()

        rollout = history.build()

        _, value_prime, _ = agent(obs_prime, action, memory_prime)

        # value_target = utils.n_step_bootstrapped_return(
        #     reward_t=rollout.reward,
        #     done_t=rollout.done,
        #     value_prime=value_prime.detach(),
        #     gamma=config.discount,
        # )

        value_target = utils.generalized_advantage_estimation(
            reward_t=rollout.reward,
            value_t=rollout.value.detach(),
            value_prime=value_prime.detach(),
            done_t=rollout.done,
            gamma=config.discount,
            lambda_=0.96,
        )
        value_target += rollout.value.detach()

        td_error = value_target - rollout.value
        critic_loss = td_error.pow(2)
        actor_loss = (-rollout.log_prob * td_error.detach() -
                      config.entropy_weight * rollout.entropy)
        loss = actor_loss + 0.5 * critic_loss

        optimizer.zero_grad()
        loss.sum(1).mean().backward()
        # nn.utils.clip_grad_norm_(agent.parameters(), 0.01)
        optimizer.step()
        opt_step += 1

        metrics["rollout/reward"].update(rollout.reward.detach())
        metrics["rollout/value"].update(rollout.value.detach())
        metrics["rollout/value_target"].update(value_target.detach())
        metrics["rollout/td_error"].update(td_error.detach())
        metrics["rollout/entropy"].update(rollout.entropy.detach())
        metrics["rollout/actor_loss"].update(actor_loss.detach())
        metrics["rollout/critic_loss"].update(critic_loss.detach())
        metrics["rollout/loss"].update(loss.detach())

        if opt_step % 10 == 0:
            # td_error_std_normalized = td_error.std() / value_target.std()
            print("log rollout")

            total_norm = torch.norm(
                torch.stack([
                    torch.norm(p.grad.detach(), 2.0)
                    for p in agent.parameters()
                ]), 2.0)
            writer.add_scalar(f"rollout/grad_norm",
                              total_norm,
                              global_step=episode)

            for k in [
                    "rollout/reward",
                    "rollout/value_target",
                    "rollout/value",
                    "rollout/td_error",
            ]:
                v = metrics[k].compute_and_reset()
                writer.add_scalar(f"{k}/mean", v.mean(), global_step=episode)
                writer.add_histogram(f"{k}/hist", v, global_step=episode)

            for k in [
                    "rollout/entropy",
                    "rollout/actor_loss",
                    "rollout/critic_loss",
                    "rollout/loss",
            ]:
                v = metrics[k].compute_and_reset()
                writer.add_scalar(f"{k}/mean", v.mean(), global_step=episode)

            writer.flush()

            # writer.add_scalar(
            #     "rollout/td_error_std_normalized", td_error_std_normalized, global_step=episode
            # )
            # writer.add_histogram("rollout/reward", rollout.reward, global_step=episode)

    env.close()
    writer.close()
Exemple #6
0
    def episodic_training(self, train_results, tail):

        episode = self.replay_buffer.get_tail(tail)

        sl = episode['s']
        sl = list(torch.chunk(sl, int((len(sl) / self.batch) + 1)))

        s, r, t, e = [episode[k] for k in ['s', 'r', 't', 'e']]

        v = []
        for s in sl:
            v.append(self.v_net(s))

        v.append(torch.zeros_like(v[0][:1]))
        v = torch.cat(v).detach()
        v1, v2 = v[:-1], v[1:]

        adv, v_target = generalized_advantage_estimation(
            r,
            t,
            e,
            v1,
            v2,
            self.gamma,
            self.lambda_gae,
            norm=self.norm_rewards)

        episode['adv'] = adv
        episode['v_target'] = v_target

        if self.batch_ppo:
            n = self.steps_per_episode * self.batch
            indices = torch.randperm(tail * max(1, n // tail + 1)) % tail
            indices = indices[:n].unsqueeze(1).view(self.steps_per_episode,
                                                    self.batch)

            samples = {k: v[indices] for k, v in episode.items()}
            iterator_pi = iter_dict(samples)
            iterator_v = iter_dict(samples)
        else:
            iterator_pi = itertools.repeat(episode, self.steps_per_episode)
            iterator_v = itertools.repeat(episode, self.steps_per_episode)

        for i, sample in enumerate(iterator_pi):
            s, a, r, t, stag, adv, v_target, log_pi_old = [
                sample[k] for k in
                ['s', 'a', 'r', 't', 'stag', 'adv', 'v_target', 'logp']
            ]
            self.pi_net(s)
            log_pi = self.pi_net.log_prob(a)
            ratio = torch.exp((log_pi - log_pi_old).sum(dim=1))

            clip_adv = torch.clamp(ratio, 1 - self.eps_ppo,
                                   1 + self.eps_ppo) * adv
            loss_p = -(torch.min(ratio * adv, clip_adv)).mean()

            approx_kl = -float((log_pi - log_pi_old).sum(dim=1).mean())
            ent = float(self.pi_net.entropy().sum(dim=1).mean())

            if approx_kl > self.target_kl:
                train_results['scalar']['pi_opt_rounds'].append(i)
                break

            clipped = ratio.gt(1 + self.eps_ppo) | ratio.lt(1 - self.eps_ppo)
            clipfrac = float(
                torch.as_tensor(clipped, dtype=torch.float32).mean())

            self.optimizer_p.zero_grad()
            loss_p.backward()
            if self.clip_p:
                nn.utils.clip_grad_norm(self.pi_net.parameters(), self.clip_p)
            self.optimizer_p.step()

            train_results['scalar']['loss_p'].append(float(loss_p))
            train_results['scalar']['approx_kl'].append(approx_kl)
            train_results['scalar']['ent'].append(ent)
            train_results['scalar']['clipfrac'].append(clipfrac)

        for sample in iterator_v:
            s, a, r, t, stag, adv, v_target, log_pi_old = [
                sample[k] for k in
                ['s', 'a', 'r', 't', 'stag', 'adv', 'v_target', 'logp']
            ]

            v = self.v_net(s)
            loss_v = F.mse_loss(v, v_target, reduction='mean')

            self.optimizer_v.zero_grad()
            loss_v.backward()
            if self.clip_q:
                nn.utils.clip_grad_norm(self.v_net.parameters(), self.clip_q)
            self.optimizer_v.step()

            train_results['scalar']['loss_v'].append(float(loss_v))

        return train_results
Exemple #7
0
def main(**kwargs):
    config = C(
        random_seed=42,
        learning_rate=1e-3,
        horizon=16,
        discount=0.995,
        num_observations=1000000,
        num_workers=32,
        entropy_weight=1e-2,
        episode_log_interval=100,
        opt_log_interval=10,
        average_reward_lr=0.001,
        clip_grad_norm=None,
        model=C(
            num_features=64,
            encoder=C(type="minigrid", ),
            memory=C(type="lstm", ),
        ),
    )
    for k in kwargs:
        config[k] = kwargs[k]

    utils.random_seed(config.random_seed)
    writer = SummaryWriter(config.experiment_path)

    # build env
    env = VecEnv([build_env for _ in range(config.num_workers)])
    env = wrappers.TensorboardBatchMonitor(
        env, writer, log_interval=config.episode_log_interval, fps_mul=0.5)
    env = wrappers.Torch(env)

    # build agent and optimizer
    agent = Agent(
        env.observation_space,
        env.action_space,
        **config.model,
    )
    optimizer = torch.optim.Adam(
        agent.parameters(),
        config.learning_rate,
        betas=(0.0, 0.999),
    )
    average_reward = 0

    # load state
    # state = torch.load("./state.pth")
    # agent.load_state_dict(state["agent"])
    # optimizer.load_state_dict(state["optimizer"])

    # train
    metrics = {
        "episode/return": Stack(),
        "episode/length": Stack(),
        "rollout/reward": Stack(),
        "rollout/value_target": Stack(),
        "rollout/value": Stack(),
        "rollout/td_error": Stack(),
        "rollout/entropy": Stack(),
        "rollout/actor_loss": Stack(),
        "rollout/critic_loss": Stack(),
        "rollout/loss": Stack(),
    }

    opt_step = 0
    observation_step = 0
    pbar = tqdm(total=config.num_observations)

    env.seed(config.random_seed)
    obs = env.reset()
    action = torch.zeros(config.num_workers, dtype=torch.int)
    memory = agent.zero_memory(config.num_workers)

    # r_stats = utils.RunningStats()

    while observation_step < config.num_observations:
        history = History()
        memory = agent.detach_memory(memory)

        for i in range(config.horizon):
            transition = history.append_transition()

            dist, value, memory_prime = agent(obs, action, memory)
            transition.record(value=value, entropy=dist.entropy())
            action = select_action(dist)
            transition.record(log_prob=dist.log_prob(action))

            obs_prime, reward, done, info = env.step(action)
            observation_step += config.num_workers
            pbar.update(config.num_workers)

            # for r in reward:
            #     r_stats.push(r)
            # reward = reward / r_stats.standard_deviation()
            transition.record(reward=reward, done=done)
            memory_prime = agent.reset_memory(memory_prime, done)

            obs, memory = obs_prime, memory_prime

            for i in info:
                if "episode" not in i:
                    continue

                metrics["episode/return"].update(i["episode"]["r"])
                metrics["episode/length"].update(i["episode"]["l"])

        rollout = history.build()

        _, value_prime, _ = agent(obs_prime, action, memory_prime)

        # value_target = utils.n_step_bootstrapped_return(
        #     reward_t=rollout.reward,
        #     done_t=rollout.done,
        #     value_prime=value_prime.detach(),
        #     discount=config.discount,
        # )

        advantage = utils.generalized_advantage_estimation(
            reward_t=rollout.reward,
            value_t=rollout.value.detach(),
            value_prime=value_prime.detach(),
            done_t=rollout.done,
            gamma=config.discount,
            lambda_=0.96,
        )
        value_target = advantage + rollout.value.detach()

        # value_target = utils.differential_n_step_bootstrapped_return(
        #     reward_t=rollout.reward,
        #     done_t=rollout.done,
        #     value_prime=value_prime.detach(),
        #     average_reward=average_reward,
        # )

        td_error = value_target - rollout.value

        critic_loss = 0.5 * td_error.pow(2)
        actor_loss = (-rollout.log_prob * td_error.detach() -
                      config.entropy_weight * rollout.entropy)
        loss = actor_loss + critic_loss

        optimizer.zero_grad()
        agg(loss).backward()
        if config.clip_grad_norm is not None:
            nn.utils.clip_grad_norm_(agent.parameters(), config.clip_grad_norm)
        optimizer.step()
        average_reward += config.average_reward_lr * agg(
            td_error.detach())  # TODO: do not use td-error
        opt_step += 1

        metrics["rollout/reward"].update(rollout.reward.detach())
        metrics["rollout/value"].update(rollout.value.detach())
        metrics["rollout/value_target"].update(value_target.detach())
        metrics["rollout/td_error"].update(td_error.detach())
        metrics["rollout/entropy"].update(rollout.entropy.detach())
        metrics["rollout/actor_loss"].update(actor_loss.detach())
        metrics["rollout/critic_loss"].update(critic_loss.detach())
        metrics["rollout/loss"].update(loss.detach())

        if opt_step % config.opt_log_interval == 0:
            print("log metrics")

            writer.add_scalar("rollout/average_reward",
                              average_reward,
                              global_step=observation_step)
            grad_norm = torch.norm(
                torch.stack([
                    torch.norm(p.grad.detach(), 2.0)
                    for p in agent.parameters()
                ]), 2.0)
            writer.add_scalar("rollout/grad_norm",
                              grad_norm,
                              global_step=observation_step)

            for k in [
                    "rollout/reward",
                    "rollout/value_target",
                    "rollout/value",
                    "rollout/td_error",
            ]:
                v = metrics[k].compute_and_reset()
                writer.add_scalar(f"{k}/mean",
                                  v.mean(),
                                  global_step=observation_step)
                writer.add_histogram(f"{k}/hist",
                                     v,
                                     global_step=observation_step)

            for k in [
                    "rollout/entropy",
                    "rollout/actor_loss",
                    "rollout/critic_loss",
                    "rollout/loss",
            ]:
                v = metrics[k].compute_and_reset()
                writer.add_scalar(f"{k}/mean",
                                  v.mean(),
                                  global_step=observation_step)

            for k in [
                    "episode/return",
                    "episode/length",
            ]:
                v = metrics[k].compute_and_reset()
                writer.add_scalar(f"{k}/mean",
                                  v.mean(),
                                  global_step=observation_step)
                writer.add_histogram(f"{k}/hist",
                                     v,
                                     global_step=observation_step)

            writer.flush()

            # torch.save(
            #     {
            #         "agent": agent.state_dict(),
            #         "optimizer": optimizer.state_dict(),
            #         "average_reward": average_reward,
            #     },
            #     "./state.pth",
            # )

    env.close()
    writer.close()