def test_episode( policy: BasePolicy, collector: Collector, test_fn: Optional[Callable[[int, Optional[int]], None]], epoch: int, n_episode: Union[int, List[int]], writer: Optional[SummaryWriter] = None, global_step: Optional[int] = None, ) -> Dict[str, float]: """A simple wrapper of testing policy in collector.""" collector.reset_env() collector.reset_buffer() policy.eval() if test_fn: test_fn(epoch, global_step) if collector.get_env_num() > 1 and isinstance(n_episode, int): n = collector.get_env_num() n_ = np.zeros(n) + n_episode // n n_[:n_episode % n] += 1 n_episode = list(n_) result = collector.collect(n_episode=n_episode) if writer is not None and global_step is not None: for k in result.keys(): writer.add_scalar("test/" + k, result[k], global_step=global_step) return result
def test_episode(policy: BasePolicy, collector: Collector, test_fn: Callable[[int], None], epoch: int, n_episode: Union[int, List[int]]) -> Dict[str, float]: """A simple wrapper of testing policy in collector.""" collector.reset_env() collector.reset_buffer() policy.eval() if test_fn: test_fn(epoch) if collector.get_env_num() > 1 and np.isscalar(n_episode): n = collector.get_env_num() n_ = np.zeros(n) + n_episode // n n_[:n_episode % n] += 1 n_episode = list(n_) return collector.collect(n_episode=n_episode)