Exemplo n.º 1
0
    def learn(self,
              total_timesteps,
              callback=None,
              log_interval=1,
              tb_log_name="A2C",
              reset_num_timesteps=True):

        new_tb_log = self._init_num_timesteps(reset_num_timesteps)
        callback = self._init_callback(callback)

        with SetVerbosity(self.verbose), TensorboardWriter(
                self.graph, self.tensorboard_log, tb_log_name,
                new_tb_log) as writer:
            self._setup_learn()
            self.learning_rate_schedule = Scheduler(
                initial_value=self.learning_rate,
                n_values=total_timesteps,
                schedule=self.lr_schedule)

            t_start = time.time()
            callback.on_training_start(locals(), globals())

            for update in range(1, total_timesteps // self.n_batch + 1):

                callback.on_rollout_start()
                # true_reward is the reward without discount
                rollout = self.runner.run(callback)
                # unpack
                obs, states, rewards, masks, actions, values, ep_infos, true_reward = rollout
                callback.update_locals(locals())
                callback.on_rollout_end()

                # Early stopping due to the callback
                if not self.runner.continue_training:
                    break

                self.ep_info_buf.extend(ep_infos)
                _, value_loss, policy_entropy = self._train_step(
                    obs, states, rewards, masks, actions, values,
                    self.num_timesteps // self.n_batch, writer)
                n_seconds = time.time() - t_start
                fps = int((update * self.n_batch) / n_seconds)

                if writer is not None:
                    total_episode_reward_logger(
                        self.episode_reward,
                        true_reward.reshape((self.n_envs, self.n_steps)),
                        masks.reshape((self.n_envs, self.n_steps)), writer,
                        self.num_timesteps)

                if self.verbose >= 1 and (update % log_interval == 0
                                          or update == 1):
                    explained_var = explained_variance(values, rewards)
                    logger.record_tabular("nupdates", update)
                    logger.record_tabular("total_timesteps",
                                          self.num_timesteps)
                    logger.record_tabular("fps", fps)
                    logger.record_tabular("policy_entropy",
                                          float(policy_entropy))
                    logger.record_tabular("value_loss", float(value_loss))
                    logger.record_tabular("explained_variance",
                                          float(explained_var))
                    if len(self.ep_info_buf) > 0 and len(
                            self.ep_info_buf[0]) > 0:
                        logger.logkv(
                            'ep_reward_mean',
                            safe_mean([
                                ep_info['r'] for ep_info in self.ep_info_buf
                            ]))
                        #logger.logkv('ep_len_mean', safe_mean([ep_info['l'] for ep_info in self.ep_info_buf]))
                    logger.dump_tabular()

        callback.on_training_end()
        return self
Exemplo n.º 2
0
    def _run_(self):
        """
        Run a learning step of the model

        :return:
            - observations: (np.ndarray) the observations
            - rewards: (np.ndarray) the rewards
            - masks: (numpy bool) whether an episode is over or not
            - actions: (np.ndarray) the actions
            - values: (np.ndarray) the value function output
            - negative log probabilities: (np.ndarray)
            - states: (np.ndarray) the internal states of the recurrent policies
            - infos: (dict) the extra information of the model
        """
        # mb stands for minibatch

        self.obs[:] = self.env.reset()

        mb_obs, mb_rewards, mb_actions, mb_values, mb_dones, mb_neglogpacs = [], [], [], [], [], [], []
        mb_states = self.states
        ep_infos = []
        scores = [[] for _ in range(self.n_envs)]
        normals = [[] for _ in range(self.n_envs)]
        attacks = [[] for _ in range(self.n_envs)]
        precisions = [[] for _ in range(self.n_envs)]
        telapsed = 0
        for _ in range(self.n_steps):
            actions, values, self.states, neglogpacs = self.model.step(self.obs, self.states, self.dones)
            mb_obs.append(self.obs.copy())
            mb_actions.append(actions)
            mb_values.append(values)
            mb_neglogpacs.append(neglogpacs)
            mb_dones.append(self.dones)
            clipped_actions = actions
            # Clip the actions to avoid out of bound error
            if isinstance(self.env.action_space, gym.spaces.Box):
                clipped_actions = np.clip(actions, self.env.action_space.low, self.env.action_space.high)
            tnow = time.time()
            self.obs[:], rewards, self.dones, infos = self.env.step(clipped_actions)
            telapsed += time.time() - tnow
            for ri in range(self.n_envs):
                scores[ri].append(infos[ri]['r'])
                normals[ri].append(infos[ri]['n'])
                attacks[ri].append(infos[ri]['a'])
                precisions[ri].append(infos[ri]['p'])

            self.model.num_timesteps += self.n_envs

            if self.callback is not None:
                # Abort training early
                self.callback.update_locals(locals())
                if self.callback.on_step() is False:
                    self.continue_training = False
                    # Return dummy values
                    return [None] * 9
            mb_rewards.append(rewards)

        print('Time step: {0}'.format(telapsed / self.n_steps))

        for s, n, a, p in zip(scores, normals, attacks, precisions):
            maybe_ep_info = {'r': safe_mean(s), 'n': safe_mean(n), 'a': safe_mean(a), 'p': safe_mean(p)}
            if maybe_ep_info is not None:
                ep_infos.append(maybe_ep_info)

        # batch of steps to batch of rollouts

        mb_obs = np.asarray(mb_obs, dtype=self.obs.dtype)
        mb_rewards = np.asarray(mb_rewards, dtype=np.float32)
        mb_actions = np.asarray(mb_actions)
        mb_values = np.asarray(mb_values, dtype=np.float32)
        mb_neglogpacs = np.asarray(mb_neglogpacs, dtype=np.float32)
        mb_dones = np.asarray(mb_dones, dtype=np.bool)
        last_values = self.model.value(self.obs, self.states, self.dones)

        # discount/bootstrap off value fn

        mb_advs = np.zeros_like(mb_rewards)
        true_reward = np.copy(mb_rewards)
        last_gae_lam = 0
        for step in reversed(range(self.n_steps)):
            if step == self.n_steps - 1:
                nextnonterminal = 1.0 - self.dones
                nextvalues = last_values
            else:
                nextnonterminal = 1.0 - mb_dones[step + 1]
                nextvalues = mb_values[step + 1]
            delta = mb_rewards[step] + self.gamma * nextvalues * nextnonterminal - mb_values[step]
            mb_advs[step] = last_gae_lam = delta + self.gamma * self.lam * nextnonterminal * last_gae_lam
        mb_returns = mb_advs + mb_values

        mb_obs, mb_returns, mb_dones, mb_actions, mb_values, mb_neglogpacs, true_reward = map(
            swap_and_flatten, (mb_obs, mb_returns, mb_dones, mb_actions, mb_values, mb_neglogpacs, true_reward)
        )

        return mb_obs, mb_returns, mb_dones, mb_actions, mb_values, mb_neglogpacs, mb_states, ep_infos, true_reward
Exemplo n.º 3
0
    def _run(self):
        """
        Run a learning step of the model

        :return:
            - observations: (np.ndarray) the observations
            - rewards: (np.ndarray) the rewards
            - masks: (numpy bool) whether an episode is over or not
            - actions: (np.ndarray) the actions
            - values: (np.ndarray) the value function output
            - negative log probabilities: (np.ndarray)
            - states: (np.ndarray) the internal states of the recurrent policies
            - infos: (dict) the extra information of the model
        """
        # mb stands for minibatch

        self.mb_obs = [[] for _ in range(self.n_envs * self.n_runs)]
        self.mb_obs_next = [[] for _ in range(self.n_envs * self.n_runs)]
        self.mb_actions = [[] for _ in range(self.n_envs * self.n_runs)]
        self.mb_values = [[] for _ in range(self.n_envs * self.n_runs)]
        self.mb_neglogpacs = [[] for _ in range(self.n_envs * self.n_runs)]
        self.mb_dones = [[] for _ in range(self.n_envs * self.n_runs)]
        self.mb_rewards = [[] for _ in range(self.n_envs * self.n_runs)]
        self.scores = [[] for _ in range(self.n_envs * self.n_runs)]

        ep_infos = []

        for i in range(self.n_runs):
            self.obs[i * self.n_envs : (i + 1) * self.n_envs] = self.env.reset()

            # run steps in different threads

            threads = []
            for env_idx in range(self.n_envs):
                th = Thread(target=self._run_one, args=(env_idx, i * self.n_envs + env_idx))
                th.start()
                threads.append(th)
            for th in threads:
                th.join()

        # combine data gathered into batches

        mb_obs = [np.stack([self.mb_obs[idx][step] for idx in range(self.n_envs * self.n_runs)]) for step in range(self.n_steps)]
        mb_obs_next = [np.stack([self.mb_obs_next[idx][step] for idx in range(self.n_envs * self.n_runs)]) for step in range(self.n_steps)]
        mb_rewards = [np.hstack([self.mb_rewards[idx][step] for idx in range(self.n_envs * self.n_runs)]) for step in range(self.n_steps)]
        mb_actions = [np.hstack([self.mb_actions[idx][step] for idx in range(self.n_envs * self.n_runs)]) for step in range(self.n_steps)]
        mb_values = [np.hstack([self.mb_values[idx][step] for idx in range(self.n_envs * self.n_runs)]) for step in range(self.n_steps)]
        mb_neglogpacs = [np.hstack([self.mb_neglogpacs[idx][step] for idx in range(self.n_envs * self.n_runs)]) for step in range(self.n_steps)]
        mb_dones = [np.hstack([self.mb_dones[idx][step] for idx in range(self.n_envs * self.n_runs)]) for step in range(self.n_steps)]
        mb_scores = [np.vstack([self.scores[idx][step] for step in range(self.n_steps)]) for idx in range(self.n_envs * self.n_runs)]
        mb_states = self.states
        self.dones = np.array(self.dones)

        for scores_in_env in mb_scores:
            maybe_ep_info = {
                'r': safe_mean(scores_in_env[:, 0]),
                'n': safe_mean(scores_in_env[:, 1]),
                'a': safe_mean(scores_in_env[:, 2]),
                'p': safe_mean(scores_in_env[:, 3])
            }
            ep_infos.append(maybe_ep_info)

        # batch of steps to batch of rollouts

        mb_obs = np.asarray(mb_obs, dtype=self.obs.dtype)
        mb_obs_next = np.asarray(mb_obs_next, dtype=self.obs.dtype)
        mb_rewards = np.asarray(mb_rewards, dtype=np.float32)
        mb_actions = np.asarray(mb_actions)
        mb_values = np.asarray(mb_values, dtype=np.float32)
        mb_neglogpacs = np.asarray(mb_neglogpacs, dtype=np.float32)
        mb_dones = np.asarray(mb_dones, dtype=np.bool)
        last_values = self.model.value(self.obs, self.states, self.dones)

        # discount/bootstrap off value fn

        mb_advs = np.zeros_like(mb_rewards)
        true_reward = np.copy(mb_rewards)
        last_gae_lam = 0
        for step in reversed(range(self.n_steps)):
            if step == self.n_steps - 1:
                nextnonterminal = 1.0 - self.dones
                nextvalues = last_values
            else:
                nextnonterminal = 1.0 - mb_dones[step + 1]
                nextvalues = mb_values[step + 1]
            delta = mb_rewards[step] + self.gamma * nextvalues * nextnonterminal - mb_values[step]
            mb_advs[step] = last_gae_lam = delta + self.gamma * self.lam * nextnonterminal * last_gae_lam
        mb_returns = mb_advs + mb_values

        mb_obs, mb_obs_next, mb_returns, mb_dones, mb_actions, mb_values, mb_neglogpacs, true_reward = map(
            swap_and_flatten, (mb_obs, mb_obs_next, mb_returns, mb_dones, mb_actions, mb_values, mb_neglogpacs, true_reward)
        )

        # reset data

        self.mb_obs = [[] for _ in range(self.n_envs)]
        self.mb_obs_next = [[] for _ in range(self.n_envs)]
        self.mb_actions = [[] for _ in range(self.n_envs)]
        self.mb_values = [[] for _ in range(self.n_envs)]
        self.mb_neglogpacs = [[] for _ in range(self.n_envs)]
        self.mb_dones = [[] for _ in range(self.n_envs)]
        self.mb_rewards = [[] for _ in range(self.n_envs)]
        self.scores = [[] for _ in range(self.n_envs)]

        return mb_obs, mb_obs_next, mb_returns, mb_dones, mb_actions, mb_values, mb_neglogpacs, mb_states, ep_infos, true_reward
Exemplo n.º 4
0
    def learn(self, total_timesteps, callback=None, log_interval=1, tb_log_name="PPO2", reset_num_timesteps=True):

        # Transform to callable if needed

        self.learning_rate = get_schedule_fn(self.learning_rate)
        self.cliprange = get_schedule_fn(self.cliprange)
        cliprange_vf = get_schedule_fn(self.cliprange_vf)

        new_tb_log = self._init_num_timesteps(reset_num_timesteps)
        callback = self._init_callback(callback)

        with SetVerbosity(self.verbose), TensorboardWriter(self.graph, self.tensorboard_log, tb_log_name, new_tb_log) as writer:
            self._setup_learn()

            t_first_start = time.time()
            n_updates = total_timesteps // self.n_batch

            callback.on_training_start(locals(), globals())

            for update in range(1, n_updates + 1):

                assert self.n_batch % self.nminibatches == 0, ("The number of minibatches (`nminibatches`) "
                                                               "is not a factor of the total number of samples "
                                                               "collected per rollout (`n_batch`), "
                                                               "some samples won't be used."
                                                               )
                batch_size = self.n_batch // self.nminibatches
                t_start = time.time()
                frac = 1.0 - (update - 1.0) / n_updates
                lr_now = self.learning_rate(frac)
                cliprange_now = self.cliprange(frac)
                cliprange_vf_now = cliprange_vf(frac)

                callback.on_rollout_start()
                # true_reward is the reward without discount
                rollout = self.runner.run(callback)

                # Unpack

                obs, obs_next, returns, masks, actions, values, neglogpacs, states, ep_infos, true_reward = rollout

                #for item in [obs, obs_next, returns, masks, actions, values, neglogpacs, states, true_reward]:
                #    if item is not None:
                #        print(item.shape)
                #print(ep_infos)

                callback.on_rollout_end()

                # Early stopping due to the callback
                if not self.runner.continue_training:
                    break

                self.ep_info_buf.extend(ep_infos)
                mb_loss_vals = []
                if states is None:  # nonrecurrent version
                    update_fac = max(self.n_batch // self.nminibatches // self.noptepochs, 1)
                    inds = np.arange(self.n_batch)
                    for epoch_num in range(self.noptepochs):
                        np.random.shuffle(inds)
                        for start in range(0, self.n_batch, batch_size):
                            timestep = self.num_timesteps // update_fac + ((epoch_num * self.n_batch + start) // batch_size)
                            end = start + batch_size
                            mbinds = inds[start:end]
                            slices = (arr[mbinds] for arr in (obs, obs_next, returns, true_reward, masks, actions, values, neglogpacs))
                            mb_loss_vals.append(self._train_step(lr_now, cliprange_now, *slices, writer=writer, update=timestep, cliprange_vf=cliprange_vf_now))
                else:  # recurrent version
                    update_fac = max(self.n_batch // self.nminibatches // self.noptepochs // self.n_steps, 1)
                    assert self.n_envs % self.nminibatches == 0
                    env_indices = np.arange(self.n_envs)
                    flat_indices = np.arange(self.n_envs * self.n_steps).reshape(self.n_envs, self.n_steps)
                    envs_per_batch = batch_size // self.n_steps
                    for epoch_num in range(self.noptepochs):
                        np.random.shuffle(env_indices)
                        for start in range(0, self.n_envs, envs_per_batch):
                            timestep = self.num_timesteps // update_fac + ((epoch_num * self.n_envs + start) // envs_per_batch)
                            end = start + envs_per_batch
                            mb_env_inds = env_indices[start:end]
                            mb_flat_inds = flat_indices[mb_env_inds].ravel()
                            slices = (arr[mb_flat_inds] for arr in (obs, returns, masks, actions, values, neglogpacs))
                            mb_states = states[mb_env_inds]
                            mb_loss_vals.append(self._train_step(lr_now, cliprange_now, *slices, update=timestep,
                                                                 writer=writer, states=mb_states,
                                                                 cliprange_vf=cliprange_vf_now))

                loss_vals = np.mean(mb_loss_vals, axis=0)
                t_now = time.time()
                fps = int(self.n_batch / (t_now - t_start))

                if writer is not None:
                    total_episode_reward_logger(self.episode_reward,
                                                true_reward.reshape((self.n_envs * self.n_runs, self.n_steps)),
                                                masks.reshape((self.n_envs * self.n_runs, self.n_steps)),
                                                writer, self.num_timesteps)

                if self.verbose >= 1 and (update % log_interval == 0 or update == 1):
                    explained_var = explained_variance(values, returns)
                    logger.logkv("serial_timesteps", update * self.n_steps)
                    logger.logkv("n_updates", update)
                    logger.logkv("total_timesteps", self.num_timesteps)
                    logger.logkv("fps", fps)
                    logger.logkv("explained_variance", float(explained_var))
                    if len(self.ep_info_buf) > 0 and len(self.ep_info_buf[0]) > 0:
                        logger.logkv('ep_reward_mean', safe_mean([ep_info['r'] for ep_info in self.ep_info_buf]))
                        logger.logkv('ep_normal_mean', safe_mean([ep_info['n'] for ep_info in self.ep_info_buf]))
                        logger.logkv('ep_attack_mean', safe_mean([ep_info['a'] for ep_info in self.ep_info_buf]))
                        logger.logkv('ep_precision_mean', safe_mean([ep_info['p'] for ep_info in self.ep_info_buf]))
                    logger.logkv('time_elapsed', t_start - t_first_start)
                    for (loss_val, loss_name) in zip(loss_vals, self.loss_names):
                        logger.logkv(loss_name, loss_val)
                    logger.dumpkvs()

            callback.on_training_end()
            return self
Exemplo n.º 5
0
    def learn(self,
              total_timesteps,
              callback=None,
              log_interval=1,
              tb_log_name="ACKTR",
              reset_num_timesteps=True):

        new_tb_log = self._init_num_timesteps(reset_num_timesteps)
        callback = self._init_callback(callback)

        with SetVerbosity(self.verbose), TensorboardWriter(
                self.graph, self.tensorboard_log, tb_log_name,
                new_tb_log) as writer:
            self._setup_learn()
            self.n_batch = self.n_envs * self.n_steps

            self.learning_rate_schedule = Scheduler(
                initial_value=self.learning_rate,
                n_values=total_timesteps,
                schedule=self.lr_schedule)

            # FIFO queue of the q_runner thread is closed at the end of the learn function.
            # As a result, it needs to be redefinied at every call
            with self.graph.as_default():
                with tf.compat.v1.variable_scope(
                        "kfac_apply",
                        reuse=self.trained,
                        custom_getter=tf_util.outer_scope_getter(
                            "kfac_apply")):
                    # Some of the variables are not in a scope when they are create
                    # so we make a note of any previously uninitialized variables
                    tf_vars = tf.compat.v1.global_variables()
                    is_uninitialized = self.sess.run([
                        tf.compat.v1.is_variable_initialized(var)
                        for var in tf_vars
                    ])
                    old_uninitialized_vars = [
                        v for (v, f) in zip(tf_vars, is_uninitialized) if not f
                    ]

                    self.train_op, self.q_runner = self.optim.apply_gradients(
                        list(zip(self.grads_check, self.params)))

                    # then we check for new uninitialized variables and initialize them
                    tf_vars = tf.compat.v1.global_variables()
                    is_uninitialized = self.sess.run([
                        tf.compat.v1.is_variable_initialized(var)
                        for var in tf_vars
                    ])
                    new_uninitialized_vars = [
                        v for (v, f) in zip(tf_vars, is_uninitialized)
                        if not f and v not in old_uninitialized_vars
                    ]

                    if len(new_uninitialized_vars) != 0:
                        self.sess.run(
                            tf.compat.v1.variables_initializer(
                                new_uninitialized_vars))

            self.trained = True

            t_start = time.time()
            coord = tf.train.Coordinator()
            if self.q_runner is not None:
                enqueue_threads = self.q_runner.create_threads(self.sess,
                                                               coord=coord,
                                                               start=True)
            else:
                enqueue_threads = []

            callback.on_training_start(locals(), globals())

            for update in range(1, total_timesteps // self.n_batch + 1):

                callback.on_rollout_start()

                # pytype:disable=bad-unpacking
                # true_reward is the reward without discount
                if isinstance(self.runner, PPO2Runner):
                    # We are using GAE
                    rollout = self.runner.run(callback)
                    obs, returns, masks, actions, values, _, states, ep_infos, true_reward = rollout
                else:
                    rollout = self.runner.run(callback)
                    obs, states, returns, masks, actions, values, ep_infos, true_reward = rollout
                # pytype:enable=bad-unpacking
                callback.update_locals(locals())
                callback.on_rollout_end()

                # Early stopping due to the callback
                if not self.runner.continue_training:
                    break

                self.ep_info_buf.extend(ep_infos)
                policy_loss, value_loss, policy_entropy = self._train_step(
                    obs, states, returns, masks, actions, values,
                    self.num_timesteps // (self.n_batch + 1), writer)
                n_seconds = time.time() - t_start
                fps = int((update * self.n_batch) / n_seconds)

                if writer is not None:
                    total_episode_reward_logger(
                        self.episode_reward,
                        true_reward.reshape((self.n_envs, self.n_steps)),
                        masks.reshape((self.n_envs, self.n_steps)), writer,
                        self.num_timesteps)

                if self.verbose >= 1 and (update % log_interval == 0
                                          or update == 1):
                    explained_var = explained_variance(values, returns)
                    logger.record_tabular("nupdates", update)
                    logger.record_tabular("total_timesteps",
                                          self.num_timesteps)
                    logger.record_tabular("fps", fps)
                    logger.record_tabular("policy_entropy",
                                          float(policy_entropy))
                    logger.record_tabular("policy_loss", float(policy_loss))
                    logger.record_tabular("value_loss", float(value_loss))
                    logger.record_tabular("explained_variance",
                                          float(explained_var))
                    if len(self.ep_info_buf) > 0 and len(
                            self.ep_info_buf[0]) > 0:
                        logger.logkv(
                            'ep_reward_mean',
                            safe_mean([
                                ep_info['r'] for ep_info in self.ep_info_buf
                            ]))
                        logger.logkv(
                            'ep_normal_mean',
                            safe_mean([
                                ep_info['n'] for ep_info in self.ep_info_buf
                            ]))
                        logger.logkv(
                            'ep_attack_mean',
                            safe_mean([
                                ep_info['a'] for ep_info in self.ep_info_buf
                            ]))
                        logger.logkv(
                            'ep_precision_mean',
                            safe_mean([
                                ep_info['p'] for ep_info in self.ep_info_buf
                            ]))
                    logger.dump_tabular()

            coord.request_stop()
            coord.join(enqueue_threads)

        callback.on_training_end()
        return self
Exemplo n.º 6
0
    def learn(self,
              total_timesteps,
              callback=None,
              log_interval=4,
              tb_log_name="SAC",
              reset_num_timesteps=True,
              replay_wrapper=None):

        new_tb_log = self._init_num_timesteps(reset_num_timesteps)
        callback = self._init_callback(callback)

        if replay_wrapper is not None:
            self.replay_buffer = replay_wrapper(self.replay_buffer)

        with SetVerbosity(self.verbose), TensorboardWriter(self.graph, self.tensorboard_log, tb_log_name, new_tb_log) \
                as writer:

            self._setup_learn()

            # Transform to callable if needed
            self.learning_rate = get_schedule_fn(self.learning_rate)
            # Initial learning rate
            current_lr = self.learning_rate(1)

            start_time = time.time()
            episode_rewards = [0.0]
            episode_successes = []
            if self.action_noise is not None:
                self.action_noise.reset()
            obs = self.env.reset()
            # Retrieve unnormalized observation for saving into the buffer
            if self._vec_normalize_env is not None:
                obs_ = self._vec_normalize_env.get_original_obs().squeeze()

            n_updates = 0
            infos_values = []

            callback.on_training_start(locals(), globals())
            callback.on_rollout_start()

            for step in range(total_timesteps):
                # Before training starts, randomly sample actions
                # from a uniform distribution for better exploration.
                # Afterwards, use the learned policy
                # if random_exploration is set to 0 (normal setting)
                if self.num_timesteps < self.learning_starts or np.random.rand(
                ) < self.random_exploration:
                    # actions sampled from action space are from range specific to the environment
                    # but algorithm operates on tanh-squashed actions therefore simple scaling is used
                    unscaled_action = self.env.action_space.sample()
                    action = scale_action(self.action_space, unscaled_action)
                else:
                    action = self.policy_tf.step(
                        obs[None], deterministic=False).flatten()
                    # Add noise to the action (improve exploration,
                    # not needed in general)
                    if self.action_noise is not None:
                        action = np.clip(action + self.action_noise(), -1, 1)
                    # inferred actions need to be transformed to environment action_space before stepping
                    unscaled_action = unscale_action(self.action_space, action)

                assert action.shape == self.env.action_space.shape

                new_obs, reward, done, info = self.env.step(unscaled_action)

                self.num_timesteps += 1

                # Only stop training if return value is False, not when it is None. This is for backwards
                # compatibility with callbacks that have no return statement.
                callback.update_locals(locals())
                if callback.on_step() is False:
                    break

                # Store only the unnormalized version
                if self._vec_normalize_env is not None:
                    new_obs_ = self._vec_normalize_env.get_original_obs(
                    ).squeeze()
                    reward_ = self._vec_normalize_env.get_original_reward(
                    ).squeeze()
                else:
                    # Avoid changing the original ones
                    obs_, new_obs_, reward_ = obs, new_obs, reward

                # Store transition in the replay buffer.
                self.replay_buffer_add(obs_, action, reward_, new_obs_, done,
                                       info)
                obs = new_obs
                # Save the unnormalized observation
                if self._vec_normalize_env is not None:
                    obs_ = new_obs_

                # Retrieve reward and episode length if using Monitor wrapper
                #maybe_ep_info = info.get('episode')
                #if maybe_ep_info is not None:
                self.ep_info_buf.extend([{'r': reward}])

                if writer is not None:
                    # Write reward per episode to tensorboard
                    ep_reward = np.array([reward_]).reshape((1, -1))
                    ep_done = np.array([done]).reshape((1, -1))
                    tf_util.total_episode_reward_logger(
                        self.episode_reward, ep_reward, ep_done, writer,
                        self.num_timesteps)

                if self.num_timesteps % self.train_freq == 0:
                    callback.on_rollout_end()

                    mb_infos_vals = []
                    # Update policy, critics and target networks
                    for grad_step in range(self.gradient_steps):
                        # Break if the warmup phase is not over
                        # or if there are not enough samples in the replay buffer
                        if not self.replay_buffer.can_sample(self.batch_size) \
                           or self.num_timesteps < self.learning_starts:
                            break
                        n_updates += 1
                        # Compute current learning_rate
                        frac = 1.0 - step / total_timesteps
                        current_lr = self.learning_rate(frac)
                        # Update policy and critics (q functions)
                        mb_infos_vals.append(
                            self._train_step(step, writer, current_lr))
                        # Update target network
                        if (step +
                                grad_step) % self.target_update_interval == 0:
                            # Update target network
                            self.sess.run(self.target_update_op)
                    # Log losses and entropy, useful for monitor training
                    if len(mb_infos_vals) > 0:
                        infos_values = np.mean(mb_infos_vals, axis=0)

                    callback.on_rollout_start()

                episode_rewards[-1] += reward_
                if done:
                    if self.action_noise is not None:
                        self.action_noise.reset()
                    if not isinstance(self.env, VecEnv):
                        obs = self.env.reset()
                    episode_rewards.append(0.0)

                    maybe_is_success = info.get('is_success')
                    if maybe_is_success is not None:
                        episode_successes.append(float(maybe_is_success))

                if len(episode_rewards[-101:-1]) == 0:
                    mean_reward = -np.inf
                else:
                    mean_reward = round(
                        float(np.mean(episode_rewards[-101:-1])), 1)

                num_episodes = len(episode_rewards)
                # Display training infos
                #if self.verbose >= 1 and done and log_interval is not None and len(episode_rewards) % log_interval == 0:
                if (step + 1) % (self.batch_size * log_interval) == 0:
                    fps = int(step / (time.time() - start_time))
                    #logger.logkv("episodes", num_episodes)
                    #logger.logkv("mean 100 episode reward", mean_reward)
                    if len(self.ep_info_buf) > 0 and len(
                            self.ep_info_buf[0]) > 0:
                        logger.logkv(
                            'ep_reward_mean',
                            safe_mean([
                                ep_info['r'] for ep_info in self.ep_info_buf
                            ]))
                        #logger.logkv('eplenmean', safe_mean([ep_info['l'] for ep_info in self.ep_info_buf]))
                    logger.logkv("n_updates", n_updates)
                    logger.logkv("current_lr", current_lr)
                    logger.logkv("fps", fps)
                    logger.logkv('time_elapsed', int(time.time() - start_time))
                    if len(episode_successes) > 0:
                        logger.logkv("success rate",
                                     np.mean(episode_successes[-100:]))
                    if len(infos_values) > 0:
                        for (name, val) in zip(self.infos_names, infos_values):
                            logger.logkv(name, val)
                    logger.logkv("total_timesteps", self.num_timesteps)
                    logger.dumpkvs()
                    # Reset infos:
                    infos_values = []
            callback.on_training_end()
            return self