def __init__(self,
                 env,
                 act_fn,
                 device,
                 batch_size=32,
                 rollout_length=None,
                 gamma=0.99,
                 lambda_=0.95,
                 norm_advantages=False):
        """Init."""
        self.env = ensure_vec_env(env)
        self.nenv = self.env.num_envs
        self.act = act_fn
        self.device = device
        self.batch_size = batch_size
        self.gamma = gamma
        self.lambda_ = lambda_
        self.norm_advantages = norm_advantages
        self.rollout_length = rollout_length

        if rollout_length:
            self.storage = RolloutStorage(self.nenv,
                                          device=self.device,
                                          num_steps=self.rollout_length)
        else:
            self.storage = RolloutStorage(self.nenv, device=self.device)
        self._initialized = False
Beispiel #2
0
 def __init__(self,
              buffer,
              env,
              act_fn,
              device,
              learning_starts=1000,
              update_period=1):
     """Init."""
     self.env = ensure_vec_env(env)
     if self.env.num_envs > 1 and not isinstance(buffer,
                                                 BatchedReplayBuffer):
         raise ValueError(
             "when num_envs > 1, you must pass a BatchedReplayBuffer"
             " to the ReplayBufferDataManager.")
     if not isinstance(buffer, BatchedReplayBuffer):
         buffer = BatchedReplayBuffer(buffer)
     if self.env.num_envs != buffer.n:
         raise ValueError(
             f"Found {self.env.num_envs} envs and {buffer.n} "
             "buffers. The number of envs must be equal to the "
             "number of buffers!")
     self.act = act_fn
     self.buffer = buffer
     self.device = device
     self.learning_starts = learning_starts
     self.update_period = update_period
     self._ob = None
Beispiel #3
0
def rl_evaluate(env,
                actor,
                nepisodes,
                outfile=None,
                device='cpu',
                save_info=False):
    """Compute episode stats for an environment and actor.

    If the environment has an EpisodeInfo Wrapper, rl_record will use that
    to determine episode termination.
    Args:
        env: A Gym environment
        actor: A torch.nn.Module whose input is an observation and output has a
               '.action' attribute.
        nepisodes: The number of episodes to run.
        outfile: Where to save results (if provided)
        device: The device which contains the actor.
        save_info: Save the info dict with results.
    Returns:
        A dict of episode stats

    """
    if nepisodes == 0:
        return
    env = ensure_vec_env(env)
    ep_lengths = []
    ep_rewards = []
    all_infos = []
    actor = Actor(actor, device)
    while len(ep_lengths) < nepisodes:
        _dones = np.zeros(env.num_envs, dtype=np.bool)
        all_infos.extend([[] for _ in range(env.num_envs)])
        obs = env.reset()
        while not np.all(_dones):
            obs, rs, dones, infos = env.step(actor(obs))
            for i, _ in enumerate(dones):
                dones[i] = infos[i]['episode_info']['done']
                if not _dones[i] and save_info:
                    all_infos[-i - 1].append(infos[i])
            _dones = np.logical_or(dones, _dones)

        # save results
        for i, info in enumerate(infos):
            ep_lengths.append(infos[i]['episode_info']['length'])
            ep_rewards.append(infos[i]['episode_info']['reward'])

    outs = {
        'episode_lengths': ep_lengths,
        'episode_rewards': ep_rewards,
        'mean_length': np.mean(ep_lengths),
        'mean_reward': np.mean(ep_rewards),
    }
    if save_info:
        outs['info'] = all_infos
    if outfile:
        torch.save(outs, outfile)
    return outs
 def __init__(self,
              buffer,
              env,
              act_fn,
              device,
              learning_starts=1000,
              update_period=1):
     """Init."""
     self.env = ensure_vec_env(env)
     if self.env.num_envs > 1:
         raise ValueError("ReplayBufferDataManager is only compatible with"
                          "num_envs = 1.")
     self.act = act_fn
     self.buffer = buffer
     self.device = device
     self.learning_starts = learning_starts
     self.update_period = update_period
     self._ob = None
Beispiel #5
0
def rl_record(env, actor, nepisodes, outfile, device='cpu', fps=30):
    """Compute episode stats for an environment and actor.

    If the environment has an EpisodeInfo Wrapper, rl_record will use that to
    determine episode termination.
    Args:
        env: A Gym environment
        actor: A callable whose input is an observation and output has a
               '.action' attribute.
        nepisodes: The number of episodes to run.
        outfile: Where to save the video.
        device: The device which contains the actor.
        fps: The frame rate of the video.
    Returns:
        A dict of episode stats

    """
    if nepisodes == 0:
        return
    env = ensure_vec_env(env)
    tmpdir = os.path.join(tempfile.gettempdir(),
                          'video_' + str(time.monotonic()))
    os.makedirs(tmpdir)
    id = 0
    actor = Actor(actor, device)
    episodes = 0

    def write_ims(ims, id):
        for im in ims:
            imwrite(os.path.join(tmpdir, '{:05d}.png'.format(id)), im)
            id += 1
        return id

    while episodes < nepisodes:
        obs = env.reset()
        nenv = min(env.num_envs, nepisodes)
        _dones = np.zeros(nenv, dtype=np.bool)
        ims = [[] for _ in range(nenv)]

        # collect images
        try:
            rgbs = env.get_images()
        except Exception as e:
            logger.log(e)
            logger.log("Error while rendering.")
            return
        for i in range(nenv):
            ims[i].append(rgbs[i])

        # rollout episodes
        while not np.all(_dones):
            obs, r_, dones, infos = env.step(actor(obs))
            for i, _ in enumerate(dones):
                if 'episode_info' in infos[i]:
                    dones[i] = infos[i]['episode_info']['done']

            # collect images
            try:
                rgbs = env.get_images()
            except Exception:
                logger.log("Error while rendering.")
                return
            for i in range(nenv):
                if not _dones[i]:
                    ims[i].append(rgbs[i])
            _dones = np.logical_or(dones[:nenv], _dones)

        # save images
        for i in range(nenv):
            if episodes < nepisodes:
                id = write_ims(ims[i], id)
                ims[i] = []
                episodes += 1

    sp.call([
        'ffmpeg', '-r',
        str(fps), '-f', 'image2', '-i',
        os.path.join(tmpdir, '%05d.png'), '-vcodec', 'libx264', '-pix_fmt',
        'yuv420p',
        os.path.join(tmpdir, 'out.mp4')
    ])
    sp.call(['mv', os.path.join(tmpdir, 'out.mp4'), outfile])
    sp.call(['rm', '-rf', tmpdir])