def sample(self, idx=None, envx=None):
        assert self.can_sample()
        idx = np.random.randint(self._num_in_buffer,
                                size=self.num_envs) if idx is None else idx
        num_envs = self.num_envs

        envx = np.arange(num_envs) if envx is None else envx

        take = lambda x: self.take(x, idx, envx
                                   )  # for i in range(num_envs)], axis = 0)

        # (nstep, num_envs)
        states = self.take_block(self.state_block, idx, envx, 0)
        next_states = self.take_block(self.state_block, idx, envx, 1)
        actions = take(self.actions)
        mus = take(self.mus)
        rewards = take(self.rewards)
        dones = take(self.dones)
        timeouts = take(self.timeouts)
        infos = take(self.infos)

        samples = Dataset(dtype=self.dtype,
                          max_size=self.num_envs * self.n_steps)
        steps = [
            states, actions, next_states, mus, rewards, dones, timeouts, infos
        ]
        steps = list(map(flatten_first_2_dims, steps))
        samples.extend(np.rec.fromarrays(steps, dtype=self.dtype))
        return samples
    def compute_advantage(self, vfn: BaseVFunction, samples: Dataset):
        n_steps = len(samples) // self.n_envs
        samples = samples.reshape((n_steps, self.n_envs))
        if not self.add_absorbing_state:
            use_next_vf = ~samples.done
            use_next_adv = ~(samples.done | samples.timeout)
        else:
            absorbing_mask = samples.mask == Mask.ABSORBING
            use_next_vf = np.ones_like(samples.done)
            use_next_adv = ~(absorbing_mask | samples.timeout)

        next_values = vfn.get_values(samples.reshape(-1).next_state).reshape(
            n_steps, self.n_envs)
        values = vfn.get_values(samples.reshape(-1).state).reshape(
            n_steps, self.n_envs)
        advantages = np.zeros((n_steps, self.n_envs), dtype=np.float32)
        last_gae_lambda = 0

        for t in reversed(range(n_steps)):
            delta = samples[t].reward + self.gamma * next_values[
                t] * use_next_vf[t] - values[t]
            advantages[
                t] = last_gae_lambda = delta + self.gamma * self.lambda_ * last_gae_lambda * use_next_adv[
                    t]
            # next_values = values[t]
        return advantages.reshape(-1), values.reshape(-1)
Example #3
0
    def compute_advantage(self,
                          vfn: BaseVFunction,
                          samples: Dataset,
                          task=None):
        n_steps = len(samples) // self.n_envs
        samples = samples.reshape((n_steps, self.n_envs))
        use_next_vf = ~samples.done
        use_next_adv = ~(samples.done | samples.timeout)

        next_values = vfn.get_values(samples[-1].next_state)
        values = vfn.get_values(samples.reshape(-1).state).reshape(
            n_steps, self.n_envs)
        advantages = np.zeros((n_steps, self.n_envs), dtype=np.float32)
        advantages_shadow = np.zeros((n_steps, self.n_envs), dtype=np.float32)

        next_values_all = np.zeros_like(values, dtype=np.float32)
        next_values_all[:-1] = values[1:] * (1.0 - samples.done[1:])
        next_values_all[-1] = next_values
        td = self.gamma * next_values_all * use_next_vf - values

        coef_mat = np.zeros([n_steps, n_steps, self.n_envs], np.float32)
        coef_mat_returns = np.zeros([n_steps, n_steps, self.n_envs],
                                    np.float32)
        #print ('use_next_adv:', use_next_adv.shape)
        tmp = []
        for i in range(n_steps):
            coef = np.ones([self.n_envs], dtype=np.float32)
            coef_r = np.ones([self.n_envs], dtype=np.float32)
            coef_mat[i][i] = coef
            coef_mat_returns[i][i] = coef_r
            if i == n_steps - 1: tmp.append(coef)
            for j in range(i + 1, n_steps):
                coef *= self.gamma * self.lambda_ * use_next_adv[
                    j -
                    1]  #~samples.done[j] #* use_next_adv[j] #~samples.done[j]
                if i == n_steps - 1: tmp.append(coef)
                coef_mat[i][j] = coef
                #TODO
                coef_r *= self.gamma * use_next_vf[j - 1]  #~samples.done[j]
                coef_mat_returns[i][j] = coef_r
        coef_mat = np.transpose(coef_mat, (2, 0, 1))
        coef_mat_returns = np.transpose(coef_mat_returns, (2, 0, 1))

        reward_ctrl_list = np.array(self.reward_ctrl_list, dtype=np.float32)
        reward_state_list = np.array(self.reward_state_list, dtype=np.float32)

        last_gae_lambda = 0
        next_values = vfn.get_values(samples[-1].next_state)
        for t in reversed(range(n_steps)):
            delta = samples[t].reward + self.gamma * next_values * use_next_vf[
                t] - values[t]
            advantages[
                t] = last_gae_lambda = delta + self.gamma * self.lambda_ * last_gae_lambda * use_next_adv[
                    t]
            next_values = values[t]

        advantages_params = None
        return advantages.reshape(-1), advantages_params, values.reshape(
            -1
        ), td, coef_mat, coef_mat_returns, reward_ctrl_list, reward_state_list, self.begin_mark
Example #4
0
    def run(self,
            policy: Actor,
            n_samples: int,
            classifier=None,
            stochastic=True):
        ep_infos = []
        n_steps = n_samples // self.n_envs
        assert n_steps * self.n_envs == n_samples
        dataset = Dataset(self._dtype, n_samples)

        if self._actions is None:
            self._actions = self._get_action(policy, self._states, stochastic)
        for T in range(n_steps):
            unscaled_actions = self._actions.copy()
            if self.rescale_action:
                lo, hi = self.env.action_space.low, self.env.action_space.high
                actions = (lo + (unscaled_actions + 1.) * 0.5 * (hi - lo))
            else:
                actions = unscaled_actions

            next_states, rewards, dones, infos = self.env.step(actions)
            if classifier is not None:
                rewards = classifier.get_rewards(self._states,
                                                 unscaled_actions, next_states)
            next_actions = self._get_action(policy, next_states, stochastic)
            dones = dones.astype(bool)
            self._returns += rewards
            self._n_steps += 1
            timeouts = self._n_steps == self.max_steps

            steps = [
                self._states.copy(), unscaled_actions,
                next_states.copy(),
                next_actions.copy(), rewards, dones, timeouts,
                self._n_steps.copy()
            ]
            dataset.extend(np.rec.fromarrays(steps, dtype=self._dtype))

            indices = np.where(dones | timeouts)[0]
            if len(indices) > 0:
                next_states = next_states.copy()
                next_states[indices] = self.env.partial_reset(indices)
                next_actions = next_actions.copy()
                next_actions[indices] = self._get_action(
                    policy, next_states, stochastic)[indices]
                for index in indices:
                    infos[index]['episode'] = {
                        'return': self._returns[index],
                        'length': self._n_steps[index]
                    }
                self._n_steps[indices] = 0
                self._returns[indices] = 0.

            self._states = next_states.copy()
            self._actions = next_actions.copy()
            ep_infos.extend(
                [info['episode'] for info in infos if 'episode' in info])

        return dataset, ep_infos
Example #5
0
    def run(self, policy: BasePolicy, n_samples: int):
        ep_infos = []
        self.rewards_params_list = []
        self.reward_ctrl_list = []
        self.reward_state_list = []
        n_steps = n_samples // self.n_envs
        assert n_steps * self.n_envs == n_samples
        dataset = Dataset(self._dtype, n_samples)
        self.begin_mark = np.zeros((n_steps, self.n_envs), dtype=np.float32)
        start = np.array([0 for _ in range(self.n_envs)])

        for T in range(n_steps):
            unscaled_actions = policy.get_actions(self._states)
            if self.rescale_action:
                lo, hi = self.env.action_space.low, self.env.action_space.high
                actions = (lo + (unscaled_actions + 1.) * 0.5 * (hi - lo))
            else:
                actions = unscaled_actions

            next_states, rewards, dones, infos = self.env.step(actions)
            self.reward_ctrl_list.append([i['reward_ctrl'] for i in infos])
            self.reward_state_list.append([i['reward_state'] for i in infos])
            dones = dones.astype(bool)
            self._returns += rewards
            self._n_steps += 1
            timeouts = self._n_steps == self.max_steps

            steps = [
                self._states.copy(), unscaled_actions,
                next_states.copy(), rewards, dones, timeouts
            ]
            dataset.extend(np.rec.fromarrays(steps, dtype=self._dtype))

            indices = np.where(dones | timeouts)[0]
            if len(indices) > 0:
                next_states = next_states.copy()
                next_states[indices] = self.env.partial_reset(indices)
                for index in indices:
                    infos[index]['episode'] = {'return': self._returns[index]}
                    self.begin_mark[start[index]][index] = 1
                self._n_steps[indices] = 0
                self._returns[indices] = 0.
                start[indices] = T + 1

            self._states = next_states.copy()
            ep_infos.extend(
                [info['episode'] for info in infos if 'episode' in info])

        if len(ep_infos) == 0:
            print("oops!")
            assert (False)
        return dataset, ep_infos
    def run(self, policy: BasePolicy, n_samples: int):
        ep_infos = []
        n_steps = n_samples // self.n_envs
        assert n_steps * self.n_envs == n_samples
        dataset = Dataset(self._dtype, n_samples)

        for T in range(n_steps):
            unscaled_actions = policy.get_actions(self._states)
            if self.rescale_action:
                lo, hi = self.env.action_space.low, self.env.action_space.high
                actions = lo + (unscaled_actions + 1.) * 0.5 * (hi - lo)
            else:
                actions = unscaled_actions

            next_states, rewards, dones, infos = self.env.step(actions)
            dones = dones.astype(bool)
            self._returns += rewards
            self._n_steps += 1
            timeouts = self._n_steps == self.max_steps
            terminals = np.copy(dones)
            for e, info in enumerate(infos):
                if self.partial_episode_bootstrapping and info.get(
                        'TimeLimit.truncated', False):
                    terminals[e] = False

            steps = [
                self._states.copy(), unscaled_actions,
                next_states.copy(), rewards, terminals, timeouts
            ]
            dataset.extend(np.rec.fromarrays(steps, dtype=self._dtype))

            indices = np.where(dones | timeouts)[0]
            if len(indices) > 0:
                next_states = next_states.copy()
                next_states[indices] = self.env.partial_reset(indices)
                for index in indices:
                    infos[index]['episode'] = {
                        'return': self._returns[index],
                        'length': self._n_steps[index]
                    }
                self._n_steps[indices] = 0
                self._returns[indices] = 0.

            self._states = next_states.copy()
            ep_infos.extend(
                [info['episode'] for info in infos if 'episode' in info])

        return dataset, ep_infos
Example #7
0
    def compute_advantage(self, vfn: BaseVFunction, samples: Dataset):
        n_steps = len(samples) // self.n_envs
        samples = samples.reshape((n_steps, self.n_envs))
        use_next_vf = ~samples.done
        use_next_adv = ~(samples.done | samples.timeout)

        next_values = vfn.get_values(samples[-1].next_state)
        values = vfn.get_values(samples.reshape(-1).state).reshape(
            n_steps, self.n_envs)
        advantages = np.zeros((n_steps, self.n_envs), dtype=np.float32)
        last_gae_lambda = 0

        for t in reversed(range(n_steps)):
            delta = samples[t].reward + self.gamma * next_values * use_next_vf[
                t] - values[t]
            advantages[
                t] = last_gae_lambda = delta + self.gamma * self.lambda_ * last_gae_lambda * use_next_adv[
                    t]
            next_values = values[t]
        return advantages.reshape(-1), values.reshape(-1)
Example #8
0
 def train_vf(self, dataset: Dataset):
     for _ in range(self.n_vf_iters):
         for subset in dataset.iterator(64):
             self.get_vf_loss(subset.state,
                              subset.return_,
                              fetch='train_vf vf_loss')
     for param in self.parameters():
         param.invalidate()
     vf_loss = self.get_vf_loss(dataset.state,
                                dataset.return_,
                                fetch='vf_loss')
     return vf_loss
    def run(self, policy: BasePolicy, n_samples: int, stochastic=True):
        assert self.n_envs == 1, 'Only support 1 env.'
        ep_infos = []
        n_steps = n_samples // self.n_envs
        assert n_steps * self.n_envs == n_samples
        dataset = Dataset(self._dtype, n_samples)

        for t in range(n_samples):
            if stochastic:
                unscaled_action = policy.get_actions(self._state[None])[0]
            else:
                unscaled_action = policy.get_actions(self._state[None],
                                                     fetch='actions_mean')[0]
            if self.rescale_action:
                lo, hi = self.env.action_space.low, self.env.action_space.high
                action = lo + (unscaled_action + 1.) * 0.5 * (hi - lo)
            else:
                action = unscaled_action

            next_state, reward, done, info = self.env.step(action)
            self._return += reward
            self._n_step += 1
            timeout = self._n_step == self.max_steps
            if not done or timeout:
                mask = Mask.NOT_DONE.value
            else:
                mask = Mask.DONE.value

            if self.add_absorbing_state and done and self._n_step < self.max_steps:
                next_state = self.env.get_absorbing_state()
            steps = [
                self._state.copy(), unscaled_action,
                next_state.copy(), reward, done, timeout, mask,
                np.copy(self._n_step)
            ]
            dataset.append(np.rec.array(steps, dtype=self._dtype))

            if done | timeout:
                if self.add_absorbing_state and self._n_step < self.max_steps:
                    action = np.zeros(self.env.action_space.shape)
                    absorbing_state = self.env.get_absorbing_state()
                    steps = [
                        absorbing_state, action, absorbing_state, 0.0, False,
                        False, Mask.ABSORBING.value
                    ]
                    dataset.append(np.rec.array(steps, dtype=self._dtype))
                    # t += 1
                next_state = self.env.reset()
                ep_infos.append({
                    'return': self._return,
                    'length': self._n_step
                })
                self._n_step = 0
                self._return = 0.
            self._state = next_state.copy()

        return dataset, ep_infos
    def store_episode(self, data: Dataset):
        data = data.reshape([self.n_steps, self.num_envs])

        if self.state_block is None:
            self.obs_shape, self.obs_dtype = list(
                data.state.shape[2:]), data.state.dtype
            self.state_block = np.empty([self._size], dtype=object)
            self.actions = np.empty([self._size] + list(data.action.shape),
                                    dtype=data.action.dtype)
            self.rewards = np.empty([self._size] + list(data.reward.shape),
                                    dtype=data.reward.dtype)
            self.mus = np.empty([self._size] + list(data.mu.shape),
                                dtype=data.mu.dtype)
            self.dones = np.empty([self._size] + list(data.done.shape),
                                  dtype=np.bool)
            self.timeouts = np.empty([self._size] + list(data.timeout.shape),
                                     dtype=np.bool)
            self.infos = np.empty([self._size] + list(data.info.shape),
                                  dtype=object)

        terminals = data.done | data.timeout
        if self.stacked_frame:
            self.state_block[self._next_idx] = StackedFrame(
                data.state, data.next_state, terminals)
        else:
            self.state_block[self._next_idx] = StateBlock(
                data.state, data.next_state, terminals)
        self.actions[self._next_idx] = data.action
        self.rewards[self._next_idx] = data.reward
        self.mus[self._next_idx] = data.mu
        self.dones[self._next_idx] = data.done
        self.timeouts[self._next_idx] = data.timeout
        self.infos[self._next_idx] = data.info

        self._next_idx = (self._next_idx + 1) % self._size
        self._total_size += 1
        self._num_in_buffer = min(self._size, self._num_in_buffer + 1)
Example #11
0
    def train(self, ent_coef, samples, advantages, values):
        returns = advantages + values
        advantages = (advantages - advantages.mean()) / np.maximum(
            advantages.std(), 1e-8)
        assert np.isfinite(advantages).all()
        self.sync_old()
        old_loss, grad, dist_std, mean_kl, dist_mean = self.get_loss(
            samples.state,
            samples.action,
            samples.next_state,
            advantages,
            ent_coef,
            fetch='loss flat_grad dist_std mean_kl dist_mean')

        if np.allclose(grad, 0):
            logger.info('Zero gradient, not updating...')
            return

        def fisher_vec_prod(x):
            return self.get_hessian_vec_prod(
                samples.state, samples.action, x,
                samples.next_state) + self.cg_damping * x

        assert np.isfinite(grad).all()
        nat_grad, cg_residual = conj_grad(fisher_vec_prod,
                                          grad,
                                          n_iters=self.n_cg_iters,
                                          verbose=False)
        grad_norm = np.linalg.norm(grad)
        nat_grad_norm = np.linalg.norm(nat_grad)

        assert np.isfinite(nat_grad).all()

        old_params = self.flatten.get_flat()
        step_size = np.sqrt(2 * self.max_kl /
                            nat_grad.dot(fisher_vec_prod(nat_grad)))

        for _ in range(10):
            new_params = old_params + nat_grad * step_size
            self.flatten.set_flat(new_params)
            loss, mean_kl = self.get_loss(samples.state,
                                          samples.action,
                                          samples.next_state,
                                          advantages,
                                          ent_coef,
                                          fetch='loss mean_kl')
            improve = loss - old_loss
            if not np.isfinite([loss, mean_kl]).all():
                logger.info('Got non-finite loss.')
            elif mean_kl > self.max_kl * 1.5:
                logger.info(
                    'Violated kl constraints, shrinking step... mean_kl = %.6f, max_kl = %.6f',
                    mean_kl, self.max_kl)
            elif improve < 0:
                logger.info(
                    "Surrogate didn't improve, shrinking step... %.6f => %.6f",
                    old_loss, loss)
            else:
                break
            step_size *= 0.5
        else:
            logger.info("Couldn't find a good step.")
            self.flatten.set_flat(old_params)
        for param in self.policy.parameters():
            param.invalidate()

        # optimize value function
        vf_dataset = Dataset.fromarrays(
            [samples.state, samples.action, returns],
            dtype=[('state', ('f8', self.dim_state)),
                   ('action', ('f8', self.dim_action)), ('return_', 'f8')])
        vf_loss = self.train_vf(vf_dataset)

        info = dict(dist_mean=dist_mean,
                    dist_std=dist_std,
                    vf_loss=np.mean(vf_loss),
                    grad_norm=grad_norm,
                    nat_grad_norm=nat_grad_norm,
                    cg_residual=cg_residual,
                    step_size=step_size)
        return info
Example #12
0
File: main.py Project: liziniu/RLX
def main():
    FLAGS.set_seed()
    FLAGS.freeze()

    env = make_env(FLAGS.env.id,
                   FLAGS.env.env_type,
                   num_env=FLAGS.env.num_env,
                   seed=FLAGS.seed,
                   log_dir=FLAGS.log_dir,
                   rescale_action=FLAGS.env.rescale_action)
    env_eval = make_env(FLAGS.env.id,
                        FLAGS.env.env_type,
                        num_env=4,
                        seed=FLAGS.seed + 1000,
                        log_dir=FLAGS.log_dir)
    dim_state = env.observation_space.shape[0]
    dim_action = env.action_space.shape[0]

    actor = Actor(dim_state,
                  dim_action,
                  hidden_sizes=FLAGS.TD3.actor_hidden_sizes)
    critic = Critic(dim_state,
                    dim_action,
                    hidden_sizes=FLAGS.TD3.critic_hidden_sizes)
    td3 = TD3(dim_state,
              dim_action,
              actor=actor,
              critic=critic,
              **FLAGS.TD3.algo.as_dict())

    tf.get_default_session().run(tf.global_variables_initializer())
    td3.update_actor_target(tau=0.0)
    td3.update_critic_target(tau=0.0)

    dtype = gen_dtype(env, 'state action next_state reward done timeout')
    buffer = Dataset(dtype=dtype, max_size=FLAGS.TD3.buffer_size)
    saver = nn.ModuleDict({'actor': actor, 'critic': critic})
    print(saver)

    n_steps = np.zeros(env.n_envs)
    n_returns = np.zeros(env.n_envs)

    train_returns = collections.deque(maxlen=40)
    train_lengths = collections.deque(maxlen=40)
    states = env.reset()
    time_st = time.time()
    for t in range(FLAGS.TD3.total_timesteps):
        if t < FLAGS.TD3.init_random_steps:
            actions = np.array(
                [env.action_space.sample() for _ in range(env.n_envs)])
        else:
            raw_actions = actor.get_actions(states)
            noises = np.random.normal(loc=0.,
                                      scale=FLAGS.TD3.explore_noise,
                                      size=raw_actions.shape)
            actions = np.clip(raw_actions + noises, -1, 1)
        next_states, rewards, dones, infos = env.step(actions)
        n_returns += rewards
        n_steps += 1
        timeouts = n_steps == env.max_episode_steps
        terminals = np.copy(dones)
        for e, info in enumerate(infos):
            if info.get('TimeLimit.truncated', False):
                terminals[e] = False

        transitions = [
            states, actions,
            next_states.copy(), rewards, terminals,
            timeouts.copy()
        ]
        buffer.extend(np.rec.fromarrays(transitions, dtype=dtype))

        indices = np.where(dones | timeouts)[0]
        if len(indices) > 0:
            next_states[indices] = env.partial_reset(indices)

            train_returns.extend(n_returns[indices])
            train_lengths.extend(n_steps[indices])
            n_returns[indices] = 0
            n_steps[indices] = 0
        states = next_states.copy()

        if t == 2000:
            assert env.n_envs == 1
            samples = buffer.sample(size=None, indices=np.arange(2000))
            masks = 1 - (samples.done | samples.timeout)[..., np.newaxis]
            masks = masks[:-1]
            assert np.allclose(samples.state[1:] * masks,
                               samples.next_state[:-1] * masks)

        if t >= FLAGS.TD3.init_random_steps:
            samples = buffer.sample(FLAGS.TD3.batch_size)
            train_info = td3.train(samples)
            if t % FLAGS.TD3.log_freq == 0:
                fps = int(t / (time.time() - time_st))
                train_info['fps'] = fps
                log_kvs(prefix='TD3',
                        kvs=dict(iter=t,
                                 episode=dict(
                                     returns=np.mean(train_returns)
                                     if len(train_returns) > 0 else 0.,
                                     lengths=int(
                                         np.mean(train_lengths)
                                         if len(train_lengths) > 0 else 0)),
                                 **train_info))

        if t % FLAGS.TD3.eval_freq == 0:
            eval_returns, eval_lengths = evaluate(actor,
                                                  env_eval,
                                                  deterministic=False)
            log_kvs(prefix='Evaluate',
                    kvs=dict(iter=t,
                             episode=dict(returns=np.mean(eval_returns),
                                          lengths=int(np.mean(eval_lengths)))))

        if t % FLAGS.TD3.save_freq == 0:
            np.save('{}/stage-{}'.format(FLAGS.log_dir, t), saver.state_dict())
            np.save('{}/final'.format(FLAGS.log_dir), saver.state_dict())

    np.save('{}/final'.format(FLAGS.log_dir), saver.state_dict())
Example #13
0
def main():
    FLAGS.set_seed()
    FLAGS.freeze()

    env = make_env(FLAGS.env.id,
                   FLAGS.env.env_type,
                   num_env=FLAGS.env.num_env,
                   seed=FLAGS.seed,
                   log_dir=FLAGS.log_dir,
                   rescale_action=FLAGS.env.rescale_action)
    env_eval = make_env(FLAGS.env.id,
                        FLAGS.env.env_type,
                        num_env=4,
                        seed=FLAGS.seed + 1000,
                        log_dir=FLAGS.log_dir)
    dim_state = env.observation_space.shape[0]
    dim_action = env.action_space.shape[0]

    actor = Actor(dim_state,
                  dim_action,
                  hidden_sizes=FLAGS.SAC.actor_hidden_sizes)
    critic = Critic(dim_state,
                    dim_action,
                    hidden_sizes=FLAGS.SAC.critic_hidden_sizes)
    target_entropy = FLAGS.SAC.target_entropy
    if target_entropy is None:
        target_entropy = -dim_action
    sac = SAC(dim_state,
              dim_action,
              actor=actor,
              critic=critic,
              target_entropy=target_entropy,
              **FLAGS.SAC.algo.as_dict())

    tf.get_default_session().run(tf.global_variables_initializer())
    sac.update_critic_target(tau=0.0)

    dtype = gen_dtype(env, 'state action next_state reward done')
    buffer = Dataset(dtype=dtype, max_size=FLAGS.SAC.buffer_size)
    saver = nn.ModuleDict({'actor': actor, 'critic': critic})
    print(saver)

    n_steps = np.zeros(env.n_envs)
    n_returns = np.zeros(env.n_envs)

    train_returns = collections.deque(maxlen=40)
    train_lengths = collections.deque(maxlen=40)
    states = env.reset()
    time_st = time.time()
    for t in range(FLAGS.SAC.total_timesteps):
        if t < FLAGS.SAC.init_random_steps:
            actions = np.array(
                [env.action_space.sample() for _ in range(env.n_envs)])
        else:
            actions = actor.get_actions(states)
        next_states, rewards, dones, infos = env.step(actions)
        n_returns += rewards
        n_steps += 1
        timeouts = n_steps == env.max_episode_steps
        terminals = np.copy(dones)
        for e, info in enumerate(infos):
            if FLAGS.SAC.peb and info.get('TimeLimit.truncated', False):
                terminals[e] = False

        transitions = [states, actions, next_states.copy(), rewards, terminals]
        buffer.extend(np.rec.fromarrays(transitions, dtype=dtype))

        indices = np.where(dones | timeouts)[0]
        if len(indices) > 0:
            next_states[indices] = env.partial_reset(indices)

            train_returns.extend(n_returns[indices])
            train_lengths.extend(n_steps[indices])
            n_returns[indices] = 0
            n_steps[indices] = 0
        states = next_states.copy()

        if t >= FLAGS.SAC.init_random_steps:
            samples = buffer.sample(FLAGS.SAC.batch_size)
            train_info = sac.train(samples)
            if t % FLAGS.SAC.log_freq == 0:
                fps = int(t / (time.time() - time_st))
                train_info['fps'] = fps
                log_kvs(prefix='SAC',
                        kvs=dict(iter=t,
                                 episode=dict(
                                     returns=np.mean(train_returns)
                                     if len(train_returns) > 0 else 0.,
                                     lengths=int(
                                         np.mean(train_lengths)
                                         if len(train_lengths) > 0 else 0)),
                                 **train_info))

        if t % FLAGS.SAC.eval_freq == 0:
            eval_returns, eval_lengths = evaluate(actor, env_eval)
            log_kvs(prefix='Evaluate',
                    kvs=dict(iter=t,
                             episode=dict(returns=np.mean(eval_returns),
                                          lengths=int(np.mean(eval_lengths)))))

        if t % FLAGS.SAC.save_freq == 0:
            np.save('{}/stage-{}'.format(FLAGS.log_dir, t), saver.state_dict())
            np.save('{}/final'.format(FLAGS.log_dir), saver.state_dict())

    np.save('{}/final'.format(FLAGS.log_dir), saver.state_dict())