Exemplo n.º 1
0
 def sample(self, policy, params=None, gamma=0.95, device='cpu'):
     episodes = BatchEpisodes(batch_size=self.batch_size,
                              gamma=gamma,
                              device=device)
     observations = self.envs.reset()
     observations = np.float32(observations)
     done_count = 0
     indexs = list(range(self.num_workers))
     next_index = self.num_workers
     while (done_count < self.batch_size):
         with torch.no_grad():
             observations_tensor = torch.from_numpy(observations).to(
                 device=device)
             actions_tensor = policy(observations_tensor,
                                     params=params).sample()
             actions = actions_tensor.cpu().numpy()
         new_observations, rewards, dones, _ = self.envs.step(actions)
         new_observations = np.float32(new_observations)
         rewards = np.float32(rewards)
         episodes.append(observations, actions, rewards, indexs)
         observations = new_observations
         for i, done in enumerate(dones):
             if done:
                 done_count = done_count + 1
                 if next_index >= self.batch_size or indexs[i] is None:
                     indexs[i] = None
                 else:
                     indexs[i] = next_index
                     next_index = next_index + 1
     return episodes
    def create_episodes(self,
                        gamma=0.95,
                        gae_lambda=1.0,
                        device='cpu'):
        # 初始化 episodes,用于保存 完整的轨迹数据
        # 将sample_trajectories函数采样 batch_size 个完整的轨迹保存至 episodes
        episodes = BatchEpisodes(batch_size=self.batch_size,
                                 gamma=gamma,
                                 device=device)
        episodes.log('_createdAt', datetime.now(timezone.utc))
        # episodes.log('process_name', self.name)

        #
        t0 = time.time()
        """
        ******************************************************************
        """
        for item in self.sample_trajectories():
            episodes.append(*item)
        episodes.log('duration', time.time() - t0)

        self.baseline.fit(episodes)
        episodes.compute_advantages(self.baseline,
                                    gae_lambda=gae_lambda,
                                    normalize=True)
        return episodes
Exemplo n.º 3
0
    def sample(self, policy, params=None, gamma=0.95, device='cpu'):
        episodes = BatchEpisodes(batch_size=self.batch_size,
                                 gamma=gamma,
                                 device=device)
        for i in range(self.batch_size):
            self.queue.put(i)
        for _ in range(self.num_workers):
            self.queue.put(None)
        observations, batch_ids = self.envs.reset()
        dones = [False]
        while (not all(dones)) or (not self.queue.empty()):

            with torch.no_grad():
                observations_tensor = torch.from_numpy(observations).to(
                    device=device)
                actions_tensor = policy(observations_tensor,
                                        params=params).sample()
                actions = actions_tensor.cpu().numpy()
            new_observations, rewards, dones, new_batch_ids, _ = self.envs.step(
                actions)
            # format: action [[]], observation [[]], rewards []
            # print('batch_ids: ', new_batch_ids)
            # print('ac {}, nobs {}, re {}'.format(actions, new_observations, rewards))
            episodes.append(observations, actions, rewards, batch_ids)
            observations, batch_ids = new_observations, new_batch_ids
            # print(not all(dones),(not self.queue.empty()))
        return episodes
Exemplo n.º 4
0
    def sample(self, policy, params=None, gamma=0.95, device='cpu'):
        episodes = BatchEpisodes(batch_size=self.batch_size, gamma=gamma, device=device)
        if params is None:
            actor_params = OrderedDict(policy.named_parameters())
        else:
            actor_params = dict()
            for name in params:
                actor_params[name] = params[name].detach()

        bid = 0
        while True:
            for i in range(NUM_AGENTS):
                # block if necessary until a free slot is available
                self.net_params_queues[i].put(actor_params)
                self.tasks_queues[i].put(self.tasks[i])
            for i in range(NUM_AGENTS):
                s_batch, a_batch, r_batch, terminal, info = self.exp_queues[i].get()
                episodes.append(s_batch, a_batch, r_batch, info, bid)
                bid += 1
                if bid == self.batch_size:
                    break
            if bid == self.batch_size:
                break

        return episodes
Exemplo n.º 5
0
 def sample(self, policy, params=None, gamma=0.95, device='cpu'):
     episodes = BatchEpisodes(batch_size=self.batch_size,
                              gamma=gamma,
                              device=device)
     for i in range(self.batch_size):
         self.queue.put(i)
     for _ in range(self.num_workers):
         self.queue.put(None)
     observations, batch_ids = self.envs.reset()
     running_state = ZFilter((observations.shape[1], ), clip=5)
     for index in range(observations.shape[0]):
         observations[index, :] = running_state(observations[index, :])
     dones = [False]
     while (not all(dones)) or (not self.queue.empty()):
         with torch.no_grad():
             observations_tensor = torch.from_numpy(observations).to(
                 device=device)
             actions_tensor = policy(observations_tensor,
                                     params=params).sample()
             actions = actions_tensor.cpu().numpy()
         new_observations, rewards, dones, new_batch_ids, _ = self.envs.step(
             actions)
         episodes.append(observations, actions, rewards, batch_ids)
         observations, batch_ids = new_observations, new_batch_ids
         for index in range(observations.shape[0]):
             observations[index, :] = running_state(observations[index, :])
     return episodes
Exemplo n.º 6
0
    def sample(self, policy, params=None, prey=None, gamma=0.95, device='cpu'):
        """Sample # of trajectories defined by "self.batch_size". The size of each
        trajectory is defined by the Gym env registration defined at:
        ./maml_rl/envs/__init__.py
        """
        assert prey is not None

        episodes = BatchEpisodes(batch_size=self.batch_size,
                                 gamma=gamma,
                                 device=device)
        for i in range(self.batch_size):
            self.queue.put(i)
        for _ in range(self.num_workers):
            self.queue.put(None)

        observations, worker_ids = self.envs.reset(
        )  # TODO reset needs to be fixed
        dones = [False]
        while (not all(dones)) or (not self.queue.empty()):
            with torch.no_grad():
                # Get observations
                predator_observations, prey_observations = self.split_observations(
                    observations)
                predator_observations_torch = torch.from_numpy(
                    predator_observations).to(device=device)
                prey_observations_torch = torch.from_numpy(
                    prey_observations).to(device=device)

                # Get actions
                predator_actions = policy(predator_observations_torch,
                                          params=params).sample()
                predator_actions = predator_actions.cpu().numpy()

                prey_actions = prey.select_deterministic_action(
                    prey_observations_torch)
                prey_actions = prey_actions.cpu().numpy()
            actions = np.concatenate([predator_actions, prey_actions], axis=1)
            new_observations, rewards, dones, new_worker_ids, _ = self.envs.step(
                copy.deepcopy(actions))
            assert np.sum(dones[:, 0]) == np.sum(dones[:, 1])
            dones = dones[:, 0]

            # Get new observations
            new_predator_observations, _ = self.split_observations(
                new_observations)

            # Get rewards
            predator_rewards = rewards[:, 0]
            episodes.append(predator_observations, predator_actions,
                            predator_rewards, worker_ids)
            observations, worker_ids = new_observations, new_worker_ids

        return episodes
Exemplo n.º 7
0
    def sample(self,
               policy,
               task,
               tree=None,
               params=None,
               gamma=0.95,
               device='cpu'):
        episodes = BatchEpisodes(batch_size=self.batch_size,
                                 gamma=gamma,
                                 device=device)
        for i in range(self.batch_size):
            self.queue.put(i)
        for _ in range(self.num_workers):
            self.queue.put(None)
        observations, batch_ids = self.envs.reset()
        dones = [False]
        while (not all(dones)) or (not self.queue.empty()):
            with torch.no_grad():
                input = torch.from_numpy(observations).float().to(
                    device=device)

                if self.env_name == 'AntPos-v0':
                    _, embedding = tree.forward(
                        torch.from_numpy(
                            task["position"]).float().to(device=device))
                if self.env_name == 'AntVel-v1':
                    _, embedding = tree.forward(
                        torch.from_numpy(np.array(
                            [task["velocity"]])).float().to(device=device))

                # print(input.shape)
                # print(embedding.shape)
                observations_tensor = torch.t(
                    torch.stack([
                        torch.cat([
                            torch.from_numpy(np.array(teo)).to(device=device),
                            embedding[0]
                        ], 0) for teo in input
                    ], 1))

                actions_tensor = policy(observations_tensor,
                                        task=task,
                                        params=params,
                                        enhanced=False).sample()
                actions = actions_tensor.cpu().numpy()
            new_observations, rewards, dones, new_batch_ids, _ = self.envs.step(
                actions)
            episodes.append(observations_tensor.cpu().numpy(), actions,
                            rewards, batch_ids)
            observations, batch_ids = new_observations, new_batch_ids
        return episodes
Exemplo n.º 8
0
    def sample_for_pretraining(self, tasks, first_order=False):
        """Sample trajectories (before and after the update of the parameters) 
        for all the tasks `tasks`.
        """
        episodes = BatchEpisodes(batch_size=0,
                                 gamma=self.gamma,
                                 device=self.device)
        # episodes.check()
        # exit()
        for task in tasks:
            self.sampler.reset_task(task)
            train_episodes = self.sampler.sample(self.policy,
                                                 gamma=self.gamma,
                                                 device=self.device)

            episodes.extend_episodes(train_episodes)
        return episodes
 def sample(self, policy, params=None, gamma=0.95, device='cpu'):
     episodes = BatchEpisodes(batch_size=self.batch_size, gamma=gamma, device=device)
     for i in range(self.batch_size):
         self.queue.put(i)
     for _ in range(self.num_workers):
         self.queue.put(None)
     observations, batch_ids = self.envs.reset()
     dones = [False]
     while (not all(dones)) or (not self.queue.empty()):
         with torch.no_grad():
             observations_tensor = torch.from_numpy(observations).to(device=device)
             actions_tensor = policy(observations_tensor, params=params).sample()
             actions = actions_tensor.cpu().numpy()
         new_observations, rewards, dones, new_batch_ids, _ = self.envs.step(actions)
         episodes.append(observations, actions, rewards, batch_ids)
         observations, batch_ids = new_observations, new_batch_ids
         #self.envs.workers[0].env.render("rgb_array")
     return episodes
Exemplo n.º 10
0
    def sample(self, policy, params=None, gamma=0.95):
        episodes = BatchEpisodes(batch_size=self.batch_size, gamma=gamma)
        for i in range(self.batch_size):
            self.queue.put(i)
        for _ in range(self.num_workers):
            self.queue.put(None)
        observations, batch_ids = self.envs.reset()
        dones = [False]
        while (not all(dones)) or (not self.queue.empty()):
            observations_tensor = observations
            actions_tensor = policy(observations_tensor,
                                    params=params).sample()
            with tf.device('/CPU:0'):
                actions = actions_tensor.numpy()
            new_observations, rewards, dones, new_batch_ids, _ = self.envs.step(
                actions)
            episodes.append(observations, actions, rewards, batch_ids)
            observations, batch_ids = new_observations, new_batch_ids

        return episodes
    def sample(self, policy, params=None, gamma=0.95, device='cpu'):
        episodes = BatchEpisodes(batch_size=self.batch_size,
                                 gamma=gamma,
                                 device=device)
        for i in range(self.batch_size):
            self.queue.put(i)
        for _ in range(self.num_workers):
            self.queue.put(None)

        observations, batch_ids = self.envs.reset()
        '''to float'''
        observations = observations.astype(np.float32)

        dones = [False]
        # when batch_size is full, not self.queue.empty() automatically change false
        # one batch contains a full episode (until done) of data
        while (not all(dones)) or (not self.queue.empty()):
            # while (not all(dones)) :
            # print(not all(dones),(not self.queue.empty()))
            with torch.no_grad():
                observations_tensor = torch.from_numpy(observations).to(
                    device=device)
                print('obs: ', observations)
                '''.float'''
                # .sample() is sample from distribution, output of NN is formed into a normal distribution
                actions_tensor = policy(observations_tensor.float(),
                                        params=params).sample()
                actions = actions_tensor.cpu().numpy()

            new_observations, rewards, dones, new_batch_ids, _ = self.envs.step(
                actions)
            print('batch_ids: ', new_batch_ids)
            print('rewards: ', rewards)
            episodes.append(observations, actions, rewards, batch_ids)
            observations, batch_ids = new_observations, new_batch_ids
            # print('actions: {}'.format(actions))
            # print('observations: {}'.format(new_observations))
            # print('rewards: {}'.format(rewards))

        print('!!!!!!!!!!!!!!!!!!!!!!!!!!!')
        return episodes
Exemplo n.º 12
0
    def sample(self, policy, params=None, gamma=0.95, device='cpu'):
        """

		:param policy:
		:param params:
		:param gamma:
		:param device:
		:return:
		"""
        episodes = BatchEpisodes(batch_size=self.batch_size,
                                 gamma=gamma,
                                 device=device)
        for i in range(self.batch_size):
            self.queue.put(i)
        for _ in range(self.num_workers):
            self.queue.put(None)

        observations, batch_ids = self.envs.reset()
        dones = [False]
        while (not all(dones)) or (
                not self.queue.empty()):  # if all done and queue is empty
            # for reinforcement learning, the forward process requires no-gradient
            with torch.no_grad():
                # convert observation to cuda
                # compute policy on cuda
                # convert action to cpu
                observations_tensor = torch.from_numpy(observations).to(
                    device=device)
                # forward via policy network
                # policy network will return Categorical(logits=logits)
                actions_tensor = policy(observations_tensor,
                                        params=params).sample()
                actions = actions_tensor.cpu().numpy()

            new_observations, rewards, dones, new_batch_ids, _ = self.envs.step(
                actions)
            # here is observations NOT new_observations, batch_ids NOT new_batch_ids
            episodes.append(observations, actions, rewards, batch_ids)
            observations, batch_ids = new_observations, new_batch_ids

        return episodes
Exemplo n.º 13
0
    def sample(self, policy, params=None, gamma=0.95, device='cpu'):
        episodes = BatchEpisodes(batch_size=self.batch_size, gamma=gamma, device=device)
        for i in range(self.batch_size):
            self.queue.put(i)
        for _ in range(self.num_workers):
            self.queue.put(None)
        observations, batch_ids = self.envs.reset()
        dones = [False]
        while (not all(dones)) or (not self.queue.empty()):
            with torch.no_grad():
                observations_tensor = torch.from_numpy(observations).to(device=device, dtype=torch.float32)
                actions_tensor = policy(observations_tensor, params=params).sample()
                actions = actions_tensor.cpu().numpy()
            new_observations, rewards, dones, new_batch_ids, infos = self.envs.step(actions)
            # info keys: reachDist, pickRew, epRew, goalDist, success, goal, task_name

            # NOTE: last infos will be absent if batch_size % num_workers != 0

            episodes.append(observations, actions, rewards, batch_ids, infos)
            observations, batch_ids = new_observations, new_batch_ids
        return episodes
Exemplo n.º 14
0
    def sample(self, policy, params=None, gamma=0.95, device=None):
        if device is None:
            device = self.args.device

        episodes = BatchEpisodes(batch_size=self.batch_size,
                                 gamma=gamma,
                                 device=device)

        assert self.batch_size % self.num_processes == 0, "for looping to work"
        episodes_per_process = self.batch_size // self.num_processes

        for i_episode_per_process in range(episodes_per_process):
            batch_ids = [(i_episode_per_process * self.num_processes) + p
                         for p in range(self.num_processes)]
            obs_tensor = self.envs.reset()
            self.rewarder.reset()
            self.rewarder.append(
                obs_tensor
            )  # one extra append at end of for loop, but that's okay

            for t in range(self.args.episode_length):
                with torch.no_grad():
                    actions_tensor = policy(obs_tensor.to(device),
                                            params=params).sample()
                new_obs_tensor, _, _, info_raw = self.envs.step(actions_tensor)
                rewards_tensor, rewards_info = self.rewarder.calculate_reward(
                    obs_tensor, actions_tensor)

                episodes.append(obs_tensor.cpu().numpy(),
                                actions_tensor.cpu().numpy(),
                                rewards_tensor.cpu().numpy(), batch_ids)
                self.rewarder.append(obs_tensor)
                obs_tensor = new_obs_tensor

                self._append_to_log(rewards_info, is_pre_update=params is None)

        self.rewarder.save_episodes(episodes, is_pre_update=params is None)
        return episodes
def test_batch_episodes_append():
    lengths = [2, 3, 7, 5, 13, 11, 17]
    envs = SyncVectorEnv([make_unittest_env(length) for length in lengths])
    episodes = BatchEpisodes(batch_size=len(lengths), gamma=0.95)

    observations = envs.reset()
    while not envs.dones.all():
        actions = [envs.single_action_space.sample() for _ in observations]
        new_observations, rewards, _, infos = envs.step(actions)
        episodes.append(observations, actions, rewards, infos['batch_ids'])
        observations = new_observations

    assert len(episodes) == 17
    assert episodes.lengths == lengths

    # Observations
    assert episodes.observations.shape == (17, 7, 64, 64, 3)
    assert episodes.observations.dtype == torch.float32
    for i in range(7):
        length = lengths[i]
        assert (episodes.observations[length:, i] == 0).all()

    # Actions
    assert episodes.actions.shape == (17, 7, 2)
    assert episodes.actions.dtype == torch.float32
    for i in range(7):
        length = lengths[i]
        assert (episodes.actions[length:, i] == 0).all()

    # Mask
    assert episodes.mask.shape == (17, 7)
    assert episodes.mask.dtype == torch.float32
    for i in range(7):
        length = lengths[i]
        assert (episodes.mask[:length, i] == 1).all()
        assert (episodes.mask[length:, i] == 0).all()
Exemplo n.º 16
0
    def sample(self, policy, params=None, gamma=0.95, device='cpu'):
        episodes = BatchEpisodes(batch_size=self.batch_size,
                                 gamma=gamma,
                                 device=device)
        for i in range(self.batch_size):
            self.queue.put(i)
        for _ in range(self.num_workers):
            self.queue.put(None)
        observations, batch_ids = self.envs.reset()
        dones = [False]

        # count = -1
        while (not all(dones)) or (not self.queue.empty()):
            # count = count + 1
            with torch.no_grad():
                observations_tensor = torch.from_numpy(observations).to(
                    device=device)
                actions_tensor = policy(observations_tensor,
                                        params=params).sample()
                actions = actions_tensor.cpu().numpy()
            new_observations, rewards, dones, new_batch_ids, _ = self.envs.step(
                actions)
            # if count <2:
            # print("\ndones: ", dones)
            # print("info: ", new_batch_ids)
            # # print(new_observations.shape)
            # print("robot position: ", new_observations[:,:2])
            # print("goal: ", new_observations[:, 4:6])

            new_hid_observations = self.envs.get_peds()
            # new_hid_observations = np.array([[-1,-1], [1,-1], [1,1], [-1,1]])

            episodes.append(observations, new_hid_observations, actions,
                            rewards, batch_ids)
            observations, batch_ids = new_observations, new_batch_ids
        return episodes
Exemplo n.º 17
0
    def inner_loss_ppo_noUpdate(self,
                                episodes,
                                first_order,
                                params=None,
                                ent_coef=0,
                                vf_coef=0,
                                nenvs=1):
        """Compute the inner loss for the one-step gradient update. The inner
        loss is PPO with clipped ratio = new_pi/old_pi.
        Can make cliprange adaptable.
        nenvs = number of workers. nsteps defined in env
        """
        #episodes = [num of steps, num of episodes, obs_space]
        #NEED TO CHANGE ADVANTAGE CALCULATION TO CRITIC.
        losses = []

        self.logger.info("cliprange: " + str(self.cliprange) +
                         "; noptepochs: " + str(self.noptepochs) +
                         ";nminibaches: " + str(self.nminibatches) +
                         ";ppo_lr: " + str(self.ppo_lr))
        # Save the old parameters
        old_policy = copy.deepcopy(self.policy)
        old_params = parameters_to_vector(old_policy.parameters())

        #Need to take mini-batch of sampled examples to do gradient update a few times.
        nepisodes = episodes.observations.shape[1]
        nsteps = episodes.observations.shape[0]
        nbatch = nenvs * nsteps * nepisodes
        nbatch_train = nbatch // self.nminibatches
        mblossvals = []

        #Flattern the episode to [steps, observations]
        episodes_flat = BatchEpisodes(batch_size=nbatch)
        i = 0
        for ep in range(nepisodes):
            for step in range(nsteps):
                episodes_flat.append([episodes.observations[step][ep].numpy()],
                                     [episodes.actions[step][ep].numpy()],
                                     [episodes.returns[step][ep].numpy()],
                                     (i, ))
                i += 1

        inds = np.arange(nbatch)

        # For the case with linear baseline.
        vf_loss = -1

        for epoch in range(self.noptepochs):

            # Randomize the indexes
            #np.random.shuffle(inds)
            mb_vf_loss = torch.zeros(1)
            grad_norm = []
            # 0 to batch_size with batch_train_size step
            for start in range(0, nbatch, nbatch_train):

                mb_obs, mb_returns, mb_masks, mb_actions = [], [], [], []
                mb_episodes = BatchEpisodes(batch_size=nbatch_train)

                end = start + nbatch_train
                mbinds = inds[start:end]

                for i in range(len(mbinds)):
                    mb_obs.append(
                        episodes_flat.observations[0][mbinds[i]].numpy())
                    mb_returns.append(
                        episodes_flat.returns[0][mbinds[i]].numpy())
                    mb_masks.append(episodes_flat.mask[0][mbinds[i]].numpy())
                    mb_actions.append(
                        episodes_flat.actions[0][mbinds[i]].numpy())
                    mb_episodes.append([mb_obs[i]], [mb_actions[i]],
                                       [mb_returns[i]], (i, ))

                if self.baseline_type == 'linear':
                    values = self.baseline(mb_episodes)
                elif self.baseline_type == 'critic separate':
                    values = self.baseline(mb_episodes.observations)
                    # find value loss sum [(R-V(s))^2]
                    R = torch.FloatTensor(np.array(mb_returns))
                    mb_vf_loss = (((values - R)**2).mean()) + mb_vf_loss

                #values = self.baseline(mb_episodes)

                advantages = mb_episodes.gae(values, tau=self.tau)
                advantages_unnorm = advantages
                advantages = weighted_normalize(advantages.type(torch.float32),
                                                weights=torch.ones(
                                                    1, advantages.shape[1]))

                mb_returns_sum = np.sum(mb_returns)
                self.logger.info("iter: " + "epoch:" + str(epoch) + "; mb:" +
                                 str(start / nbatch_train))
                self.logger.info("mb returns: " + str(mb_returns_sum))

                pi = self.policy(mb_episodes.observations)
                log_probs = pi.log_prob(mb_episodes.actions)

                #reload old policy.
                vector_to_parameters(old_params, old_policy.parameters())
                pi_old = old_policy(mb_episodes.observations)

                log_probs_old = pi_old.log_prob(mb_episodes.actions)

                if log_probs.dim() > 2:
                    log_probs_old = torch.sum(log_probs_old, dim=2)
                    log_probs = torch.sum(log_probs, dim=2)

                ratio = torch.exp(log_probs - log_probs_old)

                self.logger.info("max pi: ")
                self.logger.info(torch.max(pi.mean))

                for x in ratio[0][:10]:
                    if x > 1E5 or x < 1E-5:
                        #pdb.set_trace()
                        self.logger.info("ratio too large or too small.")
                        self.logger.info(ratio[0][:10])

                self.logger.info("policy ratio: ")
                self.logger.info(ratio[0][:10])

                #loss function
                pg_losses = -advantages * ratio
                pg_losses2 = -advantages * torch.clamp(
                    ratio, 1.0 - self.cliprange, 1.0 + self.cliprange)

                # Final PG loss
                pg_loss = weighted_mean(torch.max(pg_losses, pg_losses2),
                                        weights=torch.ones(
                                            1, advantages.shape[1]))

                self.logger.debug("policy mu weights: ")
                self.logger.debug(self.policy.mu.weight)

                sum_adv = torch.sum(advantages_unnorm).numpy()
                self.logger.info("unnormalized advantages: " + str(sum_adv))

                # Total loss
                loss = pg_loss

                self.logger.info("max_action: " + str(np.max(mb_actions)))
                self.logger.info("max_action index: " +
                                 str(np.argmax(mb_actions)))

                # Save the old parameters
                old_params = parameters_to_vector(self.policy.parameters())
                losses.append(loss)

        self.logger.info("inner loss for each mb and epoch: ")
        self.logger.info(mblossvals)
        return torch.mean(torch.stack(losses, dim=0))
Exemplo n.º 18
0
    def _create_episodes(self,
                         params=None,
                         gamma=0.95,
                         gae_lambda=1.0,
                         device='cpu'):
        episodes = BatchEpisodes(batch_size=self.batch_size,
                                 gamma=gamma,
                                 device=device)
        episodes.log('_createdAt', datetime.now(timezone.utc))
        episodes.log('process_name', self.name)

        t0 = time.time()
        for item in self._sample_trajectories(params=params):
            episodes.append(*item)
        episodes.log('duration', time.time() - t0)

        self.baseline.fit(episodes)
        episodes.compute_advantages(self.baseline,
                                    gae_lambda=gae_lambda,
                                    normalize=True)
        return episodes
def test_batch_episodes(batch_size):
    episodes = BatchEpisodes(batch_size, gamma=0.95)
    assert episodes.batch_size == batch_size
    assert episodes.gamma == 0.95