예제 #1
0
def make_vizdoom_env(task, frame_skip, res, save_lmp, seed, training_num, test_num):
    test_num = min(os.cpu_count() - 1, test_num)
    if envpool is not None:
        task_id = "".join([i.capitalize() for i in task.split("_")]) + "-v1"
        lmp_save_dir = "lmps/" if save_lmp else ""
        reward_config = {
            "KILLCOUNT": [20.0, -20.0],
            "HEALTH": [1.0, 0.0],
            "AMMO2": [1.0, -1.0],
        }
        if "battle" in task:
            reward_config["HEALTH"] = [1.0, -1.0]
        env = train_envs = envpool.make_gym(
            task_id,
            frame_skip=frame_skip,
            stack_num=res[0],
            seed=seed,
            num_envs=training_num,
            reward_config=reward_config,
            use_combined_action=True,
            max_episode_steps=2625,
            use_inter_area_resize=False,
        )
        test_envs = envpool.make_gym(
            task_id,
            frame_skip=frame_skip,
            stack_num=res[0],
            lmp_save_dir=lmp_save_dir,
            seed=seed,
            num_envs=test_num,
            reward_config=reward_config,
            use_combined_action=True,
            max_episode_steps=2625,
            use_inter_area_resize=False,
        )
    else:
        cfg_path = f"maps/{task}.cfg"
        env = Env(cfg_path, frame_skip, res)
        train_envs = ShmemVectorEnv(
            [lambda: Env(cfg_path, frame_skip, res) for _ in range(training_num)]
        )
        test_envs = ShmemVectorEnv(
            [
                lambda: Env(cfg_path, frame_skip, res, save_lmp)
                for _ in range(test_num)
            ]
        )
        train_envs.seed(seed)
        test_envs.seed(seed)
    return env, train_envs, test_envs
예제 #2
0
def make_atari_env(task, seed, training_num, test_num, **kwargs):
    """Wrapper function for Atari env.

    If EnvPool is installed, it will automatically switch to EnvPool's Atari env.

    :return: a tuple of (single env, training envs, test envs).
    """
    if envpool is not None:
        if kwargs.get("scale", 0):
            warnings.warn(
                "EnvPool does not include ScaledFloatFrame wrapper, "
                "please set `x = x / 255.0` inside CNN network's forward function."
            )
        # parameters convertion
        train_envs = env = envpool.make_gym(
            task.replace("NoFrameskip-v4", "-v5"),
            num_envs=training_num,
            seed=seed,
            episodic_life=True,
            reward_clip=True,
            stack_num=kwargs.get("frame_stack", 4),
        )
        test_envs = envpool.make_gym(
            task.replace("NoFrameskip-v4", "-v5"),
            num_envs=training_num,
            seed=seed,
            episodic_life=False,
            reward_clip=False,
            stack_num=kwargs.get("frame_stack", 4),
        )
    else:
        warnings.warn("Recommend using envpool (pip install envpool) "
                      "to run Atari games more efficiently.")
        env = wrap_deepmind(task, **kwargs)
        train_envs = ShmemVectorEnv([
            lambda: wrap_deepmind(
                task, episode_life=True, clip_rewards=True, **kwargs)
            for _ in range(training_num)
        ])
        test_envs = ShmemVectorEnv([
            lambda: wrap_deepmind(
                task, episode_life=False, clip_rewards=False, **kwargs)
            for _ in range(test_num)
        ])
        env.seed(seed)
        train_envs.seed(seed)
        test_envs.seed(seed)
    return env, train_envs, test_envs
예제 #3
0
def test_vecenv(size=10, num=8, sleep=0.001):
    env_fns = [
        lambda i=i: MyTestEnv(size=i, sleep=sleep, recurse_state=True)
        for i in range(size, size + num)
    ]
    venv = [
        DummyVectorEnv(env_fns),
        SubprocVectorEnv(env_fns),
        ShmemVectorEnv(env_fns),
    ]
    if has_ray():
        venv += [RayVectorEnv(env_fns)]
    for v in venv:
        v.seed(0)
    action_list = [1] * 5 + [0] * 10 + [1] * 20
    o = [v.reset() for v in venv]
    for a in action_list:
        o = []
        for v in venv:
            A, B, C, D = v.step([a] * num)
            if sum(C):
                A = v.reset(np.where(C)[0])
            o.append([A, B, C, D])
        for index, infos in enumerate(zip(*o)):
            if index == 3:  # do not check info here
                continue
            for info in infos:
                assert recurse_comp(infos[0], info)

    if __name__ == '__main__':
        t = [0] * len(venv)
        for i, e in enumerate(venv):
            t[i] = time.time()
            e.reset()
            for a in action_list:
                done = e.step([a] * num)[2]
                if sum(done) > 0:
                    e.reset(np.where(done)[0])
            t[i] = time.time() - t[i]
        for i, v in enumerate(venv):
            print(f'{type(v)}: {t[i]:.6f}s')

    for v in venv:
        assert v.size == list(range(size, size + num))
        assert v.env_num == num
        assert v.action_space == [Discrete(2)] * num
    for v in venv:
        v.close()
예제 #4
0
def test_c51(args=get_args()):
    args.cfg_path = f"maps/{args.task}.cfg"
    args.wad_path = f"maps/{args.task}.wad"
    args.res = (args.skip_num, 84, 84)
    env = Env(args.cfg_path, args.frames_stack, args.res)
    args.state_shape = args.res
    args.action_shape = env.action_space.shape or env.action_space.n
    # should be N_FRAMES x H x W
    print("Observations shape:", args.state_shape)
    print("Actions shape:", args.action_shape)
    # make environments
    train_envs = ShmemVectorEnv([
        lambda: Env(args.cfg_path, args.frames_stack, args.res)
        for _ in range(args.training_num)
    ])
    test_envs = ShmemVectorEnv([
        lambda: Env(args.cfg_path, args.frames_stack, args.res, args.save_lmp)
        for _ in range(min(os.cpu_count() - 1, args.test_num))
    ])
    # seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    train_envs.seed(args.seed)
    test_envs.seed(args.seed)
    # define model
    net = C51(*args.state_shape, args.action_shape, args.num_atoms,
              args.device)
    optim = torch.optim.Adam(net.parameters(), lr=args.lr)
    # define policy
    policy = C51Policy(net,
                       optim,
                       args.gamma,
                       args.num_atoms,
                       args.v_min,
                       args.v_max,
                       args.n_step,
                       target_update_freq=args.target_update_freq).to(
                           args.device)
    # load a previous policy
    if args.resume_path:
        policy.load_state_dict(
            torch.load(args.resume_path, map_location=args.device))
        print("Loaded agent from: ", args.resume_path)
    # replay buffer: `save_last_obs` and `stack_num` can be removed together
    # when you have enough RAM
    buffer = VectorReplayBuffer(args.buffer_size,
                                buffer_num=len(train_envs),
                                ignore_obs_next=True,
                                save_only_last_obs=True,
                                stack_num=args.frames_stack)
    # collector
    train_collector = Collector(policy,
                                train_envs,
                                buffer,
                                exploration_noise=True)
    test_collector = Collector(policy, test_envs, exploration_noise=True)
    # log
    log_path = os.path.join(args.logdir, args.task, 'c51')
    writer = SummaryWriter(log_path)
    writer.add_text("args", str(args))
    logger = TensorboardLogger(writer)

    def save_best_fn(policy):
        torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

    def stop_fn(mean_rewards):
        if env.spec.reward_threshold:
            return mean_rewards >= env.spec.reward_threshold
        elif 'Pong' in args.task:
            return mean_rewards >= 20
        else:
            return False

    def train_fn(epoch, env_step):
        # nature DQN setting, linear decay in the first 1M steps
        if env_step <= 1e6:
            eps = args.eps_train - env_step / 1e6 * \
                (args.eps_train - args.eps_train_final)
        else:
            eps = args.eps_train_final
        policy.set_eps(eps)
        if env_step % 1000 == 0:
            logger.write("train/env_step", env_step, {"train/eps": eps})

    def test_fn(epoch, env_step):
        policy.set_eps(args.eps_test)

    # watch agent's performance
    def watch():
        print("Setup test envs ...")
        policy.eval()
        policy.set_eps(args.eps_test)
        test_envs.seed(args.seed)
        if args.save_buffer_name:
            print(f"Generate buffer with size {args.buffer_size}")
            buffer = VectorReplayBuffer(args.buffer_size,
                                        buffer_num=len(test_envs),
                                        ignore_obs_next=True,
                                        save_only_last_obs=True,
                                        stack_num=args.frames_stack)
            collector = Collector(policy,
                                  test_envs,
                                  buffer,
                                  exploration_noise=True)
            result = collector.collect(n_step=args.buffer_size)
            print(f"Save buffer into {args.save_buffer_name}")
            # Unfortunately, pickle will cause oom with 1M buffer size
            buffer.save_hdf5(args.save_buffer_name)
        else:
            print("Testing agent ...")
            test_collector.reset()
            result = test_collector.collect(n_episode=args.test_num,
                                            render=args.render)
        rew = result["rews"].mean()
        lens = result["lens"].mean() * args.skip_num
        print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')
        print(f'Mean length (over {result["n/ep"]} episodes): {lens}')

    if args.watch:
        watch()
        exit(0)

    # test train_collector and start filling replay buffer
    train_collector.collect(n_step=args.batch_size * args.training_num)
    # trainer
    result = offpolicy_trainer(policy,
                               train_collector,
                               test_collector,
                               args.epoch,
                               args.step_per_epoch,
                               args.step_per_collect,
                               args.test_num,
                               args.batch_size,
                               train_fn=train_fn,
                               test_fn=test_fn,
                               stop_fn=stop_fn,
                               save_best_fn=save_best_fn,
                               logger=logger,
                               update_per_step=args.update_per_step,
                               test_in_train=False)

    pprint.pprint(result)
    watch()
예제 #5
0
def test_dqn(args=get_args()):
    env = make_atari_env(args)
    args.state_shape = env.observation_space.shape or env.observation_space.n
    args.action_shape = env.action_space.shape or env.action_space.n
    # should be N_FRAMES x H x W
    print("Observations shape:", args.state_shape)
    print("Actions shape:", args.action_shape)
    # make environments
    train_envs = ShmemVectorEnv(
        [lambda: make_atari_env(args) for _ in range(args.training_num)])
    test_envs = ShmemVectorEnv(
        [lambda: make_atari_env_watch(args) for _ in range(args.test_num)])
    # seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    train_envs.seed(args.seed)
    test_envs.seed(args.seed)
    # define model
    net = DQN(*args.state_shape, args.action_shape,
              args.device).to(args.device)
    optim = torch.optim.Adam(net.parameters(), lr=args.lr)
    # define policy
    policy = DQNPolicy(net,
                       optim,
                       args.gamma,
                       args.n_step,
                       target_update_freq=args.target_update_freq)
    # load a previous policy
    if args.resume_path:
        policy.load_state_dict(
            torch.load(args.resume_path, map_location=args.device))
        print("Loaded agent from: ", args.resume_path)
    # replay buffer: `save_last_obs` and `stack_num` can be removed together
    # when you have enough RAM
    buffer = VectorReplayBuffer(args.buffer_size,
                                buffer_num=len(train_envs),
                                ignore_obs_next=True,
                                save_only_last_obs=True,
                                stack_num=args.frames_stack)
    # collector
    train_collector = Collector(policy,
                                train_envs,
                                buffer,
                                exploration_noise=True)
    test_collector = Collector(policy, test_envs, exploration_noise=True)
    # log
    log_path = os.path.join(args.logdir, args.task, 'dqn')
    if args.logger == "tensorboard":
        writer = SummaryWriter(log_path)
        writer.add_text("args", str(args))
        logger = TensorboardLogger(writer)
    else:
        logger = WandbLogger(
            save_interval=1,
            project=args.task,
            name='dqn',
            run_id=args.resume_id,
            config=args,
        )

    def save_fn(policy):
        torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

    def stop_fn(mean_rewards):
        if env.spec.reward_threshold:
            return mean_rewards >= env.spec.reward_threshold
        elif 'Pong' in args.task:
            return mean_rewards >= 20
        else:
            return False

    def train_fn(epoch, env_step):
        # nature DQN setting, linear decay in the first 1M steps
        if env_step <= 1e6:
            eps = args.eps_train - env_step / 1e6 * \
                (args.eps_train - args.eps_train_final)
        else:
            eps = args.eps_train_final
        policy.set_eps(eps)
        if env_step % 1000 == 0:
            logger.write("train/env_step", env_step, {"train/eps": eps})

    def test_fn(epoch, env_step):
        policy.set_eps(args.eps_test)

    def save_checkpoint_fn(epoch, env_step, gradient_step):
        # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
        ckpt_path = os.path.join(log_path, 'checkpoint.pth')
        torch.save({'model': policy.state_dict()}, ckpt_path)
        return ckpt_path

    # watch agent's performance
    def watch():
        print("Setup test envs ...")
        policy.eval()
        policy.set_eps(args.eps_test)
        test_envs.seed(args.seed)
        if args.save_buffer_name:
            print(f"Generate buffer with size {args.buffer_size}")
            buffer = VectorReplayBuffer(args.buffer_size,
                                        buffer_num=len(test_envs),
                                        ignore_obs_next=True,
                                        save_only_last_obs=True,
                                        stack_num=args.frames_stack)
            collector = Collector(policy,
                                  test_envs,
                                  buffer,
                                  exploration_noise=True)
            result = collector.collect(n_step=args.buffer_size)
            print(f"Save buffer into {args.save_buffer_name}")
            # Unfortunately, pickle will cause oom with 1M buffer size
            buffer.save_hdf5(args.save_buffer_name)
        else:
            print("Testing agent ...")
            test_collector.reset()
            result = test_collector.collect(n_episode=args.test_num,
                                            render=args.render)
        rew = result["rews"].mean()
        print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')

    if args.watch:
        watch()
        exit(0)

    # test train_collector and start filling replay buffer
    train_collector.collect(n_step=args.batch_size * args.training_num)
    # trainer
    result = offpolicy_trainer(
        policy,
        train_collector,
        test_collector,
        args.epoch,
        args.step_per_epoch,
        args.step_per_collect,
        args.test_num,
        args.batch_size,
        train_fn=train_fn,
        test_fn=test_fn,
        stop_fn=stop_fn,
        save_fn=save_fn,
        logger=logger,
        update_per_step=args.update_per_step,
        test_in_train=False,
        resume_from_log=args.resume_id is not None,
        save_checkpoint_fn=save_checkpoint_fn,
    )

    pprint.pprint(result)
    watch()
예제 #6
0
def test_vecenv(size=10, num=8, sleep=0.001):
    env_fns = [
        lambda i=i: MyTestEnv(size=i, sleep=sleep, recurse_state=True)
        for i in range(size, size + num)
    ]
    venv = [
        DummyVectorEnv(env_fns),
        SubprocVectorEnv(env_fns),
        ShmemVectorEnv(env_fns),
    ]
    if has_ray() and sys.platform == "linux":
        venv += [RayVectorEnv(env_fns)]
    for v in venv:
        v.seed(0)
    action_list = [1] * 5 + [0] * 10 + [1] * 20
    o = [v.reset() for v in venv]
    for a in action_list:
        o = []
        for v in venv:
            A, B, C, D = v.step([a] * num)
            if sum(C):
                A = v.reset(np.where(C)[0])
            o.append([A, B, C, D])
        for index, infos in enumerate(zip(*o)):
            if index == 3:  # do not check info here
                continue
            for info in infos:
                assert recurse_comp(infos[0], info)

    if __name__ == '__main__':
        t = [0] * len(venv)
        for i, e in enumerate(venv):
            t[i] = time.time()
            e.reset()
            for a in action_list:
                done = e.step([a] * num)[2]
                if sum(done) > 0:
                    e.reset(np.where(done)[0])
            t[i] = time.time() - t[i]
        for i, v in enumerate(venv):
            print(f'{type(v)}: {t[i]:.6f}s')

    def assert_get(v, expected):
        assert v.get_env_attr("size") == expected
        assert v.get_env_attr("size", id=0) == [expected[0]]
        assert v.get_env_attr("size", id=[0, 1, 2]) == expected[:3]

    for v in venv:
        assert_get(v, list(range(size, size + num)))
        assert v.env_num == num
        assert v.action_space == [Discrete(2)] * num

        v.set_env_attr("size", 0)
        assert_get(v, [0] * num)

        v.set_env_attr("size", 1, 0)
        assert_get(v, [1] + [0] * (num - 1))

        v.set_env_attr("size", 2, [1, 2, 3])
        assert_get(v, [1] + [2] * 3 + [0] * (num - 4))

    for v in venv:
        v.close()
예제 #7
0
def make_mujoco_env(task, seed, training_num, test_num, obs_norm):
    """Wrapper function for Mujoco env.

    If EnvPool is installed, it will automatically switch to EnvPool's Mujoco env.

    :return: a tuple of (single env, training envs, test envs).
    """
    if envpool is not None:
        train_envs = env = envpool.make_gym(task, num_envs=training_num, seed=seed)
        test_envs = envpool.make_gym(task, num_envs=training_num, seed=seed)
    else:
        warnings.warn(
            "Recommend using envpool (pip install envpool) "
            "to run Mujoco environments more efficiently."
        )
        env = gym.make(task)
        train_envs = ShmemVectorEnv(
            [lambda: gym.make(task) for _ in range(training_num)]
        )
        test_envs = ShmemVectorEnv([lambda: gym.make(task) for _ in range(test_num)])
        env.seed(seed)
        train_envs.seed(seed)
        test_envs.seed(seed)
    if obs_norm:
        # obs norm wrapper
        train_envs = VectorEnvNormObs(train_envs)
        test_envs = VectorEnvNormObs(test_envs, update_obs_rms=False)
        test_envs.set_obs_rms(train_envs.get_obs_rms())
    return env, train_envs, test_envs
예제 #8
0
def test_discrete_cql(args=get_args()):
    # envs
    env = make_atari_env(args)
    args.state_shape = env.observation_space.shape or env.observation_space.n
    args.action_shape = env.action_space.shape or env.action_space.n
    # should be N_FRAMES x H x W
    print("Observations shape:", args.state_shape)
    print("Actions shape:", args.action_shape)
    # make environments
    test_envs = ShmemVectorEnv(
        [lambda: make_atari_env_watch(args) for _ in range(args.test_num)]
    )
    # seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    test_envs.seed(args.seed)
    # model
    net = QRDQN(*args.state_shape, args.action_shape, args.num_quantiles, args.device)
    optim = torch.optim.Adam(net.parameters(), lr=args.lr)
    # define policy
    policy = DiscreteCQLPolicy(
        net,
        optim,
        args.gamma,
        args.num_quantiles,
        args.n_step,
        args.target_update_freq,
        min_q_weight=args.min_q_weight
    ).to(args.device)
    # load a previous policy
    if args.resume_path:
        policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
        print("Loaded agent from: ", args.resume_path)
    # buffer
    assert os.path.exists(args.load_buffer_name), \
        "Please run atari_qrdqn.py first to get expert's data buffer."
    if args.load_buffer_name.endswith('.pkl'):
        buffer = pickle.load(open(args.load_buffer_name, "rb"))
    elif args.load_buffer_name.endswith('.hdf5'):
        buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
    else:
        print(f"Unknown buffer format: {args.load_buffer_name}")
        exit(0)

    # collector
    test_collector = Collector(policy, test_envs, exploration_noise=True)

    # log
    log_path = os.path.join(
        args.logdir, args.task, 'cql',
        f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}'
    )
    writer = SummaryWriter(log_path)
    writer.add_text("args", str(args))
    logger = TensorboardLogger(writer, update_interval=args.log_interval)

    def save_fn(policy):
        torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

    def stop_fn(mean_rewards):
        return False

    # watch agent's performance
    def watch():
        print("Setup test envs ...")
        policy.eval()
        policy.set_eps(args.eps_test)
        test_envs.seed(args.seed)
        print("Testing agent ...")
        test_collector.reset()
        result = test_collector.collect(n_episode=args.test_num, render=args.render)
        pprint.pprint(result)
        rew = result["rews"].mean()
        print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')

    if args.watch:
        watch()
        exit(0)

    result = offline_trainer(
        policy,
        buffer,
        test_collector,
        args.epoch,
        args.update_per_epoch,
        args.test_num,
        args.batch_size,
        stop_fn=stop_fn,
        save_fn=save_fn,
        logger=logger
    )

    pprint.pprint(result)
    watch()
예제 #9
0
def test_ppo(args=get_args()):
    args.cfg_path = f"maps/{args.task}.cfg"
    args.wad_path = f"maps/{args.task}.wad"
    args.res = (args.skip_num, 84, 84)
    env = Env(args.cfg_path, args.frames_stack, args.res)
    args.state_shape = args.res
    args.action_shape = env.action_space.shape or env.action_space.n
    # should be N_FRAMES x H x W
    print("Observations shape:", args.state_shape)
    print("Actions shape:", args.action_shape)
    # make environments
    train_envs = ShmemVectorEnv([
        lambda: Env(args.cfg_path, args.frames_stack, args.res)
        for _ in range(args.training_num)
    ])
    test_envs = ShmemVectorEnv([
        lambda: Env(args.cfg_path, args.frames_stack, args.res, args.save_lmp)
        for _ in range(min(os.cpu_count() - 1, args.test_num))
    ])
    # seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    train_envs.seed(args.seed)
    test_envs.seed(args.seed)
    # define model
    net = DQN(*args.state_shape,
              args.action_shape,
              device=args.device,
              features_only=True,
              output_dim=args.hidden_size)
    actor = Actor(net,
                  args.action_shape,
                  device=args.device,
                  softmax_output=False)
    critic = Critic(net, device=args.device)
    optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(),
                             lr=args.lr)

    lr_scheduler = None
    if args.lr_decay:
        # decay learning rate to 0 linearly
        max_update_num = np.ceil(
            args.step_per_epoch / args.step_per_collect) * args.epoch

        lr_scheduler = LambdaLR(
            optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)

    # define policy
    def dist(p):
        return torch.distributions.Categorical(logits=p)

    policy = PPOPolicy(actor,
                       critic,
                       optim,
                       dist,
                       discount_factor=args.gamma,
                       gae_lambda=args.gae_lambda,
                       max_grad_norm=args.max_grad_norm,
                       vf_coef=args.vf_coef,
                       ent_coef=args.ent_coef,
                       reward_normalization=args.rew_norm,
                       action_scaling=False,
                       lr_scheduler=lr_scheduler,
                       action_space=env.action_space,
                       eps_clip=args.eps_clip,
                       value_clip=args.value_clip,
                       dual_clip=args.dual_clip,
                       advantage_normalization=args.norm_adv,
                       recompute_advantage=args.recompute_adv).to(args.device)
    if args.icm_lr_scale > 0:
        feature_net = DQN(*args.state_shape,
                          args.action_shape,
                          device=args.device,
                          features_only=True,
                          output_dim=args.hidden_size)
        action_dim = np.prod(args.action_shape)
        feature_dim = feature_net.output_dim
        icm_net = IntrinsicCuriosityModule(feature_net.net,
                                           feature_dim,
                                           action_dim,
                                           device=args.device)
        icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr)
        policy = ICMPolicy(policy, icm_net, icm_optim, args.icm_lr_scale,
                           args.icm_reward_scale,
                           args.icm_forward_loss_weight).to(args.device)
    # load a previous policy
    if args.resume_path:
        policy.load_state_dict(
            torch.load(args.resume_path, map_location=args.device))
        print("Loaded agent from: ", args.resume_path)
    # replay buffer: `save_last_obs` and `stack_num` can be removed together
    # when you have enough RAM
    buffer = VectorReplayBuffer(args.buffer_size,
                                buffer_num=len(train_envs),
                                ignore_obs_next=True,
                                save_only_last_obs=True,
                                stack_num=args.frames_stack)
    # collector
    train_collector = Collector(policy,
                                train_envs,
                                buffer,
                                exploration_noise=True)
    test_collector = Collector(policy, test_envs, exploration_noise=True)
    # log
    log_name = 'ppo_icm' if args.icm_lr_scale > 0 else 'ppo'
    log_path = os.path.join(args.logdir, args.task, log_name)
    writer = SummaryWriter(log_path)
    writer.add_text("args", str(args))
    logger = TensorboardLogger(writer)

    def save_best_fn(policy):
        torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

    def stop_fn(mean_rewards):
        if env.spec.reward_threshold:
            return mean_rewards >= env.spec.reward_threshold
        elif 'Pong' in args.task:
            return mean_rewards >= 20
        else:
            return False

    # watch agent's performance
    def watch():
        print("Setup test envs ...")
        policy.eval()
        test_envs.seed(args.seed)
        if args.save_buffer_name:
            print(f"Generate buffer with size {args.buffer_size}")
            buffer = VectorReplayBuffer(args.buffer_size,
                                        buffer_num=len(test_envs),
                                        ignore_obs_next=True,
                                        save_only_last_obs=True,
                                        stack_num=args.frames_stack)
            collector = Collector(policy,
                                  test_envs,
                                  buffer,
                                  exploration_noise=True)
            result = collector.collect(n_step=args.buffer_size)
            print(f"Save buffer into {args.save_buffer_name}")
            # Unfortunately, pickle will cause oom with 1M buffer size
            buffer.save_hdf5(args.save_buffer_name)
        else:
            print("Testing agent ...")
            test_collector.reset()
            result = test_collector.collect(n_episode=args.test_num,
                                            render=args.render)
        rew = result["rews"].mean()
        lens = result["lens"].mean() * args.skip_num
        print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')
        print(f'Mean length (over {result["n/ep"]} episodes): {lens}')

    if args.watch:
        watch()
        exit(0)

    # test train_collector and start filling replay buffer
    train_collector.collect(n_step=args.batch_size * args.training_num)
    # trainer
    result = onpolicy_trainer(policy,
                              train_collector,
                              test_collector,
                              args.epoch,
                              args.step_per_epoch,
                              args.repeat_per_collect,
                              args.test_num,
                              args.batch_size,
                              step_per_collect=args.step_per_collect,
                              stop_fn=stop_fn,
                              save_best_fn=save_best_fn,
                              logger=logger,
                              test_in_train=False)

    pprint.pprint(result)
    watch()