Exemple #1
0
 def __init__(self, policy, env, buffer=None, stat_size=100, **kwargs):
     super().__init__()
     self.env = env
     self.env_num = 1
     self.collect_step = 0
     self.collect_episode = 0
     self.collect_time = 0
     if buffer is None:
         self.buffer = ReplayBuffer(100)
     else:
         self.buffer = buffer
     self.policy = policy
     self.process_fn = policy.process_fn
     self._multi_env = isinstance(env, BaseVectorEnv)
     self._multi_buf = False  # True if buf is a list
     # need multiple cache buffers only if storing in one buffer
     self._cached_buf = []
     if self._multi_env:
         self.env_num = len(env)
         if isinstance(self.buffer, list):
             assert len(self.buffer) == self.env_num, \
                 'The number of data buffer does not match the number of ' \
                 'input env.'
             self._multi_buf = True
         elif isinstance(self.buffer, ReplayBuffer):
             self._cached_buf = [
                 ListReplayBuffer() for _ in range(self.env_num)]
         else:
             raise TypeError('The buffer in data collector is invalid!')
     self.reset_env()
     self.reset_buffer()
     # state over batch is either a list, an np.ndarray, or a torch.Tensor
     self.state = None
     self.step_speed = MovAvg(stat_size)
     self.episode_speed = MovAvg(stat_size)
Exemple #2
0
 def reset(self) -> None:
     """Reset all related variables in the collector."""
     self.data = Batch(state={}, obs={}, act={}, rew={}, done={}, info={},
                       obs_next={}, policy={})
     self.reset_env()
     self.reset_buffer()
     self.step_speed = MovAvg(self.stat_size)
     self.episode_speed = MovAvg(self.stat_size)
     self.collect_time, self.collect_step, self.collect_episode = 0., 0, 0
     if self._action_noise is not None:
         self._action_noise.reset()
Exemple #3
0
 def reset(self):
     """Reset all related variables in the collector."""
     self.reset_env()
     self.reset_buffer()
     # state over batch is either a list, an np.ndarray, or a torch.Tensor
     self.state = None
     self.step_speed = MovAvg(self.stat_size)
     self.episode_speed = MovAvg(self.stat_size)
     self.collect_step = 0
     self.collect_episode = 0
     self.collect_time = 0
Exemple #4
0
def offpolicy_trainer(
    policy: BasePolicy,
    train_collector,
    max_epoch: int,
    step_per_epoch: int,
    collect_per_step: int,
    batch_size: int,
    update_per_step: int = 1,
    train_fn: Optional[Callable[[int], None]] = None,
    writer: Optional[SummaryWriter] = None,
    log_interval: int = 1000,
) -> int:
    """A wrapper for off-policy trainer procedure. The ``step`` in trainer
    means a policy network update.

    :param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
        class.
    :param train_collector: the collector used for training.
    :type train_collector: :class:`~tianshou.data.Collector`
    :param test_collector: the collector used for testing.
    :type test_collector: :class:`~tianshou.data.Collector`
    :param int max_epoch: the maximum of epochs for training. The training
        process might be finished before reaching the ``max_epoch``.
    :param int step_per_epoch: the number of step for updating policy network
        in one epoch.
    :param int collect_per_step: the number of frames the collector would
        collect before the network update. In other words, collect some frames
        and do some policy network update.
    :param episode_per_test: the number of episodes for one policy evaluation.
    :param int batch_size: the batch size of sample data, which is going to
        feed in the policy network.
    :param int update_per_step: the number of times the policy network would
        be updated after frames are collected, for example, set it to 256 means
        it updates policy 256 times once after ``collect_per_step`` frames are
        collected.
    :param function train_fn: a function receives the current number of epoch
        index and performs some operations at the beginning of training in this
        epoch.
    :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard
        SummaryWriter.
    :param int log_interval: the log interval of the writer.

    :return: See :func:`~tianshou.trainer.gather_info`.
    """
    global_step = 0
    best_epoch, best_reward = -1, -1.
    stat = {}
    start_time = time.time()
    for epoch in range(1, 1 + max_epoch):
        # train
        policy.train()
        if train_fn:
            train_fn(epoch)
        with tqdm.tqdm(total=step_per_epoch,
                       desc=f'Epoch #{epoch}',
                       **tqdm_config) as t:
            results = collections.deque(maxlen=100)
            while t.n < t.total:
                assert train_collector.policy == policy
                result = train_collector.collect(n_step=collect_per_step)
                results.extend([result])
                data = {}
                for i in range(update_per_step * min(
                        min(100, result['n/st']) // collect_per_step,
                        t.total - t.n)):
                    losses = policy.update(batch_size, train_collector.buffer)
                    global_step += collect_per_step
                    for k in result.keys():
                        data[k] = f'{result[k]:.2f}'
                        if writer and global_step % log_interval == 0:
                            writer.add_scalar('train/' + k,
                                              np.mean([r[k] for r in results]),
                                              global_step=global_step)
                    for k in losses.keys():
                        if stat.get(k) is None:
                            stat[k] = MovAvg()
                        stat[k].add(losses[k])
                        data[k] = f'{stat[k].get():.6f}'
                        if writer and global_step % log_interval == 0:
                            writer.add_scalar(k,
                                              stat[k].get(),
                                              global_step=global_step)
                    data['exp_noise'] = policy._noise._sigma
                    t.update(1)
                    t.set_postfix(**data)
            if t.n <= t.total:
                t.update()
    return global_step
Exemple #5
0
def onpolicy_trainer(
        policy: BasePolicy,
        train_collector: Collector,
        test_collector: Collector,
        max_epoch: int,
        frame_per_epoch: int,
        collect_per_step: int,
        repeat_per_collect: int,
        episode_per_test: Union[int, List[int]],
        batch_size: int,
        train_fn: Optional[Callable[[int], None]] = None,
        test_fn: Optional[Callable[[int], None]] = None,
        stop_fn: Optional[Callable[[float], bool]] = None,
        save_fn: Optional[Callable[[BasePolicy], None]] = None,
        log_fn: Optional[Callable[[dict], None]] = None,
        writer: Optional[SummaryWriter] = None,
        log_interval: int = 1,
        verbose: bool = True,
        **kwargs
) -> Dict[str, Union[float, str]]:
    global_step = 0
    best_epoch, best_reward = -1, -1
    stat = {}
    start_time = time.time()
    test_in_train = train_collector.policy == policy
    for epoch in range(1, 1 + max_epoch):
        # train
        policy.train()
        if train_fn:
            train_fn(epoch)
        with tqdm.tqdm(total=frame_per_epoch, desc=f'Epoch #{epoch}',
                       **tqdm_config) as t:
            while t.n < t.total:
                result = train_collector.collect(n_step=collect_per_step,
                                                 log_fn=log_fn)
                data = {}
                if test_in_train and stop_fn and stop_fn(result['rew']):
                    test_result = test_episode(
                        policy, test_collector, test_fn,
                        epoch, episode_per_test)
                    if stop_fn and stop_fn(test_result['rew']):
                        if save_fn:
                            save_fn(policy)
                        for k in result.keys():
                            data[k] = f'{result[k]:.2f}'
                        t.set_postfix(**data)
                        return gather_info(
                            start_time, train_collector, test_collector,
                            test_result['rew'])
                    else:
                        policy.train()
                        if train_fn:
                            train_fn(epoch)
                losses = policy.learn(
                    train_collector.sample(0), batch_size, repeat_per_collect)
                train_collector.reset_buffer()
                global_step += collect_per_step
                for k in result.keys():
                    data[k] = f'{result[k]:.2f}'
                    if writer and global_step % log_interval == 0:
                        writer.add_scalar(
                            k, result[k], global_step=global_step)
                for k in losses.keys():
                    if stat.get(k) is None:
                        stat[k] = MovAvg()
                    stat[k].add(losses[k])
                    data[k] = f'{stat[k].get():.6f}'
                    if writer and global_step % log_interval == 0:
                        writer.add_scalar(
                            k, stat[k].get(), global_step=global_step)
                t.update(collect_per_step)
                t.set_postfix(**data)
            if t.n <= t.total:
                t.update()
        # test
        result = test_episode(
            policy, test_collector, test_fn, epoch, episode_per_test)
        if best_epoch == -1 or best_reward < result['rew']:
            best_reward = result['rew']
            best_epoch = epoch
            if save_fn:
                save_fn(policy)
        if verbose:
            print(f'Epoch #{epoch}: test_reward: {result["rew"]:.6f}, '
                  f'best_reward: {best_reward:.6f} in #{best_epoch}')
        if stop_fn and stop_fn(best_reward):
            break
    return gather_info(
        start_time, train_collector, test_collector, best_reward)
Exemple #6
0
def onpolicy_trainer(policy,
                     train_collector,
                     test_collector,
                     max_epoch,
                     step_per_epoch,
                     collect_per_step,
                     repeat_per_collect,
                     episode_per_test,
                     batch_size,
                     train_fn=None,
                     test_fn=None,
                     stop_fn=None,
                     save_fn=None,
                     log_fn=None,
                     writer=None,
                     log_interval=1,
                     verbose=True,
                     task='',
                     **kwargs):
    """A wrapper for on-policy trainer procedure.

    :param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
        class.
    :param train_collector: the collector used for training.
    :type train_collector: :class:`~tianshou.data.Collector`
    :param test_collector: the collector used for testing.
    :type test_collector: :class:`~tianshou.data.Collector`
    :param int max_epoch: the maximum of epochs for training. The training
        process might be finished before reaching the ``max_epoch``.
    :param int step_per_epoch: the number of step for updating policy network
        in one epoch.
    :param int collect_per_step: the number of frames the collector would
        collect before the network update. In other words, collect some frames
        and do one policy network update.
    :param int repeat_per_collect: the number of repeat time for policy
        learning, for example, set it to 2 means the policy needs to learn each
        given batch data twice.
    :param episode_per_test: the number of episodes for one policy evaluation.
    :type episode_per_test: int or list of ints
    :param int batch_size: the batch size of sample data, which is going to
        feed in the policy network.
    :param function train_fn: a function receives the current number of epoch
        index and performs some operations at the beginning of training in this
        epoch.
    :param function test_fn: a function receives the current number of epoch
        index and performs some operations at the beginning of testing in this
        epoch.
    :param function save_fn: a function for saving policy when the undiscounted
        average mean reward in evaluation phase gets better.
    :param function stop_fn: a function receives the average undiscounted
        returns of the testing result, return a boolean which indicates whether
        reaching the goal.
    :param function log_fn: a function receives env info for logging.
    :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard
        SummaryWriter.
    :param int log_interval: the log interval of the writer.
    :param bool verbose: whether to print the information.

    :return: See :func:`~tianshou.trainer.gather_info`.
    """
    global_step = 0
    best_epoch, best_reward = -1, -1
    stat = {}
    start_time = time.time()
    for epoch in range(1, 1 + max_epoch):
        # train
        policy.train()
        if train_fn:
            train_fn(epoch)
        with tqdm.tqdm(total=step_per_epoch,
                       desc=f'Epoch #{epoch}',
                       **tqdm_config) as t:
            while t.n < t.total:
                result = train_collector.collect(n_episode=collect_per_step,
                                                 log_fn=log_fn)
                data = {}
                if stop_fn and stop_fn(result['rew']):
                    test_result = test_episode(policy, test_collector, test_fn,
                                               epoch, episode_per_test)
                    if stop_fn and stop_fn(test_result['rew']):
                        if save_fn:
                            save_fn(policy)
                        for k in result.keys():
                            data[k] = f'{result[k]:.2f}'
                        t.set_postfix(**data)
                        return gather_info(start_time, train_collector,
                                           test_collector, test_result['rew'])
                    else:
                        policy.train()
                        if train_fn:
                            train_fn(epoch)
                losses = policy.learn(train_collector.sample(0), batch_size,
                                      repeat_per_collect)
                train_collector.reset_buffer()
                step = 1
                for k in losses.keys():
                    if isinstance(losses[k], list):
                        step = max(step, len(losses[k]))
                global_step += step
                for k in result.keys():
                    data[k] = f'{result[k]:.2f}'
                    if writer and global_step % log_interval == 0:
                        writer.add_scalar(k + '_' + task if task else k,
                                          result[k],
                                          global_step=global_step)
                for k in losses.keys():
                    if stat.get(k) is None:
                        stat[k] = MovAvg()
                    stat[k].add(losses[k])
                    data[k] = f'{stat[k].get():.6f}'
                    if writer and global_step % log_interval == 0:
                        writer.add_scalar(k + '_' + task if task else k,
                                          stat[k].get(),
                                          global_step=global_step)
                t.update(step)
                t.set_postfix(**data)
            if t.n <= t.total:
                t.update()
        # test
        result = test_episode(policy, test_collector, test_fn, epoch,
                              episode_per_test)
        if best_epoch == -1 or best_reward < result['rew']:
            best_reward = result['rew']
            best_epoch = epoch
            if save_fn:
                save_fn(policy)
        if verbose:
            print(f'Epoch #{epoch}: test_reward: {result["rew"]:.6f}, '
                  f'best_reward: {best_reward:.6f} in #{best_epoch}')
        if stop_fn and stop_fn(best_reward):
            break
    return gather_info(start_time, train_collector, test_collector,
                       best_reward)
Exemple #7
0
class Collector(object):
    """The :class:`~tianshou.data.Collector` enables the policy to interact
    with different types of environments conveniently.

    :param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
        class.
    :param env: a ``gym.Env`` environment or an instance of the
        :class:`~tianshou.env.BaseVectorEnv` class.
    :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer`
        class, or a list of :class:`~tianshou.data.ReplayBuffer`. If set to
        ``None``, it will automatically assign a small-size
        :class:`~tianshou.data.ReplayBuffer`.
    :param function preprocess_fn: a function called before the data has been
        added to the buffer, see issue #42, defaults to ``None``.
    :param int stat_size: for the moving average of recording speed, defaults
        to 100.
    :param BaseNoise action_noise: add a noise to continuous action. Normally
        a policy already has a noise param for exploration in training phase,
        so this is recommended to use in test collector for some purpose.

    The ``preprocess_fn`` is a function called before the data has been added
    to the buffer with batch format, which receives up to 7 keys as listed in
    :class:`~tianshou.data.Batch`. It will receive with only ``obs`` when the
    collector resets the environment. It returns either a dict or a
    :class:`~tianshou.data.Batch` with the modified keys and values. Examples
    are in "test/base/test_collector.py".

    Example:
    ::

        policy = PGPolicy(...)  # or other policies if you wish
        env = gym.make('CartPole-v0')
        replay_buffer = ReplayBuffer(size=10000)
        # here we set up a collector with a single environment
        collector = Collector(policy, env, buffer=replay_buffer)

        # the collector supports vectorized environments as well
        envs = VectorEnv([lambda: gym.make('CartPole-v0') for _ in range(3)])
        buffers = [ReplayBuffer(size=5000) for _ in range(3)]
        # you can also pass a list of replay buffer to collector, for multi-env
        # collector = Collector(policy, envs, buffer=buffers)
        collector = Collector(policy, envs, buffer=replay_buffer)

        # collect at least 3 episodes
        collector.collect(n_episode=3)
        # collect 1 episode for the first env, 3 for the third env
        collector.collect(n_episode=[1, 0, 3])
        # collect at least 2 steps
        collector.collect(n_step=2)
        # collect episodes with visual rendering (the render argument is the
        #   sleep time between rendering consecutive frames)
        collector.collect(n_episode=1, render=0.03)

        # sample data with a given number of batch-size:
        batch_data = collector.sample(batch_size=64)
        # policy.learn(batch_data)  # btw, vanilla policy gradient only
        #   supports on-policy training, so here we pick all data in the buffer
        batch_data = collector.sample(batch_size=0)
        policy.learn(batch_data)
        # on-policy algorithms use the collected data only once, so here we
        #   clear the buffer
        collector.reset_buffer()

    For the scenario of collecting data from multiple environments to a single
    buffer, the cache buffers will turn on automatically. It may return the
    data more than the given limitation.

    .. note::

        Please make sure the given environment has a time limitation.
    """
    def __init__(self,
                 policy: BasePolicy,
                 env: Union[gym.Env, BaseVectorEnv],
                 buffer: Optional[Union[ReplayBuffer,
                                        List[ReplayBuffer]]] = None,
                 preprocess_fn: Callable[[Any], Union[dict, Batch]] = None,
                 stat_size: Optional[int] = 100,
                 action_noise: Optional[BaseNoise] = None,
                 **kwargs) -> None:
        super().__init__()
        self.env = env
        self.env_num = 1
        self.collect_time = 0
        self.collect_step = 0
        self.collect_episode = 0
        self.buffer = buffer
        self.policy = policy
        self.preprocess_fn = preprocess_fn
        # if preprocess_fn is None:
        #     def _prep(**kwargs):
        #         return kwargs
        #     self.preprocess_fn = _prep
        self.process_fn = policy.process_fn
        self._multi_env = isinstance(env, BaseVectorEnv)
        self._multi_buf = False  # True if buf is a list
        # need multiple cache buffers only if storing in one buffer
        self._cached_buf = []
        if self._multi_env:
            self.env_num = len(env)
            if isinstance(self.buffer, list):
                assert len(self.buffer) == self.env_num, \
                    'The number of data buffer does not match the number of ' \
                    'input env.'
                self._multi_buf = True
            elif isinstance(self.buffer, ReplayBuffer) or self.buffer is None:
                self._cached_buf = [
                    ListReplayBuffer() for _ in range(self.env_num)
                ]
            else:
                raise TypeError('The buffer in data collector is invalid!')
        self.stat_size = stat_size
        self._action_noise = action_noise
        self.reset()

    def reset(self) -> None:
        """Reset all related variables in the collector."""
        self.reset_env()
        self.reset_buffer()
        # state over batch is either a list, an np.ndarray, or a torch.Tensor
        self.state = None
        self.step_speed = MovAvg(self.stat_size)
        self.episode_speed = MovAvg(self.stat_size)
        self.collect_step = 0
        self.collect_episode = 0
        self.collect_time = 0
        if self._action_noise is not None:
            self._action_noise.reset()

    def reset_buffer(self) -> None:
        """Reset the main data buffer."""
        if self._multi_buf:
            for b in self.buffer:
                b.reset()
        else:
            if self.buffer is not None:
                self.buffer.reset()

    def get_env_num(self) -> int:
        """Return the number of environments the collector have."""
        return self.env_num

    def reset_env(self) -> None:
        """Reset all of the environment(s)' states and reset all of the cache
        buffers (if need).
        """
        self._obs = self.env.reset()
        if not self._multi_env:
            self._obs = self._make_batch(self._obs)
        if self.preprocess_fn:
            self._obs = self.preprocess_fn(obs=self._obs).get('obs', self._obs)
        self._act = self._rew = self._done = self._info = None
        if self._multi_env:
            self.reward = np.zeros(self.env_num)
            self.length = np.zeros(self.env_num)
        else:
            self.reward, self.length = 0, 0
        for b in self._cached_buf:
            b.reset()

    def seed(self, seed: Optional[Union[int, List[int]]] = None) -> None:
        """Reset all the seed(s) of the given environment(s)."""
        if hasattr(self.env, 'seed'):
            return self.env.seed(seed)

    def render(self, **kwargs) -> None:
        """Render all the environment(s)."""
        if hasattr(self.env, 'render'):
            return self.env.render(**kwargs)

    def close(self) -> None:
        """Close the environment(s)."""
        if hasattr(self.env, 'close'):
            self.env.close()

    def _make_batch(self, data: Any) -> np.ndarray:
        """Return [data]."""
        if isinstance(data, np.ndarray):
            return data[None]
        else:
            return np.array([data])

    def _reset_state(self, id: Union[int, List[int]]) -> None:
        """Reset self.state[id]."""
        if self.state is None:
            return
        if isinstance(self.state, list):
            self.state[id] = None
        elif isinstance(self.state, torch.Tensor):
            self.state[id].zero_()
        elif isinstance(self.state, np.ndarray):
            if isinstance(self.state.dtype == np.object):
                self.state[id] = None
            else:
                self.state[id] = 0
        elif isinstance(self.state, Batch):
            self.state.empty_(id)

    def collect(
            self,
            n_step: int = 0,
            n_episode: Union[int, List[int]] = 0,
            random: bool = False,
            render: Optional[float] = None,
            log_fn: Optional[Callable[[dict],
                                      None]] = None) -> Dict[str, float]:
        """Collect a specified number of step or episode.

        :param int n_step: how many steps you want to collect.
        :param n_episode: how many episodes you want to collect (in each
            environment).
        :type n_episode: int or list
        :param bool random: whether to use random policy for collecting data,
            defaults to ``False``.
        :param float render: the sleep time between rendering consecutive
            frames, defaults to ``None`` (no rendering).
        :param function log_fn: a function which receives env info, typically
            for tensorboard logging.

        .. note::

            One and only one collection number specification is permitted,
            either ``n_step`` or ``n_episode``.

        :return: A dict including the following keys

            * ``n/ep`` the collected number of episodes.
            * ``n/st`` the collected number of steps.
            * ``v/st`` the speed of steps per second.
            * ``v/ep`` the speed of episode per second.
            * ``rew`` the mean reward over collected episodes.
            * ``len`` the mean length over collected episodes.
        """
        warning_count = 0
        if not self._multi_env:
            n_episode = np.sum(n_episode)
        start_time = time.time()
        assert sum([(n_step != 0), (n_episode != 0)]) == 1, \
            "One and only one collection number specification is permitted!"
        cur_step = 0
        cur_episode = np.zeros(self.env_num) if self._multi_env else 0
        reward_sum = 0
        length_sum = 0
        while True:
            if warning_count >= 100000:
                warnings.warn(
                    'There are already many steps in an episode. '
                    'You should add a time limitation to your environment!',
                    Warning)
            batch = Batch(obs=self._obs,
                          act=self._act,
                          rew=self._rew,
                          done=self._done,
                          obs_next=None,
                          info=self._info,
                          policy=None)
            if random:
                action_space = self.env.action_space
                if isinstance(action_space, list):
                    result = Batch(act=[a.sample() for a in action_space])
                else:
                    result = Batch(act=self._make_batch(action_space.sample()))
            else:
                with torch.no_grad():
                    result = self.policy(batch, self.state)

            # save hidden state to policy._state, in order to save into buffer
            self.state = result.get('state', None)
            if hasattr(result, 'policy'):
                self._policy = to_numpy(result.policy)
                if self.state is not None:
                    self._policy._state = self.state
            elif self.state is not None:
                self._policy = Batch(_state=self.state)
            else:
                self._policy = [{}] * self.env_num

            self._act = to_numpy(result.act)
            if self._action_noise is not None:
                self._act += self._action_noise(self._act.shape)
            obs_next, self._rew, self._done, self._info = self.env.step(
                self._act if self._multi_env else self._act[0])
            if not self._multi_env:
                obs_next = self._make_batch(obs_next)
                self._rew = self._make_batch(self._rew)
                self._done = self._make_batch(self._done)
                self._info = self._make_batch(self._info)
            if log_fn:
                log_fn(self._info if self._multi_env else self._info[0])
            if render:
                self.env.render()
                if render > 0:
                    time.sleep(render)
            self.length += 1
            self.reward += self._rew
            if self.preprocess_fn:
                result = self.preprocess_fn(obs=self._obs,
                                            act=self._act,
                                            rew=self._rew,
                                            done=self._done,
                                            obs_next=obs_next,
                                            info=self._info,
                                            policy=self._policy)
                self._obs = result.get('obs', self._obs)
                self._act = result.get('act', self._act)
                self._rew = result.get('rew', self._rew)
                self._done = result.get('done', self._done)
                obs_next = result.get('obs_next', obs_next)
                self._info = result.get('info', self._info)
                self._policy = result.get('policy', self._policy)
            if self._multi_env:
                for i in range(self.env_num):
                    data = {
                        'obs': self._obs[i],
                        'act': self._act[i],
                        'rew': self._rew[i],
                        'done': self._done[i],
                        'obs_next': obs_next[i],
                        'info': self._info[i],
                        'policy': self._policy[i]
                    }
                    if self._cached_buf:
                        warning_count += 1
                        self._cached_buf[i].add(**data)
                    elif self._multi_buf:
                        warning_count += 1
                        self.buffer[i].add(**data)
                        cur_step += 1
                    else:
                        warning_count += 1
                        if self.buffer is not None:
                            self.buffer.add(**data)
                        cur_step += 1
                    if self._done[i]:
                        if n_step != 0 or np.isscalar(n_episode) or \
                                cur_episode[i] < n_episode[i]:
                            cur_episode[i] += 1
                            reward_sum += self.reward[i]
                            length_sum += self.length[i]
                            if self._cached_buf:
                                cur_step += len(self._cached_buf[i])
                                if self.buffer is not None:
                                    self.buffer.update(self._cached_buf[i])
                        self.reward[i], self.length[i] = 0, 0
                        if self._cached_buf:
                            self._cached_buf[i].reset()
                        self._reset_state(i)
                if sum(self._done):
                    obs_next = self.env.reset(np.where(self._done)[0])
                    if self.preprocess_fn:
                        obs_next = self.preprocess_fn(obs=obs_next).get(
                            'obs', obs_next)
                if n_episode != 0:
                    if isinstance(n_episode, list) and \
                            (cur_episode >= np.array(n_episode)).all() or \
                            np.isscalar(n_episode) and \
                            cur_episode.sum() >= n_episode:
                        break
            else:
                if self.buffer is not None:
                    self.buffer.add(self._obs[0], self._act[0], self._rew[0],
                                    self._done[0], obs_next[0], self._info[0],
                                    self._policy[0])
                cur_step += 1
                if self._done:
                    cur_episode += 1
                    reward_sum += self.reward[0]
                    length_sum += self.length
                    self.reward, self.length = 0, 0
                    self.state = None
                    obs_next = self._make_batch(self.env.reset())
                    if self.preprocess_fn:
                        obs_next = self.preprocess_fn(obs=obs_next).get(
                            'obs', obs_next)
                if n_episode != 0 and cur_episode >= n_episode:
                    break
            if n_step != 0 and cur_step >= n_step:
                break
            self._obs = obs_next
        self._obs = obs_next
        if self._multi_env:
            cur_episode = sum(cur_episode)
        duration = max(time.time() - start_time, 1e-9)
        self.step_speed.add(cur_step / duration)
        self.episode_speed.add(cur_episode / duration)
        self.collect_step += cur_step
        self.collect_episode += cur_episode
        self.collect_time += duration
        if isinstance(n_episode, list):
            n_episode = np.sum(n_episode)
        else:
            n_episode = max(cur_episode, 1)
        return {
            'n/ep': cur_episode,
            'n/st': cur_step,
            'v/st': self.step_speed.get(),
            'v/ep': self.episode_speed.get(),
            'rew': reward_sum / n_episode,
            'len': length_sum / n_episode,
        }

    def sample(self, batch_size: int) -> Batch:
        """Sample a data batch from the internal replay buffer. It will call
        :meth:`~tianshou.policy.BasePolicy.process_fn` before returning
        the final batch data.

        :param int batch_size: ``0`` means it will extract all the data from
            the buffer, otherwise it will extract the data with the given
            batch_size.
        """
        if self._multi_buf:
            if batch_size > 0:
                lens = [len(b) for b in self.buffer]
                total = sum(lens)
                batch_index = np.random.choice(len(self.buffer),
                                               batch_size,
                                               p=np.array(lens) / total)
            else:
                batch_index = np.array([])
            batch_data = Batch()
            for i, b in enumerate(self.buffer):
                cur_batch = (batch_index == i).sum()
                if batch_size and cur_batch or batch_size <= 0:
                    batch, indice = b.sample(cur_batch)
                    batch = self.process_fn(batch, b, indice)
                    batch_data.cat_(batch)
        else:
            batch_data, indice = self.buffer.sample(batch_size)
            batch_data = self.process_fn(batch_data, self.buffer, indice)
        return batch_data
Exemple #8
0
def imitation_trainer(policy, learner, expert_collector, test_collector,
                      max_epoch, step_per_epoch, collect_per_step,
                      repeat_per_collect, episode_per_test, batch_size,
                      train_fn=None, test_fn=None, stop_fn=None,
                      writer=None, task='', peer=0, peer_decay_steps=0):
    global_step = 0
    best_epoch, best_reward = -1, -1
    stat = {}
    start_time = time.time()
    for epoch in range(1, 1 + max_epoch):
        # train
        policy.train()
        if train_fn:
            train_fn(epoch)
        with tqdm.tqdm(
                total=step_per_epoch, desc=f'Epoch #{epoch}',
                **tqdm_config) as t:
            while t.n < t.total:
                expert_collector.collect(n_episode=collect_per_step)
                result = test_collector.collect(n_episode=episode_per_test)

                data = {}
                if stop_fn and stop_fn(result['rew']):
                    for k in result.keys():
                        data[k] = f'{result[k]:.2f}'
                    t.set_postfix(**data)
                    return gather_info(
                        start_time, expert_collector, test_collector,
                        result['rew'])
                else:
                    policy.train()
                    if train_fn:
                        train_fn(epoch)

                decay = 1. if not peer_decay_steps else \
                    max(0., 1 - global_step / peer_decay_steps)
                losses = learner(policy, expert_collector.sample(0),
                                 batch_size, repeat_per_collect, peer * decay)
                expert_collector.reset_buffer()
                step = 1
                for k in losses.keys():
                    if isinstance(losses[k], list):
                        step = max(step, len(losses[k]))
                global_step += step
                for k in result.keys():
                    data[k] = f'{result[k]:.2f}'
                    if writer:
                        writer.add_scalar(
                            k + '_' + task if task else k,
                            result[k], global_step=global_step)
                for k in losses.keys():
                    if stat.get(k) is None:
                        stat[k] = MovAvg()
                    stat[k].add(losses[k])
                    data[k] = f'{stat[k].get():.6f}'
                    if writer and global_step:
                        writer.add_scalar(
                            k + '_' + task if task else k,
                            stat[k].get(), global_step=global_step)
                t.update(step)
                t.set_postfix(**data)
            if t.n <= t.total:
                t.update()
        # test
        result = test_episode(
            policy, test_collector, test_fn, epoch, episode_per_test)
        if best_epoch == -1 or best_reward < result['rew']:
            best_reward = result['rew']
            best_epoch = epoch
        print(f'Epoch #{epoch}: test_reward: {result["rew"]:.6f}, '
              f'best_reward: {best_reward:.6f} in #{best_epoch}')
        if stop_fn and stop_fn(best_reward):
            break
    return gather_info(
        start_time, expert_collector, test_collector, best_reward)
Exemple #9
0
class Collector(object):
    """The :class:`~tianshou.data.Collector` enables the policy to interact
    with different types of environments conveniently.

    :param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
        class.
    :param env: an environment or an instance of the
        :class:`~tianshou.env.BaseVectorEnv` class.
    :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer`
        class, or a list of :class:`~tianshou.data.ReplayBuffer`. If set to
        ``None``, it will automatically assign a small-size
        :class:`~tianshou.data.ReplayBuffer`.
    :param int stat_size: for the moving average of recording speed, defaults
        to 100.

    Example:
    ::

        policy = PGPolicy(...)  # or other policies if you wish
        env = gym.make('CartPole-v0')
        replay_buffer = ReplayBuffer(size=10000)
        # here we set up a collector with a single environment
        collector = Collector(policy, env, buffer=replay_buffer)

        # the collector supports vectorized environments as well
        envs = VectorEnv([lambda: gym.make('CartPole-v0') for _ in range(3)])
        buffers = [ReplayBuffer(size=5000) for _ in range(3)]
        # you can also pass a list of replay buffer to collector, for multi-env
        # collector = Collector(policy, envs, buffer=buffers)
        collector = Collector(policy, envs, buffer=replay_buffer)

        # collect at least 3 episodes
        collector.collect(n_episode=3)
        # collect 1 episode for the first env, 3 for the third env
        collector.collect(n_episode=[1, 0, 3])
        # collect at least 2 steps
        collector.collect(n_step=2)
        # collect episodes with visual rendering (the render argument is the
        #   sleep time between rendering consecutive frames)
        collector.collect(n_episode=1, render=0.03)

        # sample data with a given number of batch-size:
        batch_data = collector.sample(batch_size=64)
        # policy.learn(batch_data)  # btw, vanilla policy gradient only
        #   supports on-policy training, so here we pick all data in the buffer
        batch_data = collector.sample(batch_size=0)
        policy.learn(batch_data)
        # on-policy algorithms use the collected data only once, so here we
        #   clear the buffer
        collector.reset_buffer()

    For the scenario of collecting data from multiple environments to a single
    buffer, the cache buffers will turn on automatically. It may return the
    data more than the given limitation.

    .. note::

        Please make sure the given environment has a time limitation.
    """
    def __init__(self,
                 policy,
                 env,
                 buffer=None,
                 stat_size=100,
                 repeat=0,
                 **kwargs):
        super().__init__()
        self.repeat = repeat
        self.env = env
        self.env_num = 1
        self.collect_step = 0
        self.collect_episode = 0
        self.collect_time = 0
        self.buffer = buffer
        self.policy = policy
        self.process_fn = policy.process_fn
        self._multi_env = isinstance(env, BaseVectorEnv)
        self._multi_buf = False  # True if buf is a list
        # need multiple cache buffers only if storing in one buffer
        self._cached_buf = []
        if self._multi_env:
            self.env_num = len(env)
            if isinstance(self.buffer, list):
                assert len(self.buffer) == self.env_num, \
                    'The number of data buffer does not match the number of ' \
                    'input env.'
                self._multi_buf = True
            elif isinstance(self.buffer, ReplayBuffer) or self.buffer is None:
                self._cached_buf = [
                    ListReplayBuffer() for _ in range(self.env_num)
                ]
            else:
                raise TypeError('The buffer in data collector is invalid!')
        self.stat_size = stat_size
        self.reset()

    def reset(self):
        """Reset all related variables in the collector."""
        self.reset_env()
        self.reset_buffer()
        # state over batch is either a list, an np.ndarray, or a torch.Tensor
        self.state = None
        self.step_speed = MovAvg(self.stat_size)
        self.episode_speed = MovAvg(self.stat_size)
        self.collect_step = 0
        self.collect_episode = 0
        self.collect_time = 0

    def reset_buffer(self):
        """Reset the main data buffer."""
        if self._multi_buf:
            for b in self.buffer:
                b.reset()
        else:
            if self.buffer is not None:
                self.buffer.reset()

    def get_env_num(self):
        """Return the number of environments the collector has."""
        return self.env_num

    def reset_env(self):
        """Reset all of the environment(s)' states and reset all of the cache
        buffers (if need).
        """
        self._obs = self.env.reset()
        self._act = self._rew = self._done = self._info = None
        if self._multi_env:
            self.reward = np.zeros(self.env_num)
            self.length = np.zeros(self.env_num)
        else:
            self.reward, self.length = 0, 0
        for b in self._cached_buf:
            b.reset()

    def seed(self, seed=None):
        """Reset all the seed(s) of the given environment(s)."""
        if hasattr(self.env, 'seed'):
            return self.env.seed(seed)

    def render(self, **kwargs):
        """Render all the environment(s)."""
        if hasattr(self.env, 'render'):
            return self.env.render(**kwargs)

    def close(self):
        """Close the environment(s)."""
        if hasattr(self.env, 'close'):
            self.env.close()

    def _make_batch(self, data):
        """Return [data]."""
        if isinstance(data, np.ndarray):
            return data[None]
        else:
            return np.array([data])

    def _reset_state(self, id):
        """Reset self.state[id]."""
        if self.state is None:
            return
        if isinstance(self.state, list):
            self.state[id] = None
        elif isinstance(self.state, dict):
            for k in self.state:
                if isinstance(self.state[k], list):
                    self.state[k][id] = None
                elif isinstance(self.state[k], torch.Tensor) or \
                        isinstance(self.state[k], np.ndarray):
                    self.state[k][id] = 0
        elif isinstance(self.state, torch.Tensor) or \
                isinstance(self.state, np.ndarray):
            self.state[id] = 0

    def collect(self, n_step=0, n_episode=0, render=None, log_fn=None):
        """Collect a specified number of step or episode.

        :param int n_step: how many steps you want to collect.
        :param n_episode: how many episodes you want to collect (in each
            environment).
        :type n_episode: int or list
        :param float render: the sleep time between rendering consecutive
            frames, defaults to ``None`` (no rendering).
        :param function log_fn: a function which receives env info, typically
            for tensorboard logging.

        .. note::

            One and only one collection number specification is permitted,
            either ``n_step`` or ``n_episode``.

        :return: A dict including the following keys

            * ``n/ep`` the collected number of episodes.
            * ``n/st`` the collected number of steps.
            * ``v/st`` the speed of steps per second.
            * ``v/ep`` the speed of episode per second.
            * ``rew`` the mean reward over collected episodes.
            * ``len`` the mean length over collected episodes.
        """
        warning_count = 0
        if not self._multi_env:
            n_episode = np.sum(n_episode)
        start_time = time.time()
        assert sum([(n_step != 0), (n_episode != 0)]) == 1, \
            "One and only one collection number specification is permitted!"
        cur_step = 0
        cur_episode = np.zeros(self.env_num) if self._multi_env else 0
        reward_sum = 0
        length_sum = 0
        while True:
            if warning_count >= 100000:
                warnings.warn(
                    'There are already many steps in an episode. '
                    'You should add a time limitation to your environment!',
                    Warning)
            if self._multi_env:
                batch_data = Batch(obs=self._obs,
                                   act=self._act,
                                   rew=self._rew,
                                   done=self._done,
                                   obs_next=None,
                                   info=self._info)
            else:
                batch_data = Batch(obs=self._make_batch(self._obs),
                                   act=self._make_batch(self._act),
                                   rew=self._make_batch(self._rew),
                                   done=self._make_batch(self._done),
                                   obs_next=None,
                                   info=self._make_batch(self._info))
            with torch.no_grad():
                result = self.policy(batch_data, self.state)
            self.state = result.state if hasattr(result, 'state') else None
            if isinstance(result.act, torch.Tensor):
                self._act = result.act.detach().cpu().numpy()
            elif not isinstance(self._act, np.ndarray):
                self._act = np.array(result.act)
            else:
                self._act = result.act
            obs_next, self._rew, self._done, self._info = self.env.step(
                self._act if self._multi_env else self._act[0])
            if log_fn is not None:
                log_fn(self._info)
            if render is not None:
                self.env.render()
                if render > 0:
                    time.sleep(render)
            self.length += 1
            self.reward += self._rew
            if self._multi_env:
                for i in range(self.env_num):
                    data = {
                        'obs': self._obs[i],
                        'act': self._act[i],
                        'rew': self._rew[i],
                        'done': self._done[i],
                        'obs_next': obs_next[i],
                        'info': self._info[i]
                    }
                    if self._cached_buf:
                        warning_count += 1
                        self._cached_buf[i].add(**data)
                        if data['act'] != 3:
                            for _ in range(self.repeat):
                                self._cached_buf[i].add(**data)
                    elif self._multi_buf:
                        warning_count += 1
                        self.buffer[i].add(**data)
                        cur_step += 1
                    else:
                        warning_count += 1
                        if self.buffer is not None:
                            self.buffer.add(**data)
                        cur_step += 1
                    if self._done[i]:
                        if n_step != 0 or np.isscalar(n_episode) or \
                                cur_episode[i] < n_episode[i]:
                            cur_episode[i] += 1
                            reward_sum += self.reward[i]
                            length_sum += self.length[i]
                            if self._cached_buf:
                                cur_step += len(self._cached_buf[i])
                                if self.buffer is not None:
                                    self.buffer.update(self._cached_buf[i])
                        self.reward[i], self.length[i] = 0, 0
                        if self._cached_buf:
                            self._cached_buf[i].reset()
                        self._reset_state(i)
                if sum(self._done):
                    obs_next = self.env.reset(np.where(self._done)[0])
                if n_episode != 0:
                    if isinstance(n_episode, list) and \
                            (cur_episode >= np.array(n_episode)).all() or \
                            np.isscalar(n_episode) and \
                            cur_episode.sum() >= n_episode:
                        break
            else:
                if self.buffer is not None:
                    self.buffer.add(self._obs, self._act[0], self._rew,
                                    self._done, obs_next, self._info)
                cur_step += 1
                if self._done:
                    cur_episode += 1
                    reward_sum += self.reward
                    length_sum += self.length
                    self.reward, self.length = 0, 0
                    self.state = None
                    obs_next = self.env.reset()
                if n_episode != 0 and cur_episode >= n_episode:
                    break
            if n_step != 0 and cur_step >= n_step:
                break
            self._obs = obs_next
        self._obs = obs_next
        if self._multi_env:
            cur_episode = sum(cur_episode)
        duration = max(time.time() - start_time, 1e-9)
        self.step_speed.add(cur_step / duration)
        self.episode_speed.add(cur_episode / duration)
        self.collect_step += cur_step
        self.collect_episode += cur_episode
        self.collect_time += duration
        if isinstance(n_episode, list):
            n_episode = np.sum(n_episode)
        else:
            n_episode = max(cur_episode, 1)
        return {
            'n/ep': cur_episode,
            'n/st': cur_step,
            'v/st': self.step_speed.get(),
            'v/ep': self.episode_speed.get(),
            'rew': reward_sum / n_episode,
            'len': length_sum / n_episode,
        }

    def sample(self, batch_size):
        """Sample a data batch from the internal replay buffer. It will call
        :meth:`~tianshou.policy.BasePolicy.process_fn` before returning
        the final batch data.

        :param int batch_size: ``0`` means it will extract all the data from
            the buffer, otherwise it will extract the data with the given
            batch_size.
        """
        if self._multi_buf:
            if batch_size > 0:
                lens = [len(b) for b in self.buffer]
                total = sum(lens)
                batch_index = np.random.choice(total,
                                               batch_size,
                                               p=np.array(lens) / total)
            else:
                batch_index = np.array([])
            batch_data = Batch()
            for i, b in enumerate(self.buffer):
                cur_batch = (batch_index == i).sum()
                if batch_size and cur_batch or batch_size <= 0:
                    batch, indice = b.sample(cur_batch)
                    batch = self.process_fn(batch, b, indice)
                    batch_data.append(batch)
        else:
            batch_data, indice = self.buffer.sample(batch_size)
            batch_data = self.process_fn(batch_data, self.buffer, indice)
        return batch_data
Exemple #10
0
def Myonpolicy_trainer(
    policy: BasePolicy,
    train_collector: MyCollector,
    test_collector: MyCollector,
    max_epoch: int,
    step_per_epoch: int,
    collect_per_step: int,
    repeat_per_collect: int,
    episode_per_test: Union[int, List[int]],  # 每一次测试测试几个episode
    batch_size: int,
    train_fn: Optional[Callable[[int, int], None]] = None,
    test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
    stop_fn: Optional[Callable[[float], bool]] = None,
    save_fn: Optional[Callable[[BasePolicy], None]] = None,
    writer: Optional[SummaryWriter] = None,
    log_interval: int = 1,
    verbose: bool = True,
    test_in_train: bool = True,
    test_probs: bool = False,
) -> Dict[str, Union[float, str]]:
    """A wrapper for on-policy trainer procedure.

    The "step" in trainer means a policy network update.

    :param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
        class.
    :param train_collector: the collector used for training.
    :type train_collector: :class:`~tianshou.data.Collector`
    :param test_collector: the collector used for testing.
    :type test_collector: :class:`~tianshou.data.Collector`
    :param int max_epoch: the maximum of epochs for training. The training
        process might be finished before reaching the ``max_epoch``.
    :param int step_per_epoch: the number of step for updating policy network 在每一个epoch最多更新多少次网络
        in one epoch.
    :param int collect_per_step: the number of episodes the collector would  在一个step要进行收集多少个数据
        collect before the network update. In other words, collect some
        episodes and do one policy network update.
    :param int repeat_per_collect: the number of repeat time for policy
        learning, for example, set it to 2 means the policy needs to learn each
        given batch data twice.
    :param episode_per_test: the number of episodes for one policy evaluation.
    :type episode_per_test: int or list of ints
    :param int batch_size: the batch size of sample data, which is going to
        feed in the policy network.
    :param function train_fn: a function receives the current number of epoch
        and step index, and performs some operations at the beginning of
        training in this poch.
    :param function test_fn: a function receives the current number of epoch
        and step index, and performs some operations at the beginning of
        testing in this epoch.
    :param function save_fn: a function for saving policy when the undiscounted
        average mean reward in evaluation phase gets better.
    :param function stop_fn: a function receives the average undiscounted
        returns of the testing result, return a boolean which indicates whether
        reaching the goal.
    :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard
        SummaryWriter.
    :param int log_interval: the log interval of the writer.
    :param bool verbose: whether to print the information.
    :param bool test_in_train: whether to test in the training phase.
    :param bool test_probs: 在测试集使用多个精度.

    :return: See :func:`~tianshou.trainer.gather_info`.
    """
    env_step, gradient_step = 0, 0
    best_epoch, best_reward, best_reward_std = -1, -1.0, 0.0
    best_rate = 0.
    best_mate_num = 0.
    best_avg_len = 0.
    stat: Dict[str, MovAvg] = {}
    start_time = time.time()
    # print("policy reset")
    train_collector.reset_stat()
    test_collector.reset_stat()
    test_in_train = test_in_train and train_collector.policy == policy
    # print('policy')
    for epoch in range(1, 1 + max_epoch):
        # train

        policy.train()
        with tqdm.tqdm(total=step_per_epoch,
                       desc=f"Epoch #{epoch}",
                       **tqdm_config) as t:
            while t.n < t.total:
                if train_fn:
                    train_fn(epoch, env_step)
                result = train_collector.collect(
                    n_episode=collect_per_step
                )  # collect之后就会返回一个result, 即为更新一次所需要收集的数据
                env_step += int(result["n/st"])
                data = {
                    "env_step": str(env_step),
                    "rew": f"{result['rew']:.2f}",
                    "len": str(int(result["len"])),
                    "n/ep": str(int(result["n/ep"])),
                    "n/st": str(int(result["n/st"])),
                    "v/ep": f"{result['v/ep']:.2f}",
                    "v/st": f"{result['v/st']:.2f}",
                    "rate": f"{result['hit_rate']:.2f}",
                }
                if writer and env_step % log_interval == 0:
                    for k in result.keys():
                        if "class" not in k:
                            writer.add_scalar("train/" + k,
                                              result[k],
                                              global_step=env_step)
                if test_in_train and stop_fn and stop_fn(result["rew"]):
                    test_result = test_episode(policy, test_collector, test_fn,
                                               epoch, episode_per_test, writer,
                                               env_step)
                    if stop_fn(test_result["rew"]):
                        if save_fn:
                            save_fn(policy)
                        for k in result.keys():
                            data[k] = f"{result[k]:.2f}"
                        t.set_postfix(**data)
                        return gather_info(start_time, train_collector,
                                           test_collector, test_result["rew"],
                                           test_result["rew_std"])
                    else:
                        policy.train()
                # print("what the f**k")
                losses = policy.update(
                    0,
                    train_collector.buffer,  # 训练数据就是collector的buff里面的内容
                    batch_size=batch_size,
                    repeat=repeat_per_collect)
                # print("youxi")
                train_collector.reset_buffer()
                step = max(
                    [1] +
                    [len(v) for v in losses.values() if isinstance(v, list)
                     ])  # 没太看懂这个step为什么要这样加
                gradient_step += step
                for k in losses.keys():
                    if stat.get(k) is None:
                        stat[k] = MovAvg()
                    stat[k].add(losses[k])
                    data[k] = f"{stat[k].get():.6f}"
                    if writer and gradient_step % log_interval == 0:
                        writer.add_scalar(k,
                                          stat[k].get(),
                                          global_step=gradient_step)
                t.update(step)
                t.set_postfix(**data)
            if t.n <= t.total:
                t.update()
        # test
        # print("episode_per_test: ", episode_per_test)
        # print("test_collector_env_num: ", test_collector.get_env_num())
        start_time = time.time()
        result = test_episode(policy, test_collector, test_fn, epoch,
                              episode_per_test, writer,
                              env_step)  # 这里能保证遍历所有test集合的数据吗? 并且需要只遍历一次
        end_time = time.time()
        print("total_time: ", (end_time - start_time) / 60,
              time.strftime("%Y-%m-%d-%H_%M_%S", time.localtime()))
        if test_probs:  # 测试第二个阈值
            # policy.actor.threshold = 0.9
            test_probs_result = test_episode(policy,
                                             test_collector,
                                             test_fn,
                                             epoch,
                                             episode_per_test,
                                             writer,
                                             env_step,
                                             name='test_prob1/')
            print(result['hit_rate'], test_probs_result['hit_rate'])
            best_rate = max(best_rate, test_probs_result['hit_rate'])
            # policy.actor.threshold = 0.95
        test_hit_rate = result['hit_rate']
        # best_rate = max(best_rate, test_hit_rate)
        best_flag = 0
        if best_epoch == -1 or best_rate < test_hit_rate:
            best_reward, best_reward_std = result["rew"], result["rew_std"]
            best_rate = test_hit_rate
            best_epoch = epoch
            best_mate_num = result['mate_num']
            best_flag = 1
            best_avg_len = result['len']
        if best_rate == test_hit_rate and result['mate_num'] > best_mate_num:
            best_mate_num = result['mate_num']
            best_avg_len = result['len']
            best_flag = 1
        if save_fn and best_flag == 1:
            print("happy")
            save_fn(policy)
            # import pickle
            # with open('./model/f**k.pk','wb') as f:
            #     pickle.dump(test_collector, f)
            # import torch
            # pp = torch.load('./model/policy.pth')
            # pp_result = test_episode(pp, test_collector, test_fn=None, epoch=1,
            #                     n_episode=episode_per_test)
            # print(pp_result)
        if verbose:
            print(f"Epoch #{epoch}: test_reward: {result['rew']:.6f} ± "
                  f"{result['rew_std']:.6f}, best_reward: {best_reward:.6f} ± "
                  f"{best_reward_std:.6f}"
                  f"  hit_rate: {test_hit_rate}:.3f"
                  f" mate_num:  {result['mate_num']}"
                  f" avg_len:  {result['len']}")
            print(f"  best_rate: {best_rate}:.3f"
                  f"  best_mate_num: {best_mate_num}:.3f"
                  f"  best_len: {best_avg_len}:.3f"
                  f"  in: #{best_epoch}:.3f")
            # if result['class_rate'] is not None:  ##打印各类别的rate
            #     ans = {}
            #     right_class_num = total_class_num()
            #     for ke, val in result['class_rate'].items():
            #         if ke in right_class_num.keys():
            #             ans[ke] = val / right_class_num[ke]
            #     print(ans)
            #     print(result['class_rate'])

        if stop_fn and stop_fn(best_reward):
            break
    return gather_info(start_time, train_collector, test_collector,
                       best_reward, best_reward_std, best_rate, best_mate_num,
                       best_avg_len)
Exemple #11
0
def offpolicy_trainer_with_views(A,
                                 B,
                                 max_epoch,
                                 step_per_epoch,
                                 collect_per_step,
                                 episode_per_test,
                                 batch_size,
                                 copier=False,
                                 peer=0.,
                                 verbose=True,
                                 test_fn=None,
                                 task=''):
    global_step = 0
    best_epoch, best_reward = -1, -1
    stat = {}
    start_time = time.time()

    for epoch in range(1, 1 + max_epoch):
        # train
        A.train()
        B.train()
        with tqdm.tqdm(total=step_per_epoch,
                       desc=f'Epoch #{epoch}',
                       **tqdm_config) as t:
            while t.n < t.total:
                for view, other_view in zip([A, B], [B, A]):
                    result = view.train_collector.collect(
                        n_step=collect_per_step)
                    data = {}
                    if view.stop_fn(result['rew']):
                        test_result = test_episode(view.policy,
                                                   view.test_collector,
                                                   test_fn, epoch,
                                                   episode_per_test)
                        if view.stop_fn(test_result['rew']):
                            for k in result.keys():
                                data[k] = f'{result[k]:.2f}'
                            t.set_postfix(**data)
                            return gather_info(start_time,
                                               view.train_collector,
                                               view.test_collector,
                                               test_result['rew'])
                        else:
                            view.policy.train()
                    for i in range(
                            min(result['n/st'] // collect_per_step,
                                t.total - t.n)):
                        global_step += 1
                        batch = view.train_collector.sample(batch_size)
                        losses = view.policy.learn(batch)

                        # Learn from demonstration
                        if copier:
                            demo = other_view.policy(batch)
                            view.learn_from_demos(batch, demo, peer=peer)

                        for k in result.keys():
                            data[k] = f'{result[k]:.2f}'
                            view.writer.add_scalar(k + '_' +
                                                   task if task else k,
                                                   result[k],
                                                   global_step=global_step)
                        for k in losses.keys():
                            if stat.get(k) is None:
                                stat[k] = MovAvg()
                            stat[k].add(losses[k])
                            data[k] = f'{stat[k].get():.4f}'
                            view.writer.add_scalar(k + '_' +
                                                   task if task else k,
                                                   stat[k].get(),
                                                   global_step=global_step)
                        t.update(1)
                        t.set_postfix(**data)
            if t.n <= t.total:
                t.update()
        # test
        brk = False
        for view in A, B:
            result = test_episode(view.policy, view.test_collector, test_fn,
                                  epoch, episode_per_test)
            if best_epoch == -1 or best_reward < result['rew']:
                best_reward = result['rew']
                best_epoch = epoch
            if verbose:
                print(f'Epoch #{epoch}: test_reward: {result["rew"]:.4f}, '
                      f'best_reward: {best_reward:.4f} in #{best_epoch}')
            if view.stop_fn(best_reward):
                brk = True
        if brk:
            break
    return (
        gather_info(start_time, A.train_collector, A.test_collector,
                    best_reward),
        gather_info(start_time, B.train_collector, B.test_collector,
                    best_reward),
    )
Exemple #12
0
class Collector(object):
    """docstring for Collector"""
    def __init__(self, policy, env, buffer=None, stat_size=100):
        super().__init__()
        self.env = env
        self.env_num = 1
        self.collect_step = 0
        self.collect_episode = 0
        self.collect_time = 0
        if buffer is None:
            self.buffer = ReplayBuffer(100)
        else:
            self.buffer = buffer
        self.policy = policy
        self.process_fn = policy.process_fn
        self._multi_env = isinstance(env, BaseVectorEnv)
        self._multi_buf = False  # True if buf is a list
        # need multiple cache buffers only if storing in one buffer
        self._cached_buf = []
        if self._multi_env:
            self.env_num = len(env)
            if isinstance(self.buffer, list):
                assert len(self.buffer) == self.env_num, \
                    'The number of data buffer does not match the number of ' \
                    'input env.'
                self._multi_buf = True
            elif isinstance(self.buffer, ReplayBuffer):
                self._cached_buf = [
                    ListReplayBuffer() for _ in range(self.env_num)
                ]
            else:
                raise TypeError('The buffer in data collector is invalid!')
        self.reset_env()
        self.reset_buffer()
        # state over batch is either a list, an np.ndarray, or a torch.Tensor
        self.state = None
        self.step_speed = MovAvg(stat_size)
        self.episode_speed = MovAvg(stat_size)

    def reset_buffer(self):
        if self._multi_buf:
            for b in self.buffer:
                b.reset()
        else:
            self.buffer.reset()

    def get_env_num(self):
        return self.env_num

    def reset_env(self):
        self._obs = self.env.reset()
        self._act = self._rew = self._done = self._info = None
        if self._multi_env:
            self.reward = np.zeros(self.env_num)
            self.length = np.zeros(self.env_num)
        else:
            self.reward, self.length = 0, 0
        for b in self._cached_buf:
            b.reset()

    def seed(self, seed=None):
        if hasattr(self.env, 'seed'):
            return self.env.seed(seed)

    def render(self, **kwargs):
        if hasattr(self.env, 'render'):
            return self.env.render(**kwargs)

    def close(self):
        if hasattr(self.env, 'close'):
            self.env.close()

    def _make_batch(self, data):
        if isinstance(data, np.ndarray):
            return data[None]
        else:
            return np.array([data])

    def collect(self, n_step=0, n_episode=0, render=0):
        warning_count = 0
        if not self._multi_env:
            n_episode = np.sum(n_episode)
        start_time = time.time()
        assert sum([(n_step != 0), (n_episode != 0)]) == 1, \
            "One and only one collection number specification permitted!"
        cur_step = 0
        cur_episode = np.zeros(self.env_num) if self._multi_env else 0
        reward_sum = 0
        length_sum = 0
        while True:
            if warning_count >= 100000:
                warnings.warn(
                    'There are already many steps in an episode. '
                    'You should add a time limitation to your environment!',
                    Warning)
            if self._multi_env:
                batch_data = Batch(obs=self._obs,
                                   act=self._act,
                                   rew=self._rew,
                                   done=self._done,
                                   obs_next=None,
                                   info=self._info)
            else:
                batch_data = Batch(obs=self._make_batch(self._obs),
                                   act=self._make_batch(self._act),
                                   rew=self._make_batch(self._rew),
                                   done=self._make_batch(self._done),
                                   obs_next=None,
                                   info=self._make_batch(self._info))
            result = self.policy(batch_data, self.state)
            self.state = result.state if hasattr(result, 'state') else None
            if isinstance(result.act, torch.Tensor):
                self._act = result.act.detach().cpu().numpy()
            elif not isinstance(self._act, np.ndarray):
                self._act = np.array(result.act)
            else:
                self._act = result.act
            obs_next, self._rew, self._done, self._info = self.env.step(
                self._act if self._multi_env else self._act[0])
            if render > 0:
                self.env.render()
                time.sleep(render)
            self.length += 1
            self.reward += self._rew
            if self._multi_env:
                for i in range(self.env_num):
                    data = {
                        'obs': self._obs[i],
                        'act': self._act[i],
                        'rew': self._rew[i],
                        'done': self._done[i],
                        'obs_next': obs_next[i],
                        'info': self._info[i]
                    }
                    if self._cached_buf:
                        warning_count += 1
                        self._cached_buf[i].add(**data)
                    elif self._multi_buf:
                        warning_count += 1
                        self.buffer[i].add(**data)
                        cur_step += 1
                    else:
                        warning_count += 1
                        self.buffer.add(**data)
                        cur_step += 1
                    if self._done[i]:
                        if n_step != 0 or np.isscalar(n_episode) or \
                                cur_episode[i] < n_episode[i]:
                            cur_episode[i] += 1
                            reward_sum += self.reward[i]
                            length_sum += self.length[i]
                            if self._cached_buf:
                                cur_step += len(self._cached_buf[i])
                                self.buffer.update(self._cached_buf[i])
                        self.reward[i], self.length[i] = 0, 0
                        if self._cached_buf:
                            self._cached_buf[i].reset()
                        if isinstance(self.state, list):
                            self.state[i] = None
                        elif self.state is not None:
                            if isinstance(self.state[i], dict):
                                self.state[i] = {}
                            else:
                                self.state[i] = self.state[i] * 0
                            if isinstance(self.state, torch.Tensor):
                                # remove ref count in pytorch (?)
                                self.state = self.state.detach()
                if sum(self._done):
                    obs_next = self.env.reset(np.where(self._done)[0])
                if n_episode != 0:
                    if isinstance(n_episode, list) and \
                            (cur_episode >= np.array(n_episode)).all() or \
                            np.isscalar(n_episode) and \
                            cur_episode.sum() >= n_episode:
                        break
            else:
                self.buffer.add(self._obs, self._act[0], self._rew, self._done,
                                obs_next, self._info)
                cur_step += 1
                if self._done:
                    cur_episode += 1
                    reward_sum += self.reward
                    length_sum += self.length
                    self.reward, self.length = 0, 0
                    self.state = None
                    obs_next = self.env.reset()
                if n_episode != 0 and cur_episode >= n_episode:
                    break
            if n_step != 0 and cur_step >= n_step:
                break
            self._obs = obs_next
        self._obs = obs_next
        if self._multi_env:
            cur_episode = sum(cur_episode)
        duration = time.time() - start_time
        self.step_speed.add(cur_step / duration)
        self.episode_speed.add(cur_episode / duration)
        self.collect_step += cur_step
        self.collect_episode += cur_episode
        self.collect_time += duration
        if isinstance(n_episode, list):
            n_episode = np.sum(n_episode)
        else:
            n_episode = max(cur_episode, 1)
        return {
            'n/ep': cur_episode,
            'n/st': cur_step,
            'v/st': self.step_speed.get(),
            'v/ep': self.episode_speed.get(),
            'rew': reward_sum / n_episode,
            'len': length_sum / n_episode,
        }

    def sample(self, batch_size):
        if self._multi_buf:
            if batch_size > 0:
                lens = [len(b) for b in self.buffer]
                total = sum(lens)
                batch_index = np.random.choice(total,
                                               batch_size,
                                               p=np.array(lens) / total)
            else:
                batch_index = np.array([])
            batch_data = Batch()
            for i, b in enumerate(self.buffer):
                cur_batch = (batch_index == i).sum()
                if batch_size and cur_batch or batch_size <= 0:
                    batch, indice = b.sample(cur_batch)
                    batch = self.process_fn(batch, b, indice)
                    batch_data.append(batch)
        else:
            batch_data, indice = self.buffer.sample(batch_size)
            batch_data = self.process_fn(batch_data, self.buffer, indice)
        return batch_data
def offpolicy_trainer(
    policy: BasePolicy,
    train_collector,
    max_epoch: int,
    step_per_epoch: int,
    collect_per_step: int,
    batch_size: int,
    update_per_step: int = 1,
    train_fn: Optional[Callable[[int], None]] = None,
    writer: Optional[SummaryWriter] = None,
    log_interval: int = 100,
) -> int:
    """A wrapper for off-policy trainer procedure. The ``step`` in trainer
    means a policy network update.

    :param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
        class.
    :param train_collector: the collector used for training.
    :type train_collector: :class:`~tianshou.data.Collector`
    :param test_collector: the collector used for testing.
    :type test_collector: :class:`~tianshou.data.Collector`
    :param int max_epoch: the maximum of epochs for training. The training
        process might be finished before reaching the ``max_epoch``.
    :param int step_per_epoch: the number of step for updating policy network
        in one epoch.
    :param int collect_per_step: the number of frames the collector would
        collect before the network update. In other words, collect some frames
        and do some policy network update.
    :param episode_per_test: the number of episodes for one policy evaluation.
    :param int batch_size: the batch size of sample data, which is going to
        feed in the policy network.
    :param int update_per_step: the number of times the policy network would
        be updated after frames are collected, for example, set it to 256 means
        it updates policy 256 times once after ``collect_per_step`` frames are
        collected.
    :param function train_fn: a function receives the current number of epoch
        index and performs some operations at the beginning of training in this
        epoch.
    :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard
        SummaryWriter.
    :param int log_interval: the log interval of the writer.

    :return: See :func:`~tianshou.trainer.gather_info`.
    """
    global_step = 0
    update_step = 0
    best_epoch, best_reward = -1, -1.
    stat = {}
    start_time = time.time()
    results = collections.deque(maxlen=300)
    world_results = [collections.deque(maxlen=10) for _ in range(WORLD_NUM)]
    world_count = [1] * WORLD_NUM
    world_pcount = [1] * WORLD_NUM
    for epoch in range(1, 1 + max_epoch):
        # train
        policy.train()
        if train_fn:
            train_fn(epoch)
        with tqdm.tqdm(total=step_per_epoch,
                       desc=f'Epoch #{epoch}',
                       **tqdm_config) as t:
            while t.n < t.total:
                assert train_collector.policy == policy
                result = train_collector.collect(n_step=collect_per_step)
                world_pcount = world_count.copy()
                for i, w in enumerate(result["world"]):
                    world_results[w].append({"ep_rew": result["ep_rew"][i],\
                                          "ep_len": result["ep_len"][i],\
                                          "success": result["success"][i],\
                                          "global_step": global_step})
                    world_count[w] += 1
                for w in range(WORLD_NUM):
                    if world_count[w] // 10 > world_pcount[w] // 10:
                        for k in world_results[w][0].keys():
                            writer.add_scalar(
                                'world_%d/' % (w) + k,
                                np.mean([r[k] for r in world_results[w]]),
                                global_step=world_count[w])
                n_ep = len(result["success"])
                result = [{"ep_rew":result["ep_rew"][i],\
                           "ep_len":result["ep_len"][i],\
                           "success":result["success"][i]}\
                           for i in range(n_ep)]
                results.extend(result)
                data = {"n_ep": n_ep}
                n_step = sum([r["ep_len"] for r in result])
                global_step += n_step
                n_step = np.clip(n_step, 10, 5000)
                for i in range(update_per_step *
                               min(n_step // collect_per_step, t.total - t.n)):
                    # for i in range(update_per_step):# * min(n_step // collect_per_step, t.total - t.n)):
                    losses = policy.update(batch_size, train_collector.buffer)
                    update_step += 1
                    if len(result) > 0:
                        for k in result[0].keys():
                            data[k] = f"{np.mean([r[k] for r in result]):.2f}"
                            if writer and update_step % log_interval == 0:
                                writer.add_scalar('train/' + k,
                                                  np.mean(
                                                      [r[k] for r in results]),
                                                  global_step=global_step)
                    for k in losses.keys():
                        if stat.get(k) is None:
                            stat[k] = MovAvg()
                        stat[k].add(losses[k])
                        data[k] = f'{stat[k].get():.6f}'
                        if writer and update_step % log_interval == 0:
                            writer.add_scalar(k,
                                              stat[k].get(),
                                              global_step=update_step)
                    try:
                        data['exp_noise'] = policy._noise._sigma
                    except:
                        data['exp_noise'] = policy._noise
                    t.update(1)
                    t.set_postfix(**data)
            if t.n <= t.total:
                t.update()
    return global_step
Exemple #14
0
def onpolicy_trainer(
        policy: BasePolicy,
        train_collector: Collector,
        test_collector: Collector,
        max_epoch: int,
        frame_per_epoch: int,
        collect_per_step: int,
        repeat_per_collect: int,
        episode_per_test: Union[int, Sequence[int]],
        batch_size: int,
        train_fn: Optional[Callable[[int, int], None]] = None,
        test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
        stop_fn: Optional[Callable[[float], bool]] = None,
        save_fn: Optional[Callable[[BasePolicy], None]] = None,
        writer: Optional[SummaryWriter] = None,
        log_interval: int = 1,
        verbose: bool = True,
        test_in_train: bool = True,
        **kwargs) -> Dict[str, Union[float, str]]:
    """Slightly modified Tianshou `onpolicy_trainer` original method to enable
    to define the maximum number of training steps instead of number of
    episodes, for consistency with other learning frameworks.
    """
    global_step = 0
    best_epoch, best_reward = -1, -1.0
    stat: Dict[str, MovAvg] = {}
    start_time = time.time()
    train_collector.reset_stat()
    test_collector.reset_stat()
    test_in_train = test_in_train and train_collector.policy == policy
    for epoch in range(1, 1 + max_epoch):
        # train
        policy.train()
        with tqdm.tqdm(
            total=frame_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config
        ) as t:
            while t.n < t.total:
                if train_fn:
                    train_fn(epoch, global_step)
                result = train_collector.collect(n_step=collect_per_step)
                data = {}
                if test_in_train and stop_fn and stop_fn(result["rew"]):
                    test_result = test_episode(
                        policy, test_collector, test_fn,
                        epoch, episode_per_test, writer, global_step)
                    if stop_fn(test_result["rew"]):
                        if save_fn:
                            save_fn(policy)
                        for k in result.keys():
                            data[k] = f"{result[k]:.2f}"
                        t.set_postfix(**data)
                        return gather_info(
                            start_time, train_collector, test_collector,
                            test_result["rew"])
                    else:
                        policy.train()
                losses = policy.update(
                    0, train_collector.buffer,
                    batch_size=batch_size, repeat=repeat_per_collect)
                train_collector.reset_buffer()
                step = 1
                for v in losses.values():
                    if isinstance(v, (list, tuple)):
                        step = max(step, len(v))
                global_step += step * collect_per_step
                for k in result.keys():
                    data[k] = f"{result[k]:.2f}"
                    if writer and global_step % log_interval == 0:
                        writer.add_scalar(
                            "train/" + k, result[k], global_step=global_step)
                for k in losses.keys():
                    if stat.get(k) is None:
                        stat[k] = MovAvg()
                    stat[k].add(losses[k])
                    data[k] = f"{stat[k].get():.6f}"
                    if writer and global_step % log_interval == 0:
                        writer.add_scalar(
                            k, stat[k].get(), global_step=global_step)
                t.update(collect_per_step)
                t.set_postfix(**data)
            if t.n <= t.total:
                t.update()
        # test
        result = test_episode(policy, test_collector, test_fn, epoch,
                              episode_per_test, writer, global_step)
        if best_epoch == -1 or best_reward < result["rew"]:
            best_reward = result["rew"]
            best_epoch = epoch
            if save_fn:
                save_fn(policy)
        if verbose:
            print(f"Epoch #{epoch}: test_reward: {result['rew']:.6f}, "
                  f"best_reward: {best_reward:.6f} in #{best_epoch}")
        if stop_fn and stop_fn(best_reward):
            break
    return gather_info(
        start_time, train_collector, test_collector, best_reward)
Exemple #15
0
def onpolicy_trainer(
    policy: BasePolicy,
    train_collector: Collector,
    max_epoch: int,
    step_per_epoch: int,
    collect_per_step: int,
    repeat_per_collect: int,
    batch_size: int,
    train_fn: Optional[Callable[[int], None]] = None,
    log_interval: int = 1,
    verbose: bool = True,
    test_in_train: bool = True,
    writer: Optional[SummaryWriter] = None,
):
    """A wrapper for on-policy trainer procedure. The ``step`` in trainer means
    a policy network update.

    :param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
        class.
    :param train_collector: the collector used for training.
    :type train_collector: :class:`~tianshou.data.Collector`
    :param test_collector: the collector used for testing.
    :type test_collector: :class:`~tianshou.data.Collector`
    :param int max_epoch: the maximum of epochs for training. The training
        process might be finished before reaching the ``max_epoch``.
    :param int step_per_epoch: the number of step for updating policy network
        in one epoch.
    :param int collect_per_step: the number of episodes the collector would
        collect before the network update. In other words, collect some
        episodes and do one policy network update.
    :param int repeat_per_collect: the number of repeat time for policy
        learning, for example, set it to 2 means the policy needs to learn each
        given batch data twice.
    :param int batch_size: the batch size of sample data, which is going to
        feed in the policy network.
    :param function train_fn: a function receives the current number of epoch
        index and performs some operations at the beginning of training in this
        epoch.
    :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard
        SummaryWriter.
    :param int log_interval: the log interval of the writer.
    :param bool verbose: whether to print the information.
    :param bool test_in_train: whether to test in the training phase.

    :return: See :func:`~tianshou.trainer.gather_info`.
    """
    global_step = 0
    best_epoch, best_reward = -1, -1
    stat = {}
    start_time = time.time()
    test_in_train = test_in_train and train_collector.policy == policy
    for epoch in range(1, 1 + max_epoch):
        # train
        policy.train()
        if train_fn:
            train_fn(epoch)
        with tqdm.tqdm(total=step_per_epoch,
                       desc=f'Epoch #{epoch}',
                       **tqdm_config) as t:
            while t.n < t.total:
                result = train_collector.collect(n_episode=collect_per_step)
                data = {}
                losses = policy.update(0, train_collector.buffer, batch_size,
                                       repeat_per_collect)
                train_collector.reset_buffer()
                step = 1
                for k in losses.keys():
                    if isinstance(losses[k], list):
                        step = max(step, len(losses[k]))
                global_step += step
                for k in result.keys():
                    data[k] = f'{result[k]:.2f}'
                    if writer and global_step % log_interval == 0:
                        writer.add_scalar(k,
                                          result[k],
                                          global_step=global_step)
                for k in losses.keys():
                    if stat.get(k) is None:
                        stat[k] = MovAvg()
                    stat[k].add(losses[k])
                    data[k] = f'{stat[k].get():.6f}'
                    if writer and global_step % log_interval == 0:
                        writer.add_scalar(k,
                                          stat[k].get(),
                                          global_step=global_step)
                t.update(step)
                t.set_postfix(**data)
            if t.n <= t.total:
                t.update()
    return global_step
class Collector(object):
    """The :class:`~tianshou.data.Collector` enables the policy to interact
    with different types of environments conveniently.

    :param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
        class.
    :param env: a ``gym.Env`` environment or an instance of the
        :class:`~tianshou.env.BaseVectorEnv` class.
    :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer`
        class, or a list of :class:`~tianshou.data.ReplayBuffer`. If set to
        ``None``, it will automatically assign a small-size
        :class:`~tianshou.data.ReplayBuffer`.
    :param function preprocess_fn: a function called before the data has been
        added to the buffer, see issue #42 and :ref:`preprocess_fn`, defaults
        to ``None``.
    :param int stat_size: for the moving average of recording speed, defaults
        to 100.
    :param BaseNoise action_noise: add a noise to continuous action. Normally
        a policy already has a noise param for exploration in training phase,
        so this is recommended to use in test collector for some purpose.
    :param function reward_metric: to be used in multi-agent RL. The reward to
        report is of shape [agent_num], but we need to return a single scalar
        to monitor training. This function specifies what is the desired
        metric, e.g., the reward of agent 1 or the average reward over all
        agents. By default, the behavior is to select the reward of agent 1.

    The ``preprocess_fn`` is a function called before the data has been added
    to the buffer with batch format, which receives up to 7 keys as listed in
    :class:`~tianshou.data.Batch`. It will receive with only ``obs`` when the
    collector resets the environment. It returns either a dict or a
    :class:`~tianshou.data.Batch` with the modified keys and values. Examples
    are in "test/base/test_collector.py".

    Example:
    ::

        policy = PGPolicy(...)  # or other policies if you wish
        env = gym.make('CartPole-v0')
        replay_buffer = ReplayBuffer(size=10000)
        # here we set up a collector with a single environment
        collector = Collector(policy, env, buffer=replay_buffer)

        # the collector supports vectorized environments as well
        envs = VectorEnv([lambda: gym.make('CartPole-v0') for _ in range(3)])
        buffers = [ReplayBuffer(size=5000) for _ in range(3)]
        # you can also pass a list of replay buffer to collector, for multi-env
        # collector = Collector(policy, envs, buffer=buffers)
        collector = Collector(policy, envs, buffer=replay_buffer)

        # collect at least 3 episodes
        collector.collect(n_episode=3)
        # collect 1 episode for the first env, 3 for the third env
        collector.collect(n_episode=[1, 0, 3])
        # collect at least 2 steps
        collector.collect(n_step=2)
        # collect episodes with visual rendering (the render argument is the
        #   sleep time between rendering consecutive frames)
        collector.collect(n_episode=1, render=0.03)

        # sample data with a given number of batch-size:
        batch_data = collector.sample(batch_size=64)
        # policy.learn(batch_data)  # btw, vanilla policy gradient only
        #   supports on-policy training, so here we pick all data in the buffer
        batch_data = collector.sample(batch_size=0)
        policy.learn(batch_data)
        # on-policy algorithms use the collected data only once, so here we
        #   clear the buffer
        collector.reset_buffer()

    For the scenario of collecting data from multiple environments to a single
    buffer, the cache buffers will turn on automatically. It may return the
    data more than the given limitation.

    .. note::

        Please make sure the given environment has a time limitation.
    """
    def __init__(
        self,
        policy: BasePolicy,
        env: Union[gym.Env, BaseVectorEnv],
        buffer: Optional[ReplayBuffer] = None,
        preprocess_fn: Callable[[Any], Union[dict, Batch]] = None,
        stat_size: Optional[int] = 100,
        action_noise: Optional[BaseNoise] = None,
        reward_metric: Optional[Callable[[np.ndarray], float]] = None,
    ) -> None:
        super().__init__()
        self.env = env
        self.env_num = 1
        self.collect_time, self.collect_step, self.collect_episode = 0., 0, 0
        self.buffer = buffer
        self.policy = policy
        self.preprocess_fn = preprocess_fn
        self.process_fn = policy.process_fn
        self._multi_env = isinstance(env, BaseVectorEnv)
        # need multiple cache buffers only if storing in one buffer
        self._cached_buf = []
        if self._multi_env:
            self.env_num = len(env)
            self._cached_buf = [
                ListReplayBuffer() for _ in range(self.env_num)
            ]
        self.stat_size = stat_size
        self._action_noise = action_noise

        self._rew_metric = reward_metric or Collector._default_rew_metric
        self.reset()

    @staticmethod
    def _default_rew_metric(x):
        # this internal function is designed for single-agent RL
        # for multi-agent RL, a reward_metric must be provided
        assert np.asanyarray(x).size == 1, \
            'Please specify the reward_metric ' \
            'since the reward is not a scalar.'
        return x

    def reset(self) -> None:
        """Reset all related variables in the collector."""
        self.data = Batch(state={},
                          obs={},
                          act={},
                          rew={},
                          done={},
                          info={},
                          obs_next={},
                          policy={})
        self.reset_env()
        self.reset_buffer()
        self.step_speed = MovAvg(self.stat_size)
        self.episode_speed = MovAvg(self.stat_size)
        self.collect_time, self.collect_step, self.collect_episode = 0., 0, 0
        if self._action_noise is not None:
            self._action_noise.reset()

    def reset_buffer(self) -> None:
        """Reset the main data buffer."""
        if self.buffer is not None:
            self.buffer.reset()

    def get_env_num(self) -> int:
        """Return the number of environments the collector have."""
        return self.env_num

    def reset_env(self) -> None:
        """Reset all of the environment(s)' states and reset all of the cache
        buffers (if need).
        """
        obs = self.env.reset()
        if not self._multi_env:
            obs = self._make_batch(obs)
        if self.preprocess_fn:
            obs = self.preprocess_fn(obs=obs).get('obs', obs)
        self.data.obs = obs
        self.reward = 0.  # will be specified when the first data is ready
        self.length = np.zeros(self.env_num)
        for b in self._cached_buf:
            b.reset()

    def seed(self, seed: Optional[Union[int, List[int]]] = None) -> None:
        """Reset all the seed(s) of the given environment(s)."""
        return self.env.seed(seed)

    def render(self, **kwargs) -> None:
        """Render all the environment(s)."""
        return self.env.render(**kwargs)

    def close(self) -> None:
        """Close the environment(s)."""
        self.env.close()

    def _make_batch(self, data: Any) -> np.ndarray:
        """Return [data]."""
        if isinstance(data, np.ndarray):
            return data[None]
        else:
            return np.array([data])

    def _reset_state(self, id: Union[int, List[int]]) -> None:
        """Reset self.data.state[id]."""
        state = self.data.state  # it is a reference
        if isinstance(state, torch.Tensor):
            state[id].zero_()
        elif isinstance(state, np.ndarray):
            state[id] = None if state.dtype == np.object else 0
        elif isinstance(state, Batch):
            state.empty_(id)

    def collect(
            self,
            n_step: int = 0,
            n_episode: Union[int, List[int]] = 0,
            random: bool = False,
            render: Optional[float] = None,
            log_fn: Optional[Callable[[dict],
                                      None]] = None) -> Dict[str, float]:
        """Collect a specified number of step or episode.

        :param int n_step: how many steps you want to collect.
        :param n_episode: how many episodes you want to collect (in each
            environment).
        :type n_episode: int or list
        :param bool random: whether to use random policy for collecting data,
            defaults to ``False``.
        :param float render: the sleep time between rendering consecutive
            frames, defaults to ``None`` (no rendering).
        :param function log_fn: a function which receives env info, typically
            for tensorboard logging.

        .. note::

            One and only one collection number specification is permitted,
            either ``n_step`` or ``n_episode``.

        :return: A dict including the following keys

            * ``n/ep`` the collected number of episodes.
            * ``n/st`` the collected number of steps.
            * ``v/st`` the speed of steps per second.
            * ``v/ep`` the speed of episode per second.
            * ``rew`` the mean reward over collected episodes.
            * ``len`` the mean length over collected episodes.
        """
        if not self._multi_env:
            n_episode = np.sum(n_episode)
        start_time = time.time()
        assert sum([(n_step != 0), (n_episode != 0)]) == 1, \
            "One and only one collection number specification is permitted!"
        cur_step, cur_episode = 0, np.zeros(self.env_num)
        reward_sum, length_sum = 0., 0

        # change
        ty1_succ_rate_1 = 0.
        ty1_succ_rate_2 = 0.
        ty1_succ_rate_3 = 0.
        ty1_succ_rate_4 = 0.
        Q_len_1 = 0.
        Q_len_2 = 0.
        Q_len_3 = 0.
        Q_len_4 = 0.
        energy_effi_1 = 0.
        energy_effi_2 = 0.
        energy_effi_3 = 0.
        energy_effi_4 = 0.
        avg_rate = 0.
        avg_power = 0.

        while True:
            if cur_step >= 100000 and cur_episode.sum() == 0:
                warnings.warn(
                    'There are already many steps in an episode. '
                    'You should add a time limitation to your environment!',
                    Warning)

            # restore the state and the input data
            last_state = self.data.state
            if last_state.is_empty():
                last_state = None
            self.data.update(state=Batch(), obs_next=Batch(), policy=Batch())

            # calculate the next action
            if random:
                action_space = self.env.action_space
                if isinstance(action_space, list):
                    result = Batch(act=[a.sample() for a in action_space])
                else:
                    result = Batch(act=self._make_batch(action_space.sample()))
            else:
                with torch.no_grad():
                    result = self.policy(self.data, last_state)

            # convert None to Batch(), since None is reserved for 0-init
            state = result.get('state', Batch())
            if state is None:
                state = Batch()
            self.data.state = state
            if hasattr(result, 'policy'):
                self.data.policy = to_numpy(result.policy)
            # save hidden state to policy._state, in order to save into buffer
            self.data.policy._state = self.data.state

            self.data.act = to_numpy(result.act)
            if self._action_noise is not None:
                self.data.act += self._action_noise(self.data.act.shape)

            # step in env
            obs_next, rew, done, info = self.env.step(
                self.data.act if self._multi_env else self.data.act[0])

            # move data to self.data
            if not self._multi_env:
                obs_next = self._make_batch(obs_next)
                rew = self._make_batch(rew)
                done = self._make_batch(done)
                info = self._make_batch(info)
            self.data.obs_next = obs_next
            self.data.rew = rew
            self.data.done = done
            self.data.info = info

            if log_fn:
                log_fn(info if self._multi_env else info[0])
            if render:
                self.render()
                if render > 0:
                    time.sleep(render)

            # add data into the buffer
            self.length += 1
            self.reward += self.data.rew
            if self.preprocess_fn:
                result = self.preprocess_fn(**self.data)
                self.data.update(result)
            if self._multi_env:  # cache_buffer branch
                # change
                if self.data.done[0]:
                    ty1_succ_rate_1 += self.data.info[0]['ty1_succ_rate_1']
                    ty1_succ_rate_2 += self.data.info[0]['ty1_succ_rate_2']
                    ty1_succ_rate_3 += self.data.info[0]['ty1_succ_rate_3']
                    ty1_succ_rate_4 += self.data.info[0]['ty1_succ_rate_4']
                    Q_len_1 += self.data.info[0]['Q_len_1']
                    Q_len_2 += self.data.info[0]['Q_len_2']
                    Q_len_3 += self.data.info[0]['Q_len_3']
                    Q_len_4 += self.data.info[0]['Q_len_4']
                    energy_effi_1 += self.data.info[0]['energy_effi_1']
                    energy_effi_2 += self.data.info[0]['energy_effi_2']
                    energy_effi_3 += self.data.info[0]['energy_effi_3']
                    energy_effi_4 += self.data.info[0]['energy_effi_4']
                    avg_rate += self.data.info[0]['avg_rate']
                    avg_power += self.data.info[0]['avg_power']
                for i in range(self.env_num):
                    self._cached_buf[i].add(**self.data[i])
                    if self.data.done[i]:
                        if n_step != 0 or np.isscalar(n_episode) or \
                                cur_episode[i] < n_episode[i]:
                            cur_episode[i] += 1
                            reward_sum += self.reward[i]
                            length_sum += self.length[i]
                            if self._cached_buf:
                                cur_step += len(self._cached_buf[i])
                                if self.buffer is not None:
                                    self.buffer.update(self._cached_buf[i])
                        self.reward[i], self.length[i] = 0., 0
                        if self._cached_buf:
                            self._cached_buf[i].reset()
                        self._reset_state(i)
                obs_next = self.data.obs_next
                if sum(self.data.done):
                    env_ind = np.where(self.data.done)[0]
                    obs_reset = self.env.reset(env_ind)
                    if self.preprocess_fn:
                        obs_next[env_ind] = self.preprocess_fn(
                            obs=obs_reset).get('obs', obs_reset)
                    else:
                        obs_next[env_ind] = obs_reset
                self.data.obs_next = obs_next
                if n_episode != 0:
                    if isinstance(n_episode, list) and \
                            (cur_episode >= np.array(n_episode)).all() or \
                            np.isscalar(n_episode) and \
                            cur_episode.sum() >= n_episode:
                        break
            else:  # single buffer, without cache_buffer
                if self.buffer is not None:
                    self.buffer.add(**self.data[0])
                cur_step += 1
                if self.data.done[0]:
                    # change
                    ty1_succ_rate_1 += self.data.info['ty1_succ_rate_1']
                    ty1_succ_rate_2 += self.data.info['ty1_succ_rate_2']
                    ty1_succ_rate_3 += self.data.info['ty1_succ_rate_3']
                    ty1_succ_rate_4 += self.data.info['ty1_succ_rate_4']
                    Q_len_1 += self.data.info['Q_len_1']
                    Q_len_2 += self.data.info['Q_len_2']
                    Q_len_3 += self.data.info['Q_len_3']
                    Q_len_4 += self.data.info['Q_len_4']
                    energy_effi_1 += self.data.info['energy_effi_1']
                    energy_effi_2 += self.data.info['energy_effi_2']
                    energy_effi_3 += self.data.info['energy_effi_3']
                    energy_effi_4 += self.data.info['energy_effi_4']
                    avg_rate += self.data.info[0]['avg_rate']
                    avg_power += self.data.info[0]['avg_power']
                    cur_episode += 1
                    reward_sum += self.reward[0]
                    length_sum += self.length[0]
                    self.reward, self.length = 0., np.zeros(self.env_num)
                    self.data.state = Batch()
                    obs_next = self._make_batch(self.env.reset())
                    if self.preprocess_fn:
                        obs_next = self.preprocess_fn(obs=obs_next).get(
                            'obs', obs_next)
                    self.data.obs_next = obs_next
                if n_episode != 0 and cur_episode >= n_episode:
                    break
            if n_step != 0 and cur_step >= n_step:
                break
            self.data.obs = self.data.obs_next
        self.data.obs = self.data.obs_next

        # generate the statistics
        cur_episode = sum(cur_episode)
        duration = max(time.time() - start_time, 1e-9)
        self.step_speed.add(cur_step / duration)
        self.episode_speed.add(cur_episode / duration)
        self.collect_step += cur_step
        self.collect_episode += cur_episode
        self.collect_time += duration
        if isinstance(n_episode, list):
            n_episode = np.sum(n_episode)
        else:
            n_episode = max(cur_episode, 1)
        reward_sum /= n_episode
        if np.asanyarray(reward_sum).size > 1:  # non-scalar reward_sum
            reward_sum = self._rew_metric(reward_sum)
        # change
        return {
            'n/ep': cur_episode,
            'n/st': cur_step,
            'v/st': self.step_speed.get(),
            'v/ep': self.episode_speed.get(),
            'rew': reward_sum,
            'len': length_sum / n_episode,
            'ty1s_1': ty1_succ_rate_1,
            'ty1s_2': ty1_succ_rate_2,
            'ty1s_3': ty1_succ_rate_3,
            'ty1s_4': ty1_succ_rate_4,
            'ql_1': Q_len_1,
            'ql_2': Q_len_2,
            'ql_3': Q_len_3,
            'ql_4': Q_len_4,
            'ee_1': energy_effi_1,
            'ee_2': energy_effi_2,
            'ee_3': energy_effi_3,
            'ee_4': energy_effi_4,
            'avg_r': avg_rate,
            'avg_p': avg_power,
        }

    def sample(self, batch_size: int) -> Batch:
        """Sample a data batch from the internal replay buffer. It will call
        :meth:`~tianshou.policy.BasePolicy.process_fn` before returning
        the final batch data.

        :param int batch_size: ``0`` means it will extract all the data from
            the buffer, otherwise it will extract the data with the given
            batch_size.
        """
        batch_data, indice = self.buffer.sample(batch_size)
        batch_data = self.process_fn(batch_data, self.buffer, indice)
        return batch_data
Exemple #17
0
def offpolicy_trainer(policy: BasePolicy,
                      train_collector: Collector,
                      test_collector: Collector,
                      max_epoch: int,
                      step_per_epoch: int,
                      collect_per_step: int,
                      episode_per_test: Union[int, List[int]],
                      batch_size: int,
                      update_per_step: int = 1,
                      train_fn: Optional[Callable[[int], None]] = None,
                      test_fn: Optional[Callable[[int], None]] = None,
                      stop_fn: Optional[Callable[[float], bool]] = None,
                      save_fn: Optional[Callable[[BasePolicy], None]] = None,
                      log_fn: Optional[Callable[[dict], None]] = None,
                      writer: Optional[SummaryWriter] = None,
                      log_interval: int = 1,
                      verbose: bool = True,
                      **kwargs) -> Dict[str, Union[float, str]]:
    """A wrapper for off-policy trainer procedure.

    :param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
        class.
    :param train_collector: the collector used for training.
    :type train_collector: :class:`~tianshou.data.Collector`
    :param test_collector: the collector used for testing.
    :type test_collector: :class:`~tianshou.data.Collector`
    :param int max_epoch: the maximum of epochs for training. The training
        process might be finished before reaching the ``max_epoch``.
    :param int step_per_epoch: the number of step for updating policy network
        in one epoch.
    :param int collect_per_step: the number of frames the collector would
        collect before the network update. In other words, collect some frames
        and do some policy network update.
    :param episode_per_test: the number of episodes for one policy evaluation.
    :param int batch_size: the batch size of sample data, which is going to
        feed in the policy network.
    :param int update_per_step: the number of times the policy network would
        be updated after frames be collected. In other words, collect some
        frames and do some policy network update.
    :param function train_fn: a function receives the current number of epoch
        index and performs some operations at the beginning of training in this
        epoch.
    :param function test_fn: a function receives the current number of epoch
        index and performs some operations at the beginning of testing in this
        epoch.
    :param function save_fn: a function for saving policy when the undiscounted
        average mean reward in evaluation phase gets better.
    :param function stop_fn: a function receives the average undiscounted
        returns of the testing result, return a boolean which indicates whether
        reaching the goal.
    :param function log_fn: a function receives env info for logging.
    :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard
        SummaryWriter.
    :param int log_interval: the log interval of the writer.
    :param bool verbose: whether to print the information.

    :return: See :func:`~tianshou.trainer.gather_info`.
    """
    global_step = 0
    best_epoch, best_reward = -1, -1
    stat = {}
    start_time = time.time()
    test_in_train = train_collector.policy == policy
    for epoch in range(1, 1 + max_epoch):
        # train
        policy.train()
        if train_fn:
            train_fn(epoch)
        with tqdm.tqdm(total=step_per_epoch,
                       desc=f'Epoch #{epoch}',
                       **tqdm_config) as t:
            while t.n < t.total:
                result = train_collector.collect(n_step=collect_per_step,
                                                 log_fn=log_fn)
                data = {}
                if test_in_train and stop_fn and stop_fn(result['rew']):
                    test_result = test_episode(policy, test_collector, test_fn,
                                               epoch, episode_per_test)
                    if stop_fn and stop_fn(test_result['rew']):
                        if save_fn:
                            save_fn(policy)
                        for k in result.keys():
                            data[k] = f'{result[k]:.2f}'
                        t.set_postfix(**data)
                        return gather_info(start_time, train_collector,
                                           test_collector, test_result['rew'])
                    else:
                        policy.train()
                        if train_fn:
                            train_fn(epoch)
                for i in range(update_per_step * min(
                        result['n/st'] // collect_per_step, t.total - t.n)):
                    global_step += 1
                    losses = policy.learn(train_collector.sample(batch_size))
                    for k in result.keys():
                        data[k] = f'{result[k]:.2f}'
                        if writer and global_step % log_interval == 0:
                            writer.add_scalar(k,
                                              result[k],
                                              global_step=global_step)
                    for k in losses.keys():
                        if stat.get(k) is None:
                            stat[k] = MovAvg()
                        stat[k].add(losses[k])
                        data[k] = f'{stat[k].get():.6f}'
                        if writer and global_step % log_interval == 0:
                            writer.add_scalar(k,
                                              stat[k].get(),
                                              global_step=global_step)
                    t.update(1)
                    t.set_postfix(**data)
            if t.n <= t.total:
                t.update()
        # test
        result = test_episode(policy, test_collector, test_fn, epoch,
                              episode_per_test)
        if best_epoch == -1 or best_reward < result['rew']:
            best_reward = result['rew']
            best_epoch = epoch
            if save_fn:
                save_fn(policy)
        if verbose:
            print(f'Epoch #{epoch}: test_reward: {result["rew"]:.6f}, '
                  f'best_reward: {best_reward:.6f} in #{best_epoch}')
        if stop_fn and stop_fn(best_reward):
            break
    return gather_info(start_time, train_collector, test_collector,
                       best_reward)
Exemple #18
0
def offpolicy_trainer(policy,
                      train_collector,
                      test_collector,
                      max_epoch,
                      step_per_epoch,
                      collect_per_step,
                      episode_per_test,
                      batch_size,
                      train_fn=None,
                      test_fn=None,
                      stop_fn=None,
                      writer=None,
                      log_interval=1,
                      verbose=True,
                      task=''):
    global_step = 0
    best_epoch, best_reward = -1, -1
    stat = {}
    start_time = time.time()
    for epoch in range(1, 1 + max_epoch):
        # train
        policy.train()
        if train_fn:
            train_fn(epoch)
        with tqdm.tqdm(total=step_per_epoch,
                       desc=f'Epoch #{epoch}',
                       **tqdm_config) as t:
            while t.n < t.total:
                result = train_collector.collect(n_step=collect_per_step)
                data = {}
                if stop_fn and stop_fn(result['rew']):
                    test_result = test_episode(policy, test_collector, test_fn,
                                               epoch, episode_per_test)
                    if stop_fn and stop_fn(test_result['rew']):
                        for k in result.keys():
                            data[k] = f'{result[k]:.2f}'
                        t.set_postfix(**data)
                        return gather_info(start_time, train_collector,
                                           test_collector, test_result['rew'])
                    else:
                        policy.train()
                        if train_fn:
                            train_fn(epoch)
                for i in range(
                        min(result['n/st'] // collect_per_step,
                            t.total - t.n)):
                    global_step += 1
                    losses = policy.learn(train_collector.sample(batch_size))
                    for k in result.keys():
                        data[k] = f'{result[k]:.2f}'
                        if writer and global_step % log_interval == 0:
                            writer.add_scalar(k + '_' + task if task else k,
                                              result[k],
                                              global_step=global_step)
                    for k in losses.keys():
                        if stat.get(k) is None:
                            stat[k] = MovAvg()
                        stat[k].add(losses[k])
                        data[k] = f'{stat[k].get():.6f}'
                        if writer and global_step % log_interval == 0:
                            writer.add_scalar(k + '_' + task if task else k,
                                              stat[k].get(),
                                              global_step=global_step)
                    t.update(1)
                    t.set_postfix(**data)
            if t.n <= t.total:
                t.update()
        # test
        result = test_episode(policy, test_collector, test_fn, epoch,
                              episode_per_test)
        if best_epoch == -1 or best_reward < result['rew']:
            best_reward = result['rew']
            best_epoch = epoch
        if verbose:
            print(f'Epoch #{epoch}: test_reward: {result["rew"]:.6f}, '
                  f'best_reward: {best_reward:.6f} in #{best_epoch}')
        if stop_fn and stop_fn(best_reward):
            break
    return gather_info(start_time, train_collector, test_collector,
                       best_reward)
Exemple #19
0
def onpolicy_trainer(
    policy: BasePolicy,
    train_collector: Collector,
    test_collector: Collector,
    max_epoch: int,
    step_per_epoch: int,
    collect_per_step: int,
    repeat_per_collect: int,
    episode_per_test: Union[int, List[int]],
    batch_size: int,
    train_fn: Optional[Callable[[int, int], None]] = None,
    test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
    stop_fn: Optional[Callable[[float], bool]] = None,
    save_fn: Optional[Callable[[BasePolicy], None]] = None,
    writer: Optional[SummaryWriter] = None,
    log_interval: int = 1,
    verbose: bool = True,
    test_in_train: bool = True,
) -> Dict[str, Union[float, str]]:
    """A wrapper for on-policy trainer procedure.

    The "step" in trainer means a policy network update.

    :param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
        class.
    :param train_collector: the collector used for training.
    :type train_collector: :class:`~tianshou.data.Collector`
    :param test_collector: the collector used for testing.
    :type test_collector: :class:`~tianshou.data.Collector`
    :param int max_epoch: the maximum of epochs for training. The training
        process might be finished before reaching the ``max_epoch``.
    :param int step_per_epoch: the number of step for updating policy network
        in one epoch.
    :param int collect_per_step: the number of episodes the collector would
        collect before the network update. In other words, collect some
        episodes and do one policy network update.
    :param int repeat_per_collect: the number of repeat time for policy
        learning, for example, set it to 2 means the policy needs to learn each
        given batch data twice.
    :param episode_per_test: the number of episodes for one policy evaluation.
    :type episode_per_test: int or list of ints
    :param int batch_size: the batch size of sample data, which is going to
        feed in the policy network.
    :param function train_fn: a function receives the current number of epoch
        and step index, and performs some operations at the beginning of
        training in this poch.
    :param function test_fn: a function receives the current number of epoch
        and step index, and performs some operations at the beginning of
        testing in this epoch.
    :param function save_fn: a function for saving policy when the undiscounted
        average mean reward in evaluation phase gets better.
    :param function stop_fn: a function receives the average undiscounted
        returns of the testing result, return a boolean which indicates whether
        reaching the goal.
    :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard
        SummaryWriter.
    :param int log_interval: the log interval of the writer.
    :param bool verbose: whether to print the information.
    :param bool test_in_train: whether to test in the training phase.

    :return: See :func:`~tianshou.trainer.gather_info`.
    """
    env_step, gradient_step = 0, 0
    best_epoch, best_reward, best_reward_std = -1, -1.0, 0.0
    stat: Dict[str, MovAvg] = {}
    start_time = time.time()
    train_collector.reset_stat()
    test_collector.reset_stat()
    test_in_train = test_in_train and train_collector.policy == policy
    for epoch in range(1, 1 + max_epoch):
        # train
        policy.train()
        with tqdm.tqdm(
            total=step_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config
        ) as t:
            while t.n < t.total:
                if train_fn:
                    train_fn(epoch, env_step)
                result = train_collector.collect(n_episode=collect_per_step)
                env_step += int(result["n/st"])
                data = {
                    "env_step": str(env_step),
                    "rew": f"{result['rew']:.2f}",
                    "len": str(int(result["len"])),
                    "n/ep": str(int(result["n/ep"])),
                    "n/st": str(int(result["n/st"])),
                    "v/ep": f"{result['v/ep']:.2f}",
                    "v/st": f"{result['v/st']:.2f}",
                }
                if writer and env_step % log_interval == 0:
                    for k in result.keys():
                        writer.add_scalar(
                            "train/" + k, result[k], global_step=env_step)
                if test_in_train and stop_fn and stop_fn(result["rew"]):
                    test_result = test_episode(
                        policy, test_collector, test_fn,
                        epoch, episode_per_test, writer, env_step)
                    if stop_fn(test_result["rew"]):
                        if save_fn:
                            save_fn(policy)
                        for k in result.keys():
                            data[k] = f"{result[k]:.2f}"
                        t.set_postfix(**data)
                        return gather_info(
                            start_time, train_collector, test_collector,
                            test_result["rew"], test_result["rew_std"])
                    else:
                        policy.train()
                losses = policy.update(
                    0, train_collector.buffer,
                    batch_size=batch_size, repeat=repeat_per_collect)
                train_collector.reset_buffer()
                step = max([1] + [
                    len(v) for v in losses.values() if isinstance(v, list)])
                gradient_step += step
                for k in losses.keys():
                    if stat.get(k) is None:
                        stat[k] = MovAvg()
                    stat[k].add(losses[k])
                    data[k] = f"{stat[k].get():.6f}"
                    if writer and gradient_step % log_interval == 0:
                        writer.add_scalar(
                            k, stat[k].get(), global_step=gradient_step)
                t.update(step)
                t.set_postfix(**data)
            if t.n <= t.total:
                t.update()
        # test
        result = test_episode(policy, test_collector, test_fn, epoch,
                              episode_per_test, writer, env_step)
        if best_epoch == -1 or best_reward < result["rew"]:
            best_reward, best_reward_std = result["rew"], result["rew_std"]
            best_epoch = epoch
            if save_fn:
                save_fn(policy)
        if verbose:
            print(f"Epoch #{epoch}: test_reward: {result['rew']:.6f} ± "
                  f"{result['rew_std']:.6f}, best_reward: {best_reward:.6f} ± "
                  f"{best_reward_std:.6f} in #{best_epoch}")
        if stop_fn and stop_fn(best_reward):
            break
    return gather_info(start_time, train_collector, test_collector,
                       best_reward, best_reward_std)
Exemple #20
0
def onpolicy_trainer(policy,
                     train_collector,
                     test_collector,
                     max_epoch,
                     step_per_epoch,
                     collect_per_step,
                     repeat_per_collect,
                     episode_per_test,
                     batch_size,
                     train_fn=None,
                     test_fn=None,
                     stop_fn=None,
                     writer=None,
                     log_interval=1,
                     verbose=True,
                     task='',
                     **kwargs):
    """A wrapper for on-policy trainer procedure.

    Parameters
        * **policy** – an instance of the :class:`~tianshou.policy.BasePolicy`\
            class.
        * **train_collector** – the collector used for training.
        * **test_collector** – the collector used for testing.
        * **max_epoch** – the maximum of epochs for training. The training \
            process might be finished before reaching the ``max_epoch``.
        * **step_per_epoch** – the number of step for updating policy network \
            in one epoch.
        * **collect_per_step** – the number of frames the collector would \
            collect before the network update. In other words, collect some \
            frames and do one policy network update.
        * **repeat_per_collect** – the number of repeat time for policy \
            learning, for example, set it to 2 means the policy needs to learn\
            each given batch data twice.
        * **episode_per_test** – the number of episodes for one policy \
            evaluation.
        * **batch_size** – the batch size of sample data, which is going to \
            feed in the policy network.
        * **train_fn** – a function receives the current number of epoch index\
            and performs some operations at the beginning of training in this \
            epoch.
        * **test_fn** – a function receives the current number of epoch index \
            and performs some operations at the beginning of testing in this \
            epoch.
        * **stop_fn** – a function receives the average undiscounted returns \
            of the testing result, return a boolean which indicates whether \
            reaching the goal.
        * **writer** – a SummaryWriter provided from TensorBoard.
        * **log_interval** – an int indicating the log interval of the writer.
        * **verbose** – a boolean indicating whether to print the information.

    :return: See :func:`~tianshou.trainer.gather_info`.
    """
    global_step = 0
    best_epoch, best_reward = -1, -1
    stat = {}
    start_time = time.time()
    for epoch in range(1, 1 + max_epoch):
        # train
        policy.train()
        if train_fn:
            train_fn(epoch)
        with tqdm.tqdm(total=step_per_epoch,
                       desc=f'Epoch #{epoch}',
                       **tqdm_config) as t:
            while t.n < t.total:
                result = train_collector.collect(n_episode=collect_per_step)
                data = {}
                if stop_fn and stop_fn(result['rew']):
                    test_result = test_episode(policy, test_collector, test_fn,
                                               epoch, episode_per_test)
                    if stop_fn and stop_fn(test_result['rew']):
                        for k in result.keys():
                            data[k] = f'{result[k]:.2f}'
                        t.set_postfix(**data)
                        return gather_info(start_time, train_collector,
                                           test_collector, test_result['rew'])
                    else:
                        policy.train()
                        if train_fn:
                            train_fn(epoch)
                losses = policy.learn(train_collector.sample(0), batch_size,
                                      repeat_per_collect)
                train_collector.reset_buffer()
                step = 1
                for k in losses.keys():
                    if isinstance(losses[k], list):
                        step = max(step, len(losses[k]))
                global_step += step
                for k in result.keys():
                    data[k] = f'{result[k]:.2f}'
                    if writer and global_step % log_interval == 0:
                        writer.add_scalar(k + '_' + task if task else k,
                                          result[k],
                                          global_step=global_step)
                for k in losses.keys():
                    if stat.get(k) is None:
                        stat[k] = MovAvg()
                    stat[k].add(losses[k])
                    data[k] = f'{stat[k].get():.6f}'
                    if writer and global_step % log_interval == 0:
                        writer.add_scalar(k + '_' + task if task else k,
                                          stat[k].get(),
                                          global_step=global_step)
                t.update(step)
                t.set_postfix(**data)
            if t.n <= t.total:
                t.update()
        # test
        result = test_episode(policy, test_collector, test_fn, epoch,
                              episode_per_test)
        if best_epoch == -1 or best_reward < result['rew']:
            best_reward = result['rew']
            best_epoch = epoch
        if verbose:
            print(f'Epoch #{epoch}: test_reward: {result["rew"]:.6f}, '
                  f'best_reward: {best_reward:.6f} in #{best_epoch}')
        if stop_fn and stop_fn(best_reward):
            break
    return gather_info(start_time, train_collector, test_collector,
                       best_reward)
def test_moving_average():
    stat = MovAvg(10)
    assert np.allclose(stat.get(), 0)
    assert np.allclose(stat.mean(), 0)
    assert np.allclose(stat.std() ** 2, 0)
    stat.add(torch.tensor([1]))
    stat.add(np.array([2]))
    stat.add([3, 4])
    stat.add(5.)
    assert np.allclose(stat.get(), 3)
    assert np.allclose(stat.mean(), 3)
    assert np.allclose(stat.std() ** 2, 2)
Exemple #22
0
def onpolicy_trainer_with_views(A, B, max_epoch, step_per_epoch, collect_per_step,
                                repeat_per_collect, episode_per_test, batch_size,
                                copier=False, peer=0, verbose=True, test_fn=None,
                                task='', copier_batch_size=0):
    global_step = 0
    best_epoch, best_reward = -1, -1
    stat = {}
    start_time = time.time()
    for epoch in range(1, 1 + max_epoch):
        # train
        A.train()
        B.train()
        with tqdm.tqdm(
                total=step_per_epoch, desc=f'Epoch #{epoch}',
                **tqdm_config) as t:
            while t.n < t.total:
                for view, other_view in zip([A, B], [B, A]):
                    result = view.train_collector.collect(n_episode=collect_per_step)
                    data = {}
                    if view.stop_fn and view.stop_fn(result['rew']):
                        test_result = test_episode(
                            view.policy, view.test_collector, test_fn,
                            epoch, episode_per_test)
                        if view.stop_fn and view.stop_fn(test_result['rew']):
                            for k in result.keys():
                                data[k] = f'{result[k]:.2f}'
                            t.set_postfix(**data)
                            return gather_info(
                                start_time, view.train_collector, view.test_collector,
                                test_result['rew'])
                        else:
                            view.train()

                    batch = view.train_collector.sample(0)
                    for obs in batch.obs:
                        view.buffer.add(obs, {}, {}, {}, {}, {})  # only need obs
                    losses = view.policy.learn(batch, batch_size, repeat_per_collect)

                    # Learn from demonstration
                    if copier:
                        copier_batch, indice = view.buffer.sample(copier_batch_size)
                        # copier_batch = view.train_collector.process_fn(
                        #     copier_batch, view.buffer, indice)
                        demo = other_view.policy(copier_batch)
                        view.learn_from_demos(copier_batch, demo, peer=peer)

                    view.train_collector.reset_buffer()
                    step = 1
                    for k in losses.keys():
                        if isinstance(losses[k], list):
                            step = max(step, len(losses[k]))
                    global_step += step
                    for k in result.keys():
                        data[k] = f'{result[k]:.2f}'
                        view.writer.add_scalar(
                            k + '_' + task if task else k,
                            result[k], global_step=global_step)
                    for k in losses.keys():
                        if stat.get(k) is None:
                            stat[k] = MovAvg()
                        stat[k].add(losses[k])
                        data[k] = f'{stat[k].get():.4f}'
                        if global_step:
                            view.writer.add_scalar(
                                k + '_' + task if task else k,
                                stat[k].get(), global_step=global_step)
                    t.update(step)
                    t.set_postfix(**data)
            if t.n <= t.total:
                t.update()
        # test
        brk = False
        for view in A, B:
            result = test_episode(
                view.policy, view.test_collector, test_fn, epoch, episode_per_test)
            if best_epoch == -1 or best_reward < result['rew']:
                best_reward = result['rew']
                best_epoch = epoch
            if verbose:
                print(f'Epoch #{epoch}: test_reward: {result["rew"]:.4f}, '
                      f'best_reward: {best_reward:.4f} in #{best_epoch}')
            if view.stop_fn(best_reward):
                brk = True
        if brk:
            break
    return (
        gather_info(start_time, A.train_collector, A.test_collector, best_reward),
        gather_info(start_time, B.train_collector, B.test_collector, best_reward),
    )