Example #1
0
def test_episode(
    policy: BasePolicy,
    collector: Collector,
    test_fn: Optional[Callable[[int, Optional[int]], None]],
    epoch: int,
    n_episode: Union[int, List[int]],
    writer: Optional[SummaryWriter] = None,
    global_step: Optional[int] = None,
) -> Dict[str, float]:
    """A simple wrapper of testing policy in collector."""
    collector.reset_env()
    collector.reset_buffer()
    policy.eval()
    if test_fn:
        test_fn(epoch, global_step)
    if collector.get_env_num() > 1 and isinstance(n_episode, int):
        n = collector.get_env_num()
        n_ = np.zeros(n) + n_episode // n
        n_[:n_episode % n] += 1
        n_episode = list(n_)
    result = collector.collect(n_episode=n_episode)
    if writer is not None and global_step is not None:
        for k in result.keys():
            writer.add_scalar("test/" + k, result[k], global_step=global_step)
    return result
Example #2
0
def test_collector():
    writer = SummaryWriter('log/collector')
    logger = Logger(writer)
    env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0) for i in [2, 3, 4, 5]]

    venv = SubprocVectorEnv(env_fns)
    dum = DummyVectorEnv(env_fns)
    policy = MyPolicy()
    env = env_fns[0]()
    c0 = Collector(policy, env, ReplayBuffer(size=100), logger.preprocess_fn)
    c0.collect(n_step=3)
    assert len(c0.buffer) == 3
    assert np.allclose(c0.buffer.obs[:4, 0], [0, 1, 0, 0])
    assert np.allclose(c0.buffer[:].obs_next[..., 0], [1, 2, 1])
    c0.collect(n_episode=3)
    assert len(c0.buffer) == 8
    assert np.allclose(c0.buffer.obs[:10, 0], [0, 1, 0, 1, 0, 1, 0, 1, 0, 0])
    assert np.allclose(c0.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2])
    c0.collect(n_step=3, random=True)
    c1 = Collector(policy, venv,
                   VectorReplayBuffer(total_size=100, buffer_num=4),
                   logger.preprocess_fn)
    c1.collect(n_step=8)
    obs = np.zeros(100)
    obs[[0, 1, 25, 26, 50, 51, 75, 76]] = [0, 1, 0, 1, 0, 1, 0, 1]

    assert np.allclose(c1.buffer.obs[:, 0], obs)
    assert np.allclose(c1.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2])
    c1.collect(n_episode=4)
    assert len(c1.buffer) == 16
    obs[[2, 3, 27, 52, 53, 77, 78, 79]] = [0, 1, 2, 2, 3, 2, 3, 4]
    assert np.allclose(c1.buffer.obs[:, 0], obs)
    assert np.allclose(c1.buffer[:].obs_next[..., 0],
                       [1, 2, 1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5])
    c1.collect(n_episode=4, random=True)
    c2 = Collector(policy, dum, VectorReplayBuffer(total_size=100,
                                                   buffer_num=4),
                   logger.preprocess_fn)
    c2.collect(n_episode=7)
    obs1 = obs.copy()
    obs1[[4, 5, 28, 29, 30]] = [0, 1, 0, 1, 2]
    obs2 = obs.copy()
    obs2[[28, 29, 30, 54, 55, 56, 57]] = [0, 1, 2, 0, 1, 2, 3]
    c2obs = c2.buffer.obs[:, 0]
    assert np.all(c2obs == obs1) or np.all(c2obs == obs2)
    c2.reset_env()
    c2.reset_buffer()
    assert c2.collect(n_episode=8)['n/ep'] == 8
    obs[[4, 5, 28, 29, 30, 54, 55, 56, 57]] = [0, 1, 0, 1, 2, 0, 1, 2, 3]
    assert np.all(c2.buffer.obs[:, 0] == obs)
    c2.collect(n_episode=4, random=True)

    # test corner case
    with pytest.raises(TypeError):
        Collector(policy, dum, ReplayBuffer(10))
    with pytest.raises(TypeError):
        Collector(policy, dum, PrioritizedReplayBuffer(10, 0.5, 0.5))
    with pytest.raises(TypeError):
        c2.collect()
Example #3
0
def test_episode(policy: BasePolicy, collector: Collector,
                 test_fn: Callable[[int], None], epoch: int,
                 n_episode: Union[int, List[int]]) -> Dict[str, float]:
    """A simple wrapper of testing policy in collector."""
    collector.reset_env()
    collector.reset_buffer()
    policy.eval()
    if test_fn:
        test_fn(epoch)
    if collector.get_env_num() > 1 and np.isscalar(n_episode):
        n = collector.get_env_num()
        n_ = np.zeros(n) + n_episode // n
        n_[:n_episode % n] += 1
        n_episode = list(n_)
    return collector.collect(n_episode=n_episode)
Example #4
0
def test_episode(
    policy: BasePolicy,
    collector: Collector,
    test_fn: Optional[Callable[[int, Optional[int]], None]],
    epoch: int,
    n_episode: int,
    logger: Optional[BaseLogger] = None,
    global_step: Optional[int] = None,
    reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
) -> Dict[str, Any]:
    """A simple wrapper of testing policy in collector."""
    collector.reset_env()
    collector.reset_buffer()
    policy.eval()
    if test_fn:
        test_fn(epoch, global_step)
    result = collector.collect(n_episode=n_episode)
    if reward_metric:
        result["rews"] = reward_metric(result["rews"])
    if logger and global_step is not None:
        logger.log_test_data(result, global_step)
    return result
Example #5
0
def onpolicy_trainer(
    policy: BasePolicy,
    train_collector: Collector,
    test_collector: Collector,
    max_epoch: int,
    step_per_epoch: int,
    repeat_per_collect: int,
    episode_per_test: int,
    batch_size: int,
    step_per_collect: Optional[int] = None,
    episode_per_collect: Optional[int] = None,
    train_fn: Optional[Callable[[int, int], None]] = None,
    test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
    stop_fn: Optional[Callable[[float], bool]] = None,
    save_fn: Optional[Callable[[BasePolicy], None]] = None,
    reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
    logger: BaseLogger = LazyLogger(),
    verbose: bool = True,
    test_in_train: bool = True,
) -> Dict[str, Union[float, str]]:
    """A wrapper for on-policy trainer procedure.

    The "step" in trainer means an environment step (a.k.a. transition).

    :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
    :param Collector train_collector: the collector used for training.
    :param Collector test_collector: the collector used for testing.
    :param int max_epoch: the maximum number of epochs for training. The training
        process might be finished before reaching ``max_epoch`` if ``stop_fn`` is set.
    :param int step_per_epoch: the number of transitions collected per epoch.
    :param int repeat_per_collect: the number of repeat time for policy learning, for
        example, set it to 2 means the policy needs to learn each given batch data
        twice.
    :param int episode_per_test: the number of episodes for one policy evaluation.
    :param int batch_size: the batch size of sample data, which is going to feed in the
        policy network.
    :param int step_per_collect: the number of transitions the collector would collect
        before the network update, i.e., trainer will collect "step_per_collect"
        transitions and do some policy network update repeatly in each epoch.
    :param int episode_per_collect: the number of episodes the collector would collect
        before the network update, i.e., trainer will collect "episode_per_collect"
        episodes and do some policy network update repeatly in each epoch.
    :param function train_fn: a hook called at the beginning of training in each epoch.
        It can be used to perform custom additional operations, with the signature ``f(
        num_epoch: int, step_idx: int) -> None``.
    :param function test_fn: a hook called at the beginning of testing in each epoch.
        It can be used to perform custom additional operations, with the signature ``f(
        num_epoch: int, step_idx: int) -> None``.
    :param function save_fn: a hook called when the undiscounted average mean reward in
        evaluation phase gets better, with the signature ``f(policy: BasePolicy) ->
        None``.
    :param function stop_fn: a function with signature ``f(mean_rewards: float) ->
        bool``, receives the average undiscounted returns of the testing result,
        returns a boolean which indicates whether reaching the goal.
    :param function reward_metric: a function with signature ``f(rewards: np.ndarray
        with shape (num_episode, agent_num)) -> np.ndarray with shape (num_episode,)``,
        used in multi-agent RL. We need to return a single scalar for each episode's
        result to monitor training in the multi-agent RL setting. This function
        specifies what is the desired metric, e.g., the reward of agent 1 or the
        average reward over all agents.
    :param BaseLogger logger: A logger that logs statistics during
        training/testing/updating. Default to a logger that doesn't log anything.
    :param bool verbose: whether to print the information. Default to True.
    :param bool test_in_train: whether to test in the training phase. Default to True.

    :return: See :func:`~tianshou.trainer.gather_info`.

    .. note::

        Only either one of step_per_collect and episode_per_collect can be specified.
    """
    env_step, gradient_step = 0, 0
    last_rew, last_len = 0.0, 0
    stat: Dict[str, MovAvg] = defaultdict(MovAvg)
    start_time = time.time()
    train_collector.reset_stat()
    test_collector.reset_stat()
    test_in_train = test_in_train and train_collector.policy == policy
    test_result = test_episode(policy, test_collector, test_fn, 0,
                               episode_per_test, logger, env_step,
                               reward_metric)
    best_epoch = 0
    best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
    for epoch in range(1, 1 + max_epoch):
        # train
        policy.train()
        with tqdm.tqdm(total=step_per_epoch,
                       desc=f"Epoch #{epoch}",
                       **tqdm_config) as t:
            while t.n < t.total:
                if train_fn:
                    train_fn(epoch, env_step)
                result = train_collector.collect(n_step=step_per_collect,
                                                 n_episode=episode_per_collect)
                if reward_metric:
                    result["rews"] = reward_metric(result["rews"])
                env_step += int(result["n/st"])
                t.update(result["n/st"])
                logger.log_train_data(result, env_step)
                last_rew = result['rew'] if 'rew' in result else last_rew
                last_len = result['len'] if 'len' in result else last_len
                data = {
                    "env_step": str(env_step),
                    "rew": f"{last_rew:.2f}",
                    "len": str(int(last_len)),
                    "n/ep": str(int(result["n/ep"])),
                    "n/st": str(int(result["n/st"])),
                }
                if test_in_train and stop_fn and stop_fn(result["rew"]):
                    test_result = test_episode(policy, test_collector, test_fn,
                                               epoch, episode_per_test, logger,
                                               env_step)
                    if stop_fn(test_result["rew"]):
                        if save_fn:
                            save_fn(policy)
                        t.set_postfix(**data)
                        return gather_info(start_time, train_collector,
                                           test_collector, test_result["rew"],
                                           test_result["rew_std"])
                    else:
                        policy.train()
                losses = policy.update(0,
                                       train_collector.buffer,
                                       batch_size=batch_size,
                                       repeat=repeat_per_collect)
                train_collector.reset_buffer()
                step = max(
                    [1] +
                    [len(v) for v in losses.values() if isinstance(v, list)])
                gradient_step += step
                for k in losses.keys():
                    stat[k].add(losses[k])
                    losses[k] = stat[k].get()
                    data[k] = f"{losses[k]:.3f}"
                logger.log_update_data(losses, gradient_step)
                t.set_postfix(**data)
            if t.n <= t.total:
                t.update()
        # test
        test_result = test_episode(policy, test_collector, test_fn, epoch,
                                   episode_per_test, logger, env_step)
        rew, rew_std = test_result["rew"], test_result["rew_std"]
        if best_epoch == -1 or best_reward < rew:
            best_reward, best_reward_std = rew, rew_std
            best_epoch = epoch
            if save_fn:
                save_fn(policy)
        if verbose:
            print(
                f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
                f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}"
            )
        if stop_fn and stop_fn(best_reward):
            break
    return gather_info(start_time, train_collector, test_collector,
                       best_reward, best_reward_std)
Example #6
0
def onpolicy_trainer(
    policy: BasePolicy,
    train_collector: Collector,
    test_collector: Collector,
    max_epoch: int,
    step_per_epoch: int,
    collect_per_step: int,
    repeat_per_collect: int,
    episode_per_test: Union[int, List[int]],
    batch_size: int,
    train_fn: Optional[Callable[[int, int], None]] = None,
    test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
    stop_fn: Optional[Callable[[float], bool]] = None,
    save_fn: Optional[Callable[[BasePolicy], None]] = None,
    writer: Optional[SummaryWriter] = None,
    log_interval: int = 1,
    verbose: bool = True,
    test_in_train: bool = True,
) -> Dict[str, Union[float, str]]:
    """A wrapper for on-policy trainer procedure.

    The "step" in trainer means a policy network update.

    :param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
        class.
    :param train_collector: the collector used for training.
    :type train_collector: :class:`~tianshou.data.Collector`
    :param test_collector: the collector used for testing.
    :type test_collector: :class:`~tianshou.data.Collector`
    :param int max_epoch: the maximum of epochs for training. The training
        process might be finished before reaching the ``max_epoch``.
    :param int step_per_epoch: the number of step for updating policy network
        in one epoch.
    :param int collect_per_step: the number of episodes the collector would
        collect before the network update. In other words, collect some
        episodes and do one policy network update.
    :param int repeat_per_collect: the number of repeat time for policy
        learning, for example, set it to 2 means the policy needs to learn each
        given batch data twice.
    :param episode_per_test: the number of episodes for one policy evaluation.
    :type episode_per_test: int or list of ints
    :param int batch_size: the batch size of sample data, which is going to
        feed in the policy network.
    :param function train_fn: a function receives the current number of epoch
        and step index, and performs some operations at the beginning of
        training in this poch.
    :param function test_fn: a function receives the current number of epoch
        and step index, and performs some operations at the beginning of
        testing in this epoch.
    :param function save_fn: a function for saving policy when the undiscounted
        average mean reward in evaluation phase gets better.
    :param function stop_fn: a function receives the average undiscounted
        returns of the testing result, return a boolean which indicates whether
        reaching the goal.
    :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard
        SummaryWriter.
    :param int log_interval: the log interval of the writer.
    :param bool verbose: whether to print the information.
    :param bool test_in_train: whether to test in the training phase.

    :return: See :func:`~tianshou.trainer.gather_info`.
    """
    env_step, gradient_step = 0, 0
    best_epoch, best_reward, best_reward_std = -1, -1.0, 0.0
    stat: Dict[str, MovAvg] = {}
    start_time = time.time()
    train_collector.reset_stat()
    test_collector.reset_stat()
    test_in_train = test_in_train and train_collector.policy == policy
    for epoch in range(1, 1 + max_epoch):
        # train
        policy.train()
        with tqdm.tqdm(
            total=step_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config
        ) as t:
            while t.n < t.total:
                if train_fn:
                    train_fn(epoch, env_step)
                result = train_collector.collect(n_episode=collect_per_step)
                env_step += int(result["n/st"])
                data = {
                    "env_step": str(env_step),
                    "rew": f"{result['rew']:.2f}",
                    "len": str(int(result["len"])),
                    "n/ep": str(int(result["n/ep"])),
                    "n/st": str(int(result["n/st"])),
                    "v/ep": f"{result['v/ep']:.2f}",
                    "v/st": f"{result['v/st']:.2f}",
                }
                if writer and env_step % log_interval == 0:
                    for k in result.keys():
                        writer.add_scalar(
                            "train/" + k, result[k], global_step=env_step)
                if test_in_train and stop_fn and stop_fn(result["rew"]):
                    test_result = test_episode(
                        policy, test_collector, test_fn,
                        epoch, episode_per_test, writer, env_step)
                    if stop_fn(test_result["rew"]):
                        if save_fn:
                            save_fn(policy)
                        for k in result.keys():
                            data[k] = f"{result[k]:.2f}"
                        t.set_postfix(**data)
                        return gather_info(
                            start_time, train_collector, test_collector,
                            test_result["rew"], test_result["rew_std"])
                    else:
                        policy.train()
                losses = policy.update(
                    0, train_collector.buffer,
                    batch_size=batch_size, repeat=repeat_per_collect)
                train_collector.reset_buffer()
                step = max([1] + [
                    len(v) for v in losses.values() if isinstance(v, list)])
                gradient_step += step
                for k in losses.keys():
                    if stat.get(k) is None:
                        stat[k] = MovAvg()
                    stat[k].add(losses[k])
                    data[k] = f"{stat[k].get():.6f}"
                    if writer and gradient_step % log_interval == 0:
                        writer.add_scalar(
                            k, stat[k].get(), global_step=gradient_step)
                t.update(step)
                t.set_postfix(**data)
            if t.n <= t.total:
                t.update()
        # test
        result = test_episode(policy, test_collector, test_fn, epoch,
                              episode_per_test, writer, env_step)
        if best_epoch == -1 or best_reward < result["rew"]:
            best_reward, best_reward_std = result["rew"], result["rew_std"]
            best_epoch = epoch
            if save_fn:
                save_fn(policy)
        if verbose:
            print(f"Epoch #{epoch}: test_reward: {result['rew']:.6f} ± "
                  f"{result['rew_std']:.6f}, best_reward: {best_reward:.6f} ± "
                  f"{best_reward_std:.6f} in #{best_epoch}")
        if stop_fn and stop_fn(best_reward):
            break
    return gather_info(start_time, train_collector, test_collector,
                       best_reward, best_reward_std)
Example #7
0
def onpolicy_trainer(
        policy: BasePolicy,
        train_collector: Collector,
        test_collector: Collector,
        max_epoch: int,
        frame_per_epoch: int,
        collect_per_step: int,
        repeat_per_collect: int,
        episode_per_test: Union[int, List[int]],
        batch_size: int,
        train_fn: Optional[Callable[[int], None]] = None,
        test_fn: Optional[Callable[[int], None]] = None,
        stop_fn: Optional[Callable[[float], bool]] = None,
        save_fn: Optional[Callable[[BasePolicy], None]] = None,
        log_fn: Optional[Callable[[dict], None]] = None,
        writer: Optional[SummaryWriter] = None,
        log_interval: int = 1,
        verbose: bool = True,
        **kwargs
) -> Dict[str, Union[float, str]]:
    global_step = 0
    best_epoch, best_reward = -1, -1
    stat = {}
    start_time = time.time()
    test_in_train = train_collector.policy == policy
    for epoch in range(1, 1 + max_epoch):
        # train
        policy.train()
        if train_fn:
            train_fn(epoch)
        with tqdm.tqdm(total=frame_per_epoch, desc=f'Epoch #{epoch}',
                       **tqdm_config) as t:
            while t.n < t.total:
                result = train_collector.collect(n_step=collect_per_step,
                                                 log_fn=log_fn)
                data = {}
                if test_in_train and stop_fn and stop_fn(result['rew']):
                    test_result = test_episode(
                        policy, test_collector, test_fn,
                        epoch, episode_per_test)
                    if stop_fn and stop_fn(test_result['rew']):
                        if save_fn:
                            save_fn(policy)
                        for k in result.keys():
                            data[k] = f'{result[k]:.2f}'
                        t.set_postfix(**data)
                        return gather_info(
                            start_time, train_collector, test_collector,
                            test_result['rew'])
                    else:
                        policy.train()
                        if train_fn:
                            train_fn(epoch)
                losses = policy.learn(
                    train_collector.sample(0), batch_size, repeat_per_collect)
                train_collector.reset_buffer()
                global_step += collect_per_step
                for k in result.keys():
                    data[k] = f'{result[k]:.2f}'
                    if writer and global_step % log_interval == 0:
                        writer.add_scalar(
                            k, result[k], global_step=global_step)
                for k in losses.keys():
                    if stat.get(k) is None:
                        stat[k] = MovAvg()
                    stat[k].add(losses[k])
                    data[k] = f'{stat[k].get():.6f}'
                    if writer and global_step % log_interval == 0:
                        writer.add_scalar(
                            k, stat[k].get(), global_step=global_step)
                t.update(collect_per_step)
                t.set_postfix(**data)
            if t.n <= t.total:
                t.update()
        # test
        result = test_episode(
            policy, test_collector, test_fn, epoch, episode_per_test)
        if best_epoch == -1 or best_reward < result['rew']:
            best_reward = result['rew']
            best_epoch = epoch
            if save_fn:
                save_fn(policy)
        if verbose:
            print(f'Epoch #{epoch}: test_reward: {result["rew"]:.6f}, '
                  f'best_reward: {best_reward:.6f} in #{best_epoch}')
        if stop_fn and stop_fn(best_reward):
            break
    return gather_info(
        start_time, train_collector, test_collector, best_reward)
Example #8
0
def onpolicy_trainer(policy: BasePolicy,
                     train_collector: Collector,
                     test_collector: Collector,
                     max_epoch: int,
                     step_per_epoch: int,
                     collect_per_step: int,
                     repeat_per_collect: int,
                     episode_per_test: Union[int, List[int]],
                     batch_size: int,
                     train_fn: Optional[Callable[[int], None]] = None,
                     test_fn: Optional[Callable[[int], None]] = None,
                     stop_fn: Optional[Callable[[float], bool]] = None,
                     save_fn: Optional[Callable[[BasePolicy], None]] = None,
                     log_fn: Optional[Callable[[dict], None]] = None,
                     writer: Optional[SummaryWriter] = None,
                     log_interval: Optional[int] = 1,
                     verbose: Optional[bool] = True,
                     **kwargs) -> Dict[str, Union[float, str]]:
    """A wrapper for on-policy trainer procedure.

    :param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
        class.
    :param train_collector: the collector used for training.
    :type train_collector: :class:`~tianshou.data.Collector`
    :param test_collector: the collector used for testing.
    :type test_collector: :class:`~tianshou.data.Collector`
    :param int max_epoch: the maximum of epochs for training. The training
        process might be finished before reaching the ``max_epoch``.
    :param int step_per_epoch: the number of step for updating policy network
        in one epoch.
    :param int collect_per_step: the number of frames the collector would
        collect before the network update. In other words, collect some frames
        and do one policy network update.
    :param int repeat_per_collect: the number of repeat time for policy
        learning, for example, set it to 2 means the policy needs to learn each
        given batch data twice.
    :param episode_per_test: the number of episodes for one policy evaluation.
    :type episode_per_test: int or list of ints
    :param int batch_size: the batch size of sample data, which is going to
        feed in the policy network.
    :param function train_fn: a function receives the current number of epoch
        index and performs some operations at the beginning of training in this
        epoch.
    :param function test_fn: a function receives the current number of epoch
        index and performs some operations at the beginning of testing in this
        epoch.
    :param function save_fn: a function for saving policy when the undiscounted
        average mean reward in evaluation phase gets better.
    :param function stop_fn: a function receives the average undiscounted
        returns of the testing result, return a boolean which indicates whether
        reaching the goal.
    :param function log_fn: a function receives env info for logging.
    :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard
        SummaryWriter.
    :param int log_interval: the log interval of the writer.
    :param bool verbose: whether to print the information.

    :return: See :func:`~tianshou.trainer.gather_info`.
    """
    global_step = 0
    best_epoch, best_reward = -1, -1
    stat = {}
    start_time = time.time()
    test_in_train = train_collector.policy == policy
    for epoch in range(1, 1 + max_epoch):
        # train
        policy.train()
        if train_fn:
            train_fn(epoch)
        with tqdm.tqdm(total=step_per_epoch,
                       desc=f'Epoch #{epoch}',
                       **tqdm_config) as t:
            while t.n < t.total:
                result = train_collector.collect(n_episode=collect_per_step,
                                                 log_fn=log_fn)
                data = {}
                if test_in_train and stop_fn and stop_fn(result['rew']):
                    test_result = test_episode(policy, test_collector, test_fn,
                                               epoch, episode_per_test)
                    if stop_fn and stop_fn(test_result['rew']):
                        if save_fn:
                            save_fn(policy)
                        for k in result.keys():
                            data[k] = f'{result[k]:.2f}'
                        t.set_postfix(**data)
                        return gather_info(start_time, train_collector,
                                           test_collector, test_result['rew'])
                    else:
                        policy.train()
                        if train_fn:
                            train_fn(epoch)
                losses = policy.learn(train_collector.sample(0), batch_size,
                                      repeat_per_collect)
                train_collector.reset_buffer()
                step = 1
                for k in losses.keys():
                    if isinstance(losses[k], list):
                        step = max(step, len(losses[k]))
                global_step += step
                for k in result.keys():
                    data[k] = f'{result[k]:.2f}'
                    if writer and global_step % log_interval == 0:
                        writer.add_scalar(k,
                                          result[k],
                                          global_step=global_step)
                for k in losses.keys():
                    if stat.get(k) is None:
                        stat[k] = MovAvg()
                    stat[k].add(losses[k])
                    data[k] = f'{stat[k].get():.6f}'
                    if writer and global_step % log_interval == 0:
                        writer.add_scalar(k,
                                          stat[k].get(),
                                          global_step=global_step)
                t.update(step)
                t.set_postfix(**data)
            if t.n <= t.total:
                t.update()
        # test
        result = test_episode(policy, test_collector, test_fn, epoch,
                              episode_per_test)
        if best_epoch == -1 or best_reward < result['rew']:
            best_reward = result['rew']
            best_epoch = epoch
            if save_fn:
                save_fn(policy)
        if verbose:
            print(f'Epoch #{epoch}: test_reward: {result["rew"]:.6f}, '
                  f'best_reward: {best_reward:.6f} in #{best_epoch}')
        if stop_fn and stop_fn(best_reward):
            break
    return gather_info(start_time, train_collector, test_collector,
                       best_reward)
Example #9
0
def onpolicy_trainer(
        policy: BasePolicy,
        train_collector: Collector,
        test_collector: Collector,
        max_epoch: int,
        frame_per_epoch: int,
        collect_per_step: int,
        repeat_per_collect: int,
        episode_per_test: Union[int, Sequence[int]],
        batch_size: int,
        train_fn: Optional[Callable[[int, int], None]] = None,
        test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
        stop_fn: Optional[Callable[[float], bool]] = None,
        save_fn: Optional[Callable[[BasePolicy], None]] = None,
        writer: Optional[SummaryWriter] = None,
        log_interval: int = 1,
        verbose: bool = True,
        test_in_train: bool = True,
        **kwargs) -> Dict[str, Union[float, str]]:
    """Slightly modified Tianshou `onpolicy_trainer` original method to enable
    to define the maximum number of training steps instead of number of
    episodes, for consistency with other learning frameworks.
    """
    global_step = 0
    best_epoch, best_reward = -1, -1.0
    stat: Dict[str, MovAvg] = {}
    start_time = time.time()
    train_collector.reset_stat()
    test_collector.reset_stat()
    test_in_train = test_in_train and train_collector.policy == policy
    for epoch in range(1, 1 + max_epoch):
        # train
        policy.train()
        with tqdm.tqdm(
            total=frame_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config
        ) as t:
            while t.n < t.total:
                if train_fn:
                    train_fn(epoch, global_step)
                result = train_collector.collect(n_step=collect_per_step)
                data = {}
                if test_in_train and stop_fn and stop_fn(result["rew"]):
                    test_result = test_episode(
                        policy, test_collector, test_fn,
                        epoch, episode_per_test, writer, global_step)
                    if stop_fn(test_result["rew"]):
                        if save_fn:
                            save_fn(policy)
                        for k in result.keys():
                            data[k] = f"{result[k]:.2f}"
                        t.set_postfix(**data)
                        return gather_info(
                            start_time, train_collector, test_collector,
                            test_result["rew"])
                    else:
                        policy.train()
                losses = policy.update(
                    0, train_collector.buffer,
                    batch_size=batch_size, repeat=repeat_per_collect)
                train_collector.reset_buffer()
                step = 1
                for v in losses.values():
                    if isinstance(v, (list, tuple)):
                        step = max(step, len(v))
                global_step += step * collect_per_step
                for k in result.keys():
                    data[k] = f"{result[k]:.2f}"
                    if writer and global_step % log_interval == 0:
                        writer.add_scalar(
                            "train/" + k, result[k], global_step=global_step)
                for k in losses.keys():
                    if stat.get(k) is None:
                        stat[k] = MovAvg()
                    stat[k].add(losses[k])
                    data[k] = f"{stat[k].get():.6f}"
                    if writer and global_step % log_interval == 0:
                        writer.add_scalar(
                            k, stat[k].get(), global_step=global_step)
                t.update(collect_per_step)
                t.set_postfix(**data)
            if t.n <= t.total:
                t.update()
        # test
        result = test_episode(policy, test_collector, test_fn, epoch,
                              episode_per_test, writer, global_step)
        if best_epoch == -1 or best_reward < result["rew"]:
            best_reward = result["rew"]
            best_epoch = epoch
            if save_fn:
                save_fn(policy)
        if verbose:
            print(f"Epoch #{epoch}: test_reward: {result['rew']:.6f}, "
                  f"best_reward: {best_reward:.6f} in #{best_epoch}")
        if stop_fn and stop_fn(best_reward):
            break
    return gather_info(
        start_time, train_collector, test_collector, best_reward)
Example #10
0
def onpolicy_trainer(
    policy: BasePolicy,
    train_collector: Collector,
    max_epoch: int,
    step_per_epoch: int,
    collect_per_step: int,
    repeat_per_collect: int,
    batch_size: int,
    train_fn: Optional[Callable[[int], None]] = None,
    log_interval: int = 1,
    verbose: bool = True,
    test_in_train: bool = True,
    writer: Optional[SummaryWriter] = None,
):
    """A wrapper for on-policy trainer procedure. The ``step`` in trainer means
    a policy network update.

    :param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
        class.
    :param train_collector: the collector used for training.
    :type train_collector: :class:`~tianshou.data.Collector`
    :param test_collector: the collector used for testing.
    :type test_collector: :class:`~tianshou.data.Collector`
    :param int max_epoch: the maximum of epochs for training. The training
        process might be finished before reaching the ``max_epoch``.
    :param int step_per_epoch: the number of step for updating policy network
        in one epoch.
    :param int collect_per_step: the number of episodes the collector would
        collect before the network update. In other words, collect some
        episodes and do one policy network update.
    :param int repeat_per_collect: the number of repeat time for policy
        learning, for example, set it to 2 means the policy needs to learn each
        given batch data twice.
    :param int batch_size: the batch size of sample data, which is going to
        feed in the policy network.
    :param function train_fn: a function receives the current number of epoch
        index and performs some operations at the beginning of training in this
        epoch.
    :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard
        SummaryWriter.
    :param int log_interval: the log interval of the writer.
    :param bool verbose: whether to print the information.
    :param bool test_in_train: whether to test in the training phase.

    :return: See :func:`~tianshou.trainer.gather_info`.
    """
    global_step = 0
    best_epoch, best_reward = -1, -1
    stat = {}
    start_time = time.time()
    test_in_train = test_in_train and train_collector.policy == policy
    for epoch in range(1, 1 + max_epoch):
        # train
        policy.train()
        if train_fn:
            train_fn(epoch)
        with tqdm.tqdm(total=step_per_epoch,
                       desc=f'Epoch #{epoch}',
                       **tqdm_config) as t:
            while t.n < t.total:
                result = train_collector.collect(n_episode=collect_per_step)
                data = {}
                losses = policy.update(0, train_collector.buffer, batch_size,
                                       repeat_per_collect)
                train_collector.reset_buffer()
                step = 1
                for k in losses.keys():
                    if isinstance(losses[k], list):
                        step = max(step, len(losses[k]))
                global_step += step
                for k in result.keys():
                    data[k] = f'{result[k]:.2f}'
                    if writer and global_step % log_interval == 0:
                        writer.add_scalar(k,
                                          result[k],
                                          global_step=global_step)
                for k in losses.keys():
                    if stat.get(k) is None:
                        stat[k] = MovAvg()
                    stat[k].add(losses[k])
                    data[k] = f'{stat[k].get():.6f}'
                    if writer and global_step % log_interval == 0:
                        writer.add_scalar(k,
                                          stat[k].get(),
                                          global_step=global_step)
                t.update(step)
                t.set_postfix(**data)
            if t.n <= t.total:
                t.update()
    return global_step
Example #11
0
def test_collector(gym_reset_kwargs):
    writer = SummaryWriter('log/collector')
    logger = Logger(writer)
    env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0) for i in [2, 3, 4, 5]]

    venv = SubprocVectorEnv(env_fns)
    dum = DummyVectorEnv(env_fns)
    policy = MyPolicy()
    env = env_fns[0]()
    c0 = Collector(
        policy,
        env,
        ReplayBuffer(size=100),
        logger.preprocess_fn,
    )
    c0.collect(n_step=3, gym_reset_kwargs=gym_reset_kwargs)
    assert len(c0.buffer) == 3
    assert np.allclose(c0.buffer.obs[:4, 0], [0, 1, 0, 0])
    assert np.allclose(c0.buffer[:].obs_next[..., 0], [1, 2, 1])
    keys = np.zeros(100)
    keys[:3] = 1
    assert np.allclose(c0.buffer.info["key"], keys)
    for e in c0.buffer.info["env"][:3]:
        assert isinstance(e, MyTestEnv)
    assert np.allclose(c0.buffer.info["env_id"], 0)
    rews = np.zeros(100)
    rews[:3] = [0, 1, 0]
    assert np.allclose(c0.buffer.info["rew"], rews)
    c0.collect(n_episode=3, gym_reset_kwargs=gym_reset_kwargs)
    assert len(c0.buffer) == 8
    assert np.allclose(c0.buffer.obs[:10, 0], [0, 1, 0, 1, 0, 1, 0, 1, 0, 0])
    assert np.allclose(c0.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2])
    assert np.allclose(c0.buffer.info["key"][:8], 1)
    for e in c0.buffer.info["env"][:8]:
        assert isinstance(e, MyTestEnv)
    assert np.allclose(c0.buffer.info["env_id"][:8], 0)
    assert np.allclose(c0.buffer.info["rew"][:8], [0, 1, 0, 1, 0, 1, 0, 1])
    c0.collect(n_step=3, random=True, gym_reset_kwargs=gym_reset_kwargs)

    c1 = Collector(policy, venv,
                   VectorReplayBuffer(total_size=100, buffer_num=4),
                   logger.preprocess_fn)
    c1.collect(n_step=8, gym_reset_kwargs=gym_reset_kwargs)
    obs = np.zeros(100)
    valid_indices = [0, 1, 25, 26, 50, 51, 75, 76]
    obs[valid_indices] = [0, 1, 0, 1, 0, 1, 0, 1]
    assert np.allclose(c1.buffer.obs[:, 0], obs)
    assert np.allclose(c1.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2])
    keys = np.zeros(100)
    keys[valid_indices] = [1, 1, 1, 1, 1, 1, 1, 1]
    assert np.allclose(c1.buffer.info["key"], keys)
    for e in c1.buffer.info["env"][valid_indices]:
        assert isinstance(e, MyTestEnv)
    env_ids = np.zeros(100)
    env_ids[valid_indices] = [0, 0, 1, 1, 2, 2, 3, 3]
    assert np.allclose(c1.buffer.info["env_id"], env_ids)
    rews = np.zeros(100)
    rews[valid_indices] = [0, 1, 0, 0, 0, 0, 0, 0]
    assert np.allclose(c1.buffer.info["rew"], rews)
    c1.collect(n_episode=4, gym_reset_kwargs=gym_reset_kwargs)
    assert len(c1.buffer) == 16
    valid_indices = [2, 3, 27, 52, 53, 77, 78, 79]
    obs[[2, 3, 27, 52, 53, 77, 78, 79]] = [0, 1, 2, 2, 3, 2, 3, 4]
    assert np.allclose(c1.buffer.obs[:, 0], obs)
    assert np.allclose(c1.buffer[:].obs_next[..., 0],
                       [1, 2, 1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5])
    keys[valid_indices] = [1, 1, 1, 1, 1, 1, 1, 1]
    assert np.allclose(c1.buffer.info["key"], keys)
    for e in c1.buffer.info["env"][valid_indices]:
        assert isinstance(e, MyTestEnv)
    env_ids[valid_indices] = [0, 0, 1, 2, 2, 3, 3, 3]
    assert np.allclose(c1.buffer.info["env_id"], env_ids)
    rews[valid_indices] = [0, 1, 1, 0, 1, 0, 0, 1]
    assert np.allclose(c1.buffer.info["rew"], rews)
    c1.collect(n_episode=4, random=True, gym_reset_kwargs=gym_reset_kwargs)

    c2 = Collector(policy, dum, VectorReplayBuffer(total_size=100,
                                                   buffer_num=4),
                   logger.preprocess_fn)
    c2.collect(n_episode=7, gym_reset_kwargs=gym_reset_kwargs)
    obs1 = obs.copy()
    obs1[[4, 5, 28, 29, 30]] = [0, 1, 0, 1, 2]
    obs2 = obs.copy()
    obs2[[28, 29, 30, 54, 55, 56, 57]] = [0, 1, 2, 0, 1, 2, 3]
    c2obs = c2.buffer.obs[:, 0]
    assert np.all(c2obs == obs1) or np.all(c2obs == obs2)
    c2.reset_env(gym_reset_kwargs=gym_reset_kwargs)
    c2.reset_buffer()
    assert c2.collect(n_episode=8,
                      gym_reset_kwargs=gym_reset_kwargs)['n/ep'] == 8
    valid_indices = [4, 5, 28, 29, 30, 54, 55, 56, 57]
    obs[valid_indices] = [0, 1, 0, 1, 2, 0, 1, 2, 3]
    assert np.all(c2.buffer.obs[:, 0] == obs)
    keys[valid_indices] = [1, 1, 1, 1, 1, 1, 1, 1, 1]
    assert np.allclose(c2.buffer.info["key"], keys)
    for e in c2.buffer.info["env"][valid_indices]:
        assert isinstance(e, MyTestEnv)
    env_ids[valid_indices] = [0, 0, 1, 1, 1, 2, 2, 2, 2]
    assert np.allclose(c2.buffer.info["env_id"], env_ids)
    rews[valid_indices] = [0, 1, 0, 0, 1, 0, 0, 0, 1]
    assert np.allclose(c2.buffer.info["rew"], rews)
    c2.collect(n_episode=4, random=True, gym_reset_kwargs=gym_reset_kwargs)

    # test corner case
    with pytest.raises(TypeError):
        Collector(policy, dum, ReplayBuffer(10))
    with pytest.raises(TypeError):
        Collector(policy, dum, PrioritizedReplayBuffer(10, 0.5, 0.5))
    with pytest.raises(TypeError):
        c2.collect()

    # test NXEnv
    for obs_type in ["array", "object"]:
        envs = SubprocVectorEnv(
            [lambda i=x, t=obs_type: NXEnv(i, t) for x in [5, 10, 15, 20]])
        c3 = Collector(policy, envs,
                       VectorReplayBuffer(total_size=100, buffer_num=4))
        c3.collect(n_step=6, gym_reset_kwargs=gym_reset_kwargs)
        assert c3.buffer.obs.dtype == object