Exemplo n.º 1
0
 def update(
     self,
     timestep: dm_env.TimeStep,
     action: base.Action,
     new_timestep: dm_env.TimeStep,
 ):
   """Adds a transition to the trajectory buffer and periodically does SGD."""
   if new_timestep.last():
     self._state = self._state._replace(rnn_state=self._initial_rnn_state)
   self._buffer.append(timestep, action, new_timestep)
   if self._buffer.full() or new_timestep.last():
     trajectory = self._buffer.drain()
     self._state = self._sgd_step(self._state, trajectory)
Exemplo n.º 2
0
    def add_first(self, timestep: dm_env.TimeStep):
        """Record the first observation of a trajectory."""
        if not timestep.first():
            raise ValueError(
                'adder.add_first with an initial timestep (i.e. one for '
                'which timestep.first() is True')

        # Record the next observation but leave the history buffer row open by
        # passing `partial_step=True`.
        self._writer.append(dict(observation=timestep.observation,
                                 start_of_episode=timestep.first()),
                            partial_step=True)
        self._add_first_called = True
Exemplo n.º 3
0
    def add(
        self,
        actions: Dict[str, types.NestedArray],
        next_timestep: dm_env.TimeStep,
        next_extras: Dict[str, types.NestedArray] = {},
    ) -> None:
        """Record an action and the following timestep."""
        if self._next_observations is None:
            raise ValueError(
                "adder.add_first must be called before adder.add.")

        discount = next_timestep.discount
        if next_timestep.last():
            # Terminal timesteps created by dm_env.termination() will have a scalar
            # discount of 0.0. This may not match the array shape / nested structure
            # of the previous timesteps' discounts. The below will match
            # next_timestep.discount's shape/structure to that of
            # self._buffer[-1].discount.
            if self._buffer and not tree.is_nested(next_timestep.discount):
                discount = tree.map_structure(
                    lambda d: np.broadcast_to(next_timestep.discount,
                                              np.shape(d)),
                    self._buffer[-1].discount,
                )

        self._buffer.append(
            Step(
                observations=self._next_observations,
                actions=actions,
                rewards=next_timestep.reward,
                discounts=discount,
                start_of_episode=self._start_of_episode,
                extras=self._next_extras
                if self._use_next_extras else next_extras,
            ))

        # Write the last "dangling" observation.
        if next_timestep.last():
            self._start_of_episode = False
            self._write()
            self._write_last()
            self.reset()
        else:
            # Record the next observation and write.
            # Possibly store next_extras
            if self._use_next_extras:
                self._next_extras = next_extras
            self._next_observations = next_timestep.observation
            self._start_of_episode = False
            self._write()
Exemplo n.º 4
0
    def update(self, old_step: dm_env.TimeStep, action: int,
               new_step: dm_env.TimeStep):
        """Takes in a transition from the environment."""
        #counting
        self._total_steps += 1
        if new_step.last():
            self._total_episodes += 1

        #adding data to the replay buffer
        if not old_step.last():
            self._replay.add(
                TransitionWithMaskAndNoise(
                    x0=old_step.observation,
                    a0=action,
                    r1=new_step.reward,
                    gamma1=new_step.discount,
                    x1=new_step.observation,
                ))

        # Keep gathering data
        if self._replay.size < self._min_replay_size:
            return

        #training step
        if self._total_steps % self._sgd_period == 0:
            x0, a0, r1, gamma1, x1 = self._replay.sample(self._batch_size)
            target = self._compute_target(r1, gamma1, x1)
            with tf.GradientTape() as tape:
                q0 = self._compute_prediction(x0, a0)
                # loss = -tf.reduce_sum( q0.log_prob(target) ) #q0 is not a distribution
                td_error = target - q0
                loss = tf.reduce_sum(tf.square(td_error))
            gradients = tape.gradient(loss,
                                      self._online_network.trainable_variables)
            self._optimizer.apply_gradients(
                zip(gradients, self._online_network.trainable_variables))

        #calculating posterior
        if self._total_steps % self._posterior_update_period == 0:
            self._compute_posterior()

        #sampling new out weights (mus)
        if self._total_steps % self._sample_out_mus_period == 0:
            self._sample_out_mus()

        #updating nets
        if self._total_steps % self._target_update_period == 0:
            self._update_target_nets()
Exemplo n.º 5
0
    def update(self, timestep: TimeStep, action: int,
               new_timestep: TimeStep) -> float:
        """True online Sarsa(λ) update using accumulating traces.
        See Sutton R., Barto, G. 2018. Reinforcement Learning: an Introduction, pp. 300-306"""
        g, l, a = timestep.discount, self.lamda, self.alpha
        x_m1, x = timestep.observation, new_timestep.observation
        q_m1, q = self.q_values(x_m1), self.q_values(x)

        # as we are using TD(λ) for control
        td = new_timestep.reward + g * q - q_m1

        # update traces using accumulation
        self.z = g * l * self.z + (1 -
                                   a * g * l * jnp.dot(self.z.T, x_m1)) * x_m1

        # update weights
        self.w = (self.w + a(td + q_m1 - self._q_old) * self.z + a *
                  (q_m1 - self._q_old) * x_m1)

        # update value estimate
        error = q - self._q_old
        self._q_old = int(not new_timestep.last()) * q

        # log
        self.logger.log({"loss": error})

        return error
Exemplo n.º 6
0
def get_expected_parallel_timesteps_1() -> TimeStep:
    return TimeStep(
        step_type=StepType.FIRST,
        reward={
            "agent_0": 0.0,
            "agent_1": 0.0,
            "agent_2": 0.0
        },
        discount={
            "agent_0": 1.0,
            "agent_1": 1.0,
            "agent_2": 1.0
        },
        observation={
            "agent_0":
            OLT(
                observation=[0.1, 0.3, 0.7],
                legal_actions=[1],
                terminal=[0.0],
            ),
            "agent_1":
            OLT(
                observation=[0.1, 0.3, 0.7],
                legal_actions=[1],
                terminal=[0.0],
            ),
            "agent_2":
            OLT(
                observation=[0.1, 0.3, 0.7],
                legal_actions=[1],
                terminal=[0.0],
            ),
        },
    )
Exemplo n.º 7
0
  def append(
      self,
      timestep: dm_env.TimeStep,
      action: base.Action,
      new_timestep: dm_env.TimeStep,
  ):
    """Appends an observation, action, reward, and discount to the buffer."""
    if self.full():
      raise ValueError('Cannot append; sequence buffer is full.')

    # Start a new sequence with an initial observation, if required.
    if self._needs_reset:
      self._t = 0
      self._observations[self._t] = timestep.observation
      self._needs_reset = False

    # Append (o, a, r, d) to the sequence buffer.
    print('sequece save obs type is', type(new_timestep.observation))
    self._observations[self._t + 1] = new_timestep.observation
    self._actions[self._t] = action
    self._rewards[self._t] = new_timestep.reward
    self._discounts[self._t] = new_timestep.discount
    self._t += 1

    # Don't accumulate sequences that cross episode boundaries.
    # It is up to the caller to drain the buffer in this case.
    if new_timestep.last():
      self._needs_reset = True
Exemplo n.º 8
0
    def add(
        self,
        timestep: dm_env.TimeStep,
        action: int,
        new_timestep: dm_env.TimeStep,
        preprocess=lambda x: x,
    ) -> None:
        #  start of a new episode
        if not self.collecting():
            self.trajectories.append(self._current)
            self._reset()

        # add new transition to the trajectory
        store = lambda x: device_array(x, device=self.device)
        self._current.observations[self._t] = preprocess(
            store(timestep.observation))
        self._current.actions[self._t] = store(int(action))
        self._current.rewards[self._t] = store(float(new_timestep.reward))
        self._current.discounts[self._t] = store(float(new_timestep.discount))
        self._t += 1

        # ready to store, just add final observation
        if not self.collecting():
            self._current.observations[self._t] = preprocess(
                jnp.array(timestep.observation, dtype=jnp.float32))
        # if not enough samples, and we can't sample the env anymore, reset
        elif new_timestep.last():
            self._reset()
        return
Exemplo n.º 9
0
 def policy(self, timestep: dm_env.TimeStep) -> base.Action:
   """Selects actions according to the latest softmax policy."""
   action, self.state = self._policy_fn(
       np.expand_dims(timestep.observation, 0),
       np.expand_dims(float(not timestep.first()), 0),
       self.state)
   return np.int32(action)
Exemplo n.º 10
0
    def update(self, step: dm_env.TimeStep, action: int,
               next_step: dm_env.TimeStep) -> None:
        """
        Adds experience to the replay memory, performs an optimization_step and updates the q_target neural network.
        Args:
            step(dm_env.TimeStep): Current observation from the environment
            action(int): The action that was performed by the agent.
            next_step(dm_env.TimeStep): Next observation from the environment
        Returns:
            None
        """

        observation = np.array(step.observation).flatten()
        next_observation = np.array(next_step.observation).flatten()
        done = next_step.last()
        exp = Experience(observation, action, next_step.reward,
                         next_step.discount, next_observation, 0, done)
        self.memory.add(exp)

        if self.memory.number_samples() < self.start_optimization:
            return

        if self.number_steps % self.update_qnet_every == 0:
            s0, a0, n_step_reward, discount, s1, _, dones, indices, weights = self.memory.sample_batch(
                self.batch_size)
            self.optimization_step(s0, a0, n_step_reward, discount, s1,
                                   indices, weights)

        if self.number_steps % self.update_target_every == 0:
            self.q_target.load_state_dict(self.qnet.state_dict())
        return
Exemplo n.º 11
0
  def update(
      self,
      timestep: dm_env.TimeStep,
      action: types.Action,
      next_timestep: dm_env.TimeStep,
  ) -> dm_env.TimeStep:
    # Add the true transition to replay.
    transition = [
        timestep.observation,
        action,
        next_timestep.reward,
        next_timestep.discount,
        next_timestep.observation,
    ]
    self._replay.add(transition)

    # Step the model to generate a synthetic transition.
    ts = self.step(action)

    # Copy the *true* state on update.
    self._state = next_timestep.observation.copy()

    if ts.last() or next_timestep.last():
      # Model believes that a termination has happened.
      # This will result in a crash during planning if the true environment
      # didn't terminate here as well. So, we indicate that we need a reset.
      self._needs_reset = True

    # Sample from replay and do SGD.
    if self._replay.size >= self._batch_size:
      batch = self._replay.sample(self._batch_size)
      self._step(*batch)

    return ts
Exemplo n.º 12
0
    def add(
        self,
        timestep: dm_env.TimeStep,
        action: int,
        new_timestep: dm_env.TimeStep,
        trace_decay: TraceDecay = 1.0,
        preprocess: Callable[[Observation], Observation] = lambda x: x,
    ) -> None:
        #  if buffer is full, prepare for new trajectory
        if self.full():
            self._reset()

        # add new transition to the trajectory
        self.trajectory.observations[self._t] = preprocess(
            jnp.array(timestep.observation, dtype=jnp.float32))
        self.trajectory.actions[self._t] = action
        self.trajectory.rewards[self._t] = new_timestep.reward
        self.trajectory.discounts[self._t] = new_timestep.discount
        self.trajectory.trace_decays[self._t] = trace_decay
        self._t += 1

        #  if we have enough transitions, add last obs and return
        if self.full():
            self.trajectory.observations[self._t] = preprocess(
                jnp.array(new_timestep.observation, dtype=jnp.float32))
        #  if we do not have enough transitions, and can't sample more, retry
        elif new_timestep.last():
            self._reset()
        return
Exemplo n.º 13
0
 def policy(self, timestep: dm_env.TimeStep) -> base.Action:
     """Selects actions according to the latest softmax policy."""
     observation = tf.expand_dims(timestep.observation, axis=0)
     mask = tf.expand_dims(float(not timestep.first()), axis=0)
     (logits, _), self._state = self._forward((observation, mask),
                                              self._state)
     return tf.random.categorical(logits, num_samples=1).numpy().squeeze()
Exemplo n.º 14
0
Arquivo: base.py Projeto: srsohn/mtsgi
    def observe(
        self,
        action: types.NestedArray,
        next_timestep: dm_env.TimeStep,
    ):
        action = np.expand_dims(action, axis=-1)
        next_rewards = np.expand_dims(next_timestep.reward, axis=-1)
        is_first = np.expand_dims(next_timestep.first(), axis=-1)  # for mask
        avg_rewards = np.take_along_axis(self._avg_rewards, action, axis=-1)
        counts = np.take_along_axis(self._counts, action, axis=-1)

        # Compute & update avg rewards.
        update_values = 1 / counts * (next_rewards - avg_rewards)
        next_avg_rewards = avg_rewards + np.where(
            is_first, 0, update_values)  # skip first timestep.
        np.put_along_axis(self._avg_rewards,
                          action,
                          values=next_avg_rewards,
                          axis=-1)

        # Update counts.
        np.put_along_axis(self._counts,
                          action,
                          values=counts + (1 - is_first),
                          axis=-1)
        self._total_counts += (1 - is_first).squeeze()
Exemplo n.º 15
0
 def _augment_observation(self, action: types.NestedArray,
                          reward: types.NestedArray,
                          timestep: dm_env.TimeStep) -> dm_env.TimeStep:
   oar = OAR(observation=timestep.observation,
             action=action,
             reward=reward)
   return timestep._replace(observation=oar)
Exemplo n.º 16
0
    def add(self,
            action: types.NestedArray,
            next_timestep: dm_env.TimeStep,
            extras: types.NestedArray = ()):
        """Record an action and the following timestep."""
        if self._next_observation is None:
            raise ValueError(
                'adder.add_first must be called before adder.add.')

        # Add the timestep to the buffer.
        self._buffer.append(
            Step(
                observation=self._next_observation,
                action=action,
                reward=next_timestep.reward,
                discount=next_timestep.discount,
                extras=extras,
            ))

        # Record the next observation and write.
        self._next_observation = next_timestep.observation
        self._write()

        # Write the last "dangling" observation.
        if next_timestep.last():
            self._write_last()
            self.reset()
Exemplo n.º 17
0
    def update(
        self,
        timestep: dm_env.TimeStep,
        action: base.Action,
        new_timestep: dm_env.TimeStep,
    ):
        """Update the agent: add transition to replay and periodically do SGD."""
        if new_timestep.last():
            self._active_head = np.random.randint(self._num_ensemble)

        self._replay.add(
            TransitionWithMaskAndNoise(
                o_tm1=timestep.observation,
                a_tm1=action,
                r_t=np.float32(new_timestep.reward),
                d_t=np.float32(new_timestep.discount),
                o_t=new_timestep.observation,
                m_t=self._rng.binomial(1, self._mask_prob,
                                       self._num_ensemble).astype(np.float32),
                z_t=self._rng.randn(self._num_ensemble).astype(np.float32) *
                self._noise_scale,
            ))

        if self._replay.size < self._min_replay_size:
            return

        if tf.math.mod(self._total_steps, self._sgd_period) == 0:
            minibatch = self._replay.sample(self._batch_size)
            minibatch = [tf.convert_to_tensor(x) for x in minibatch]
            self._step(minibatch)
Exemplo n.º 18
0
 def select_action(self, timestep: dm_env.TimeStep) -> base.Action:
   """Selects actions according to the latest softmax policy."""
   if timestep.first():
     self._state = self._network.initial_state(1)
     self._rollout_initial_state = self._network.initial_state(1)
   observation = tf.expand_dims(timestep.observation, axis=0)
   (logits, _), self._state = self._forward(observation, self._state)
   return tf.random.categorical(logits, num_samples=1).numpy().squeeze()
Exemplo n.º 19
0
 def _add_reward_noise(self, timestep: dm_env.TimeStep):
     if timestep.first():
         return timestep
     reward = timestep.reward + self._noise_scale * self._rng.randn()
     return dm_env.TimeStep(step_type=timestep.step_type,
                            reward=reward,
                            discount=timestep.discount,
                            observation=timestep.observation)
Exemplo n.º 20
0
 def _rescale_rewards(self, timestep: dm_env.TimeStep):
     if timestep.first():
         return timestep
     reward = timestep.reward * self._reward_scale
     return dm_env.TimeStep(step_type=timestep.step_type,
                            reward=reward,
                            discount=timestep.discount,
                            observation=timestep.observation)
Exemplo n.º 21
0
    def step(self, timestep_t: dm_env.TimeStep,
             a_t: parts.Action) -> Iterable[Transition]:
        """Accumulates timestep and resulting action, maybe yields transitions."""
        if timestep_t.first():
            self.reset()

        # There are no transitions on the first timestep.
        if self._timestep_tm1 is None:
            assert self._a_tm1 is None
            if not timestep_t.first():
                raise ValueError('Expected FIRST timestep, got %s.' %
                                 str(timestep_t))
            self._timestep_tm1 = timestep_t
            self._a_tm1 = a_t
            return  # Empty iterable.

        self._transitions.append(
            Transition(
                s_tm1=self._timestep_tm1.observation,
                a_tm1=self._a_tm1,
                r_t=timestep_t.reward,
                discount_t=timestep_t.discount,
                s_t=timestep_t.observation,
                a_t=a_t,
                mc_return_tm1=None,
            ))

        self._timestep_tm1 = timestep_t
        self._a_tm1 = a_t

        if timestep_t.last():
            # Annotate all episode transitions with their MC returns.
            mc_return = 0
            mc_transitions = []
            while self._transitions:
                transition = self._transitions.pop()
                mc_return = transition.discount_t * mc_return + transition.r_t
                mc_transitions.append(
                    transition._replace(mc_return_tm1=mc_return))
            for transition in reversed(mc_transitions):
                yield transition

        else:
            # Wait for episode end before yielding anything.
            return
Exemplo n.º 22
0
def get_seq_timesteps_1() -> TimeStep:
    return TimeStep(
        step_type=StepType.FIRST,
        reward=0.0,
        discount=1.0,
        observation=OLT(observation=[0.1, 0.3, 0.7],
                        legal_actions=[1],
                        terminal=[0.0]),
    )
Exemplo n.º 23
0
 def add(self,
         action: types.NestedArray,
         next_timestep: dm_env.TimeStep,
         extras: types.NestedArray = ()):
     updated_timestep = next_timestep._replace(
         reward=self._rewarder.append_and_compute_reward(
             observation=self._latest_observation, action=action))
     self._latest_observation = next_timestep.observation
     self._adder.add(action, updated_timestep, extras)
Exemplo n.º 24
0
  def step(self, timestep: dm_env.TimeStep) -> None:
    """Accumulates statistics from timestep."""
    if timestep.first():
      if self._current_episode_rewards:
        raise ValueError('Current episode reward list should be empty.')
      if self._current_episode_step != 0:
        raise ValueError('Current episode step should be zero.')
    else:
      # First reward is invalid, all other rewards are appended.
      self._current_episode_rewards.append(timestep.reward)

    self._num_steps_since_reset += 1
    self._current_episode_step += 1

    if timestep.last():
      self._episode_returns.append(sum(self._current_episode_rewards))
      self._current_episode_rewards = []
      self._num_steps_over_episodes += self._current_episode_step
      self._current_episode_step = 0
Exemplo n.º 25
0
def get_seq_timesteps_dict_2() -> Dict[str, SeqTimestepDict]:
    return {
        "agent_0": {
            "timestep":
            TimeStep(
                step_type=StepType.FIRST,
                reward=-1,
                discount=0.8,
                observation=OLT(observation=[0.1, 0.5, 0.7],
                                legal_actions=[1],
                                terminal=[0.0]),
            ),
            "action":
            0,
        },
        "agent_1": {
            "timestep":
            TimeStep(
                step_type=StepType.FIRST,
                reward=0.0,
                discount=0.8,
                observation=OLT(observation=[0.8, 0.3, 0.7],
                                legal_actions=[1],
                                terminal=[0.0]),
            ),
            "action":
            2,
        },
        "agent_2": {
            "timestep":
            TimeStep(
                step_type=StepType.FIRST,
                reward=1,
                discount=1.0,
                observation=OLT(observation=[0.9, 0.9, 0.8],
                                legal_actions=[1],
                                terminal=[0.0]),
            ),
            "action":
            1,
        },
    }
Exemplo n.º 26
0
    def step(self, timestep_t: dm_env.TimeStep,
             a_t: parts.Action) -> Iterable[Transition]:
        """Accumulates timestep and resulting action, yields transitions."""
        if timestep_t.first():
            self.reset()

        # There are no transitions on the first timestep.
        if self._timestep_tm1 is None:
            assert self._a_tm1 is None
            if not timestep_t.first():
                raise ValueError('Expected FIRST timestep, got %s.' %
                                 str(timestep_t))
            self._timestep_tm1 = timestep_t
            self._a_tm1 = a_t
            return  # Empty iterator.

        self._transitions.append(
            Transition(
                s_tm1=self._timestep_tm1.observation,
                a_tm1=self._a_tm1,
                r_t=timestep_t.reward,
                discount_t=timestep_t.discount,
                s_t=timestep_t.observation,
            ))

        self._timestep_tm1 = timestep_t
        self._a_tm1 = a_t

        if timestep_t.last():
            # Yield any remaining n, n-1, ..., 1-step transitions at episode end.
            while self._transitions:
                yield _build_n_step_transition(self._transitions)
                self._transitions.popleft()
        else:
            # Wait for n transitions before yielding anything.
            if len(self._transitions) < self._transitions.maxlen:
                return  # Empty iterator.

            assert len(self._transitions) == self._transitions.maxlen

            # This is the typical case, yield a single n-step transition.
            yield _build_n_step_transition(self._transitions)
Exemplo n.º 27
0
Arquivo: base.py Projeto: wzyxwqx/acme
    def add(self,
            action: types.NestedArray,
            next_timestep: dm_env.TimeStep,
            extras: types.NestedArray = ()):
        """Record an action and the following timestep."""
        if self._next_observation is None:
            raise ValueError(
                'adder.add_first must be called before adder.add.')

        discount = next_timestep.discount
        if next_timestep.last():
            # Terminal timesteps created by dm_env.termination() will have a scalar
            # discount of 0.0. This may not match the array shape / nested structure
            # of the previous timesteps' discounts. The below will match
            # next_timestep.discount's shape/structure to that of
            # self._buffer[-1].discount.
            if self._buffer and not tree.is_nested(next_timestep.discount):
                discount = tree.map_structure(
                    lambda d: np.broadcast_to(next_timestep.discount,
                                              np.shape(d)),
                    self._buffer[-1].discount)

        # Add the timestep to the buffer.
        self._buffer.append(
            Step(
                observation=self._next_observation,
                action=action,
                reward=next_timestep.reward,
                discount=discount,
                start_of_episode=self._start_of_episode,
                extras=extras,
            ))

        # Record the next observation and write.
        self._next_observation = next_timestep.observation
        self._start_of_episode = False
        self._write()

        # Write the last "dangling" observation.
        if next_timestep.last():
            self._write_last()
            self.reset()
Exemplo n.º 28
0
  def observe_first(self, timestep: dm_env.TimeStep):
    if self._queue is not None:
      self._queue.add_first(timestep)

    # Set the state to None so that we re-initialize at the next policy call.
    assert timestep.step_type.shape[0] == 1, \
        'Currently only supports single worker.'

    # Reset hidden state every new episode.
    if timestep.first():
      self._state = None
Exemplo n.º 29
0
 def _add_reward_noise(self, timestep: dm_env.TimeStep):
     if timestep.first():
         return timestep
     if self._bernoulli:
         reward = self._rng.binomial(p=timestep.reward, n=1,
                                     size=1).astype(np.float64)[0]
     else:
         reward = timestep.reward + self._noise_scale * self._rng.randn()
     return dm_env.TimeStep(step_type=timestep.step_type,
                            reward=reward,
                            discount=timestep.discount,
                            observation=timestep.observation)
Exemplo n.º 30
0
    def step(self, timestep: dm_env.TimeStep, loss, shaped_reward,
             penalties) -> None:
        """Accumulates statistics from timestep."""

        if timestep.first():
            if self._current_episode_rewards:
                raise ValueError(
                    'Current episode reward list should be empty.')
            if self._current_episode_step != 0:
                raise ValueError('Current episode step should be zero.')
        else:
            # First reward is invalid, all other rewards are appended.
            self._current_episode_rewards.append(timestep.reward)

        if shaped_reward is not None:
            if isinstance(shaped_reward, list):
                self._current_episode_shaped_rewards.extend(shaped_reward)
            else:
                self._current_episode_shaped_rewards.append(shaped_reward)
        if loss is not None:
            self._current_episode_loss += loss

        if penalties is not None:
            if isinstance(penalties, list):
                self._current_episode_penalties.extend(penalties)
            else:
                self._current_episode_penalties.append(penalties)

        self._num_steps_since_reset += 1
        self._current_episode_step += 1

        if timestep.last():
            self._episode_returns.append(sum(self._current_episode_rewards))
            self._current_episode_rewards = []
            self._current_episode_shaped_rewards = []
            self._current_episode_penalties = []
            self._current_episode_loss = 0
            self._num_steps_over_episodes += self._current_episode_step
            self._current_episode_step = 0