Exemplo n.º 1
0
    def _track(self, timestep: dm_env.TimeStep):
        # Count transitions only.
        if not timestep.first():
            self._steps += 1
            self._episode_len += 1
        if timestep.last():
            self._episode += 1
        self._episode_return += timestep.reward or 0.0
        self._total_return += timestep.reward or 0.0

        # Log statistics periodically, either by step or by episode.
        if self._log_by_step:
            if _logarithmic_logging(self._steps) or self._log_every:
                self._log_bsuite_data()

        elif timestep.last():
            if _logarithmic_logging(self._episode) or self._log_every:
                self._log_bsuite_data()

        # Perform bookkeeping at the end of episodes.
        if timestep.last():
            self._episode_len = 0
            self._episode_return = 0.0

        if self._episode == self._env.bsuite_num_episodes:
            self.flush()
Exemplo n.º 2
0
    def update(self, old_step: dm_env.TimeStep, action: int,
               new_step: dm_env.TimeStep):
        """Takes in a transition from the environment."""
        self._total_steps += 1
        if new_step.last():
            self._total_episodes += 1
            self._active_head = np.random.randint(self._num_ensemble)

        if not old_step.last():
            self._replay.add(
                TransitionWithMaskAndNoise(
                    o_tm1=old_step.observation,
                    a_tm1=action,
                    r_t=new_step.reward,
                    d_t=new_step.discount,
                    o_t=new_step.observation,
                    m_t=self._rng.binomial(1, self._mask_prob,
                                           self._num_ensemble),
                    z_t=self._rng.randn(self._num_ensemble) *
                    self._noise_scale,
                ))

        if self._replay.size < self._min_replay_size:
            return

        if self._total_steps % self._sgd_period == 0:
            minibatch = self._replay.sample(self._batch_size)
            self._sgd_step(*minibatch)

        if self._total_steps % self._target_update_period == 0:
            self._update_target_nets()
Exemplo n.º 3
0
    def update(self, old_step: dm_env.TimeStep, action: int,
               new_step: dm_env.TimeStep):
        """Takes in a transition from the environment."""
        self._total_steps += 1
        if new_step.last():
            self._total_episodes += 1

        if not old_step.last():
            self._replay.add(
                TransitionWithMaskAndNoise(
                    x0=old_step.observation,
                    a0=action,
                    r1=new_step.reward,
                    gamma1=new_step.discount,
                    x1=new_step.observation,
                ))

        if self._replay.size < self._min_replay_size:
            return

        if self._total_steps % self._sgd_period == 0:
            x0, a0, r1, gamma1, x1 = self._replay.sample(self._batch_size)
            target = self._compute_target(r1, gamma1, x1)
            with tf.GradientTape() as tape:
                q0 = self._compute_prediction(x0, a0)
                td_error = target - q0
                loss = tf.reduce_sum(tf.square(td_error))

            gradients = tape.gradient(loss,
                                      self._online_network.trainable_variables)
            self._optimizer.apply_gradients(
                zip(gradients, self._online_network.trainable_variables))

        if self._total_steps % self._target_update_period == 0:
            self._update_target_nets()
Exemplo n.º 4
0
 def update(
     self,
     timestep: dm_env.TimeStep,
     action: base.Action,
     new_timestep: dm_env.TimeStep,
 ):
   """Adds a transition to the trajectory buffer and periodically does SGD."""
   if new_timestep.last():
     self._state = self._state._replace(rnn_state=self._initial_rnn_state)
   self._buffer.append(timestep, action, new_timestep)
   if self._buffer.full() or new_timestep.last():
     trajectory = self._buffer.drain()
     self._state = self._sgd_step(self._state, trajectory)
Exemplo n.º 5
0
    def add(
        self,
        actions: Dict[str, types.NestedArray],
        next_timestep: dm_env.TimeStep,
        next_extras: Dict[str, types.NestedArray] = {},
    ) -> None:
        """Record an action and the following timestep."""
        if self._next_observations is None:
            raise ValueError(
                "adder.add_first must be called before adder.add.")

        discount = next_timestep.discount
        if next_timestep.last():
            # Terminal timesteps created by dm_env.termination() will have a scalar
            # discount of 0.0. This may not match the array shape / nested structure
            # of the previous timesteps' discounts. The below will match
            # next_timestep.discount's shape/structure to that of
            # self._buffer[-1].discount.
            if self._buffer and not tree.is_nested(next_timestep.discount):
                discount = tree.map_structure(
                    lambda d: np.broadcast_to(next_timestep.discount,
                                              np.shape(d)),
                    self._buffer[-1].discount,
                )

        self._buffer.append(
            Step(
                observations=self._next_observations,
                actions=actions,
                rewards=next_timestep.reward,
                discounts=discount,
                start_of_episode=self._start_of_episode,
                extras=self._next_extras
                if self._use_next_extras else next_extras,
            ))

        # Write the last "dangling" observation.
        if next_timestep.last():
            self._start_of_episode = False
            self._write()
            self._write_last()
            self.reset()
        else:
            # Record the next observation and write.
            # Possibly store next_extras
            if self._use_next_extras:
                self._next_extras = next_extras
            self._next_observations = next_timestep.observation
            self._start_of_episode = False
            self._write()
Exemplo n.º 6
0
    def update(self, old_step: dm_env.TimeStep, action: int,
               new_step: dm_env.TimeStep):
        """Takes in a transition from the environment."""
        #counting
        self._total_steps += 1
        if new_step.last():
            self._total_episodes += 1

        #adding data to the replay buffer
        if not old_step.last():
            self._replay.add(
                TransitionWithMaskAndNoise(
                    x0=old_step.observation,
                    a0=action,
                    r1=new_step.reward,
                    gamma1=new_step.discount,
                    x1=new_step.observation,
                ))

        # Keep gathering data
        if self._replay.size < self._min_replay_size:
            return

        #training step
        if self._total_steps % self._sgd_period == 0:
            x0, a0, r1, gamma1, x1 = self._replay.sample(self._batch_size)
            target = self._compute_target(r1, gamma1, x1)
            with tf.GradientTape() as tape:
                q0 = self._compute_prediction(x0, a0)
                # loss = -tf.reduce_sum( q0.log_prob(target) ) #q0 is not a distribution
                td_error = target - q0
                loss = tf.reduce_sum(tf.square(td_error))
            gradients = tape.gradient(loss,
                                      self._online_network.trainable_variables)
            self._optimizer.apply_gradients(
                zip(gradients, self._online_network.trainable_variables))

        #calculating posterior
        if self._total_steps % self._posterior_update_period == 0:
            self._compute_posterior()

        #sampling new out weights (mus)
        if self._total_steps % self._sample_out_mus_period == 0:
            self._sample_out_mus()

        #updating nets
        if self._total_steps % self._target_update_period == 0:
            self._update_target_nets()
Exemplo n.º 7
0
  def update(
      self,
      timestep: dm_env.TimeStep,
      action: types.Action,
      next_timestep: dm_env.TimeStep,
  ) -> dm_env.TimeStep:
    # Add the true transition to replay.
    transition = [
        timestep.observation,
        action,
        next_timestep.reward,
        next_timestep.discount,
        next_timestep.observation,
    ]
    self._replay.add(transition)

    # Step the model to generate a synthetic transition.
    ts = self.step(action)

    # Copy the *true* state on update.
    self._state = next_timestep.observation.copy()

    if ts.last() or next_timestep.last():
      # Model believes that a termination has happened.
      # This will result in a crash during planning if the true environment
      # didn't terminate here as well. So, we indicate that we need a reset.
      self._needs_reset = True

    # Sample from replay and do SGD.
    if self._replay.size >= self._batch_size:
      batch = self._replay.sample(self._batch_size)
      self._step(*batch)

    return ts
Exemplo n.º 8
0
    def update(self, step: dm_env.TimeStep, action: int,
               next_step: dm_env.TimeStep) -> None:
        """
        Adds experience to the replay memory, performs an optimization_step and updates the q_target neural network.
        Args:
            step(dm_env.TimeStep): Current observation from the environment
            action(int): The action that was performed by the agent.
            next_step(dm_env.TimeStep): Next observation from the environment
        Returns:
            None
        """

        observation = np.array(step.observation).flatten()
        next_observation = np.array(next_step.observation).flatten()
        done = next_step.last()
        exp = Experience(observation, action, next_step.reward,
                         next_step.discount, next_observation, 0, done)
        self.memory.add(exp)

        if self.memory.number_samples() < self.start_optimization:
            return

        if self.number_steps % self.update_qnet_every == 0:
            s0, a0, n_step_reward, discount, s1, _, dones, indices, weights = self.memory.sample_batch(
                self.batch_size)
            self.optimization_step(s0, a0, n_step_reward, discount, s1,
                                   indices, weights)

        if self.number_steps % self.update_target_every == 0:
            self.q_target.load_state_dict(self.qnet.state_dict())
        return
Exemplo n.º 9
0
    def update(self, timestep: TimeStep, action: int,
               new_timestep: TimeStep) -> float:
        """True online Sarsa(λ) update using accumulating traces.
        See Sutton R., Barto, G. 2018. Reinforcement Learning: an Introduction, pp. 300-306"""
        g, l, a = timestep.discount, self.lamda, self.alpha
        x_m1, x = timestep.observation, new_timestep.observation
        q_m1, q = self.q_values(x_m1), self.q_values(x)

        # as we are using TD(λ) for control
        td = new_timestep.reward + g * q - q_m1

        # update traces using accumulation
        self.z = g * l * self.z + (1 -
                                   a * g * l * jnp.dot(self.z.T, x_m1)) * x_m1

        # update weights
        self.w = (self.w + a(td + q_m1 - self._q_old) * self.z + a *
                  (q_m1 - self._q_old) * x_m1)

        # update value estimate
        error = q - self._q_old
        self._q_old = int(not new_timestep.last()) * q

        # log
        self.logger.log({"loss": error})

        return error
Exemplo n.º 10
0
    def update(
        self,
        timestep: dm_env.TimeStep,
        action: base.Action,
        new_timestep: dm_env.TimeStep,
    ):
        """Update the agent: add transition to replay and periodically do SGD."""
        if new_timestep.last():
            self._active_head = np.random.randint(self._num_ensemble)

        self._replay.add(
            TransitionWithMaskAndNoise(
                o_tm1=timestep.observation,
                a_tm1=action,
                r_t=np.float32(new_timestep.reward),
                d_t=np.float32(new_timestep.discount),
                o_t=new_timestep.observation,
                m_t=self._rng.binomial(1, self._mask_prob,
                                       self._num_ensemble).astype(np.float32),
                z_t=self._rng.randn(self._num_ensemble).astype(np.float32) *
                self._noise_scale,
            ))

        if self._replay.size < self._min_replay_size:
            return

        if tf.math.mod(self._total_steps, self._sgd_period) == 0:
            minibatch = self._replay.sample(self._batch_size)
            minibatch = [tf.convert_to_tensor(x) for x in minibatch]
            self._step(minibatch)
Exemplo n.º 11
0
    def add(self,
            action: types.NestedArray,
            next_timestep: dm_env.TimeStep,
            extras: types.NestedArray = ()):
        """Record an action and the following timestep."""
        if self._next_observation is None:
            raise ValueError(
                'adder.add_first must be called before adder.add.')

        # Add the timestep to the buffer.
        self._buffer.append(
            Step(
                observation=self._next_observation,
                action=action,
                reward=next_timestep.reward,
                discount=next_timestep.discount,
                extras=extras,
            ))

        # Record the next observation and write.
        self._next_observation = next_timestep.observation
        self._write()

        # Write the last "dangling" observation.
        if next_timestep.last():
            self._write_last()
            self.reset()
Exemplo n.º 12
0
    def add(
        self,
        timestep: dm_env.TimeStep,
        action: int,
        new_timestep: dm_env.TimeStep,
        trace_decay: TraceDecay = 1.0,
        preprocess: Callable[[Observation], Observation] = lambda x: x,
    ) -> None:
        #  if buffer is full, prepare for new trajectory
        if self.full():
            self._reset()

        # add new transition to the trajectory
        self.trajectory.observations[self._t] = preprocess(
            jnp.array(timestep.observation, dtype=jnp.float32))
        self.trajectory.actions[self._t] = action
        self.trajectory.rewards[self._t] = new_timestep.reward
        self.trajectory.discounts[self._t] = new_timestep.discount
        self.trajectory.trace_decays[self._t] = trace_decay
        self._t += 1

        #  if we have enough transitions, add last obs and return
        if self.full():
            self.trajectory.observations[self._t] = preprocess(
                jnp.array(new_timestep.observation, dtype=jnp.float32))
        #  if we do not have enough transitions, and can't sample more, retry
        elif new_timestep.last():
            self._reset()
        return
Exemplo n.º 13
0
    def add(
        self,
        timestep: dm_env.TimeStep,
        action: int,
        new_timestep: dm_env.TimeStep,
        preprocess=lambda x: x,
    ) -> None:
        #  start of a new episode
        if not self.collecting():
            self.trajectories.append(self._current)
            self._reset()

        # add new transition to the trajectory
        store = lambda x: device_array(x, device=self.device)
        self._current.observations[self._t] = preprocess(
            store(timestep.observation))
        self._current.actions[self._t] = store(int(action))
        self._current.rewards[self._t] = store(float(new_timestep.reward))
        self._current.discounts[self._t] = store(float(new_timestep.discount))
        self._t += 1

        # ready to store, just add final observation
        if not self.collecting():
            self._current.observations[self._t] = preprocess(
                jnp.array(timestep.observation, dtype=jnp.float32))
        # if not enough samples, and we can't sample the env anymore, reset
        elif new_timestep.last():
            self._reset()
        return
Exemplo n.º 14
0
    def step(
        self,
        environment: Optional[dm_env.Environment],
        timestep_t: dm_env.TimeStep,
        agent: Optional[Agent],
        a_t: Optional[Action],
    ) -> None:
        """Accumulates statistics from timestep."""
        del (environment, agent, a_t)

        if timestep_t.first():
            if self._current_episode_rewards:
                raise ValueError(
                    'Current episode reward list should be empty.')
            if self._current_episode_step != 0:
                raise ValueError('Current episode step should be zero.')
        else:
            # First reward is invalid, all other rewards are appended.
            self._current_episode_rewards.append(timestep_t.reward)

        self._num_steps_since_reset += 1
        self._current_episode_step += 1

        if timestep_t.last():
            self._episode_returns.append(sum(self._current_episode_rewards))
            self._current_episode_rewards = []
            self._num_steps_over_episodes += self._current_episode_step
            self._current_episode_step = 0
Exemplo n.º 15
0
  def append(
      self,
      timestep: dm_env.TimeStep,
      action: base.Action,
      new_timestep: dm_env.TimeStep,
  ):
    """Appends an observation, action, reward, and discount to the buffer."""
    if self.full():
      raise ValueError('Cannot append; sequence buffer is full.')

    # Start a new sequence with an initial observation, if required.
    if self._needs_reset:
      self._t = 0
      self._observations[self._t] = timestep.observation
      self._needs_reset = False

    # Append (o, a, r, d) to the sequence buffer.
    print('sequece save obs type is', type(new_timestep.observation))
    self._observations[self._t + 1] = new_timestep.observation
    self._actions[self._t] = action
    self._rewards[self._t] = new_timestep.reward
    self._discounts[self._t] = new_timestep.discount
    self._t += 1

    # Don't accumulate sequences that cross episode boundaries.
    # It is up to the caller to drain the buffer in this case.
    if new_timestep.last():
      self._needs_reset = True
Exemplo n.º 16
0
Arquivo: base.py Projeto: wzyxwqx/acme
    def add(self,
            action: types.NestedArray,
            next_timestep: dm_env.TimeStep,
            extras: types.NestedArray = ()):
        """Record an action and the following timestep."""
        if self._next_observation is None:
            raise ValueError(
                'adder.add_first must be called before adder.add.')

        discount = next_timestep.discount
        if next_timestep.last():
            # Terminal timesteps created by dm_env.termination() will have a scalar
            # discount of 0.0. This may not match the array shape / nested structure
            # of the previous timesteps' discounts. The below will match
            # next_timestep.discount's shape/structure to that of
            # self._buffer[-1].discount.
            if self._buffer and not tree.is_nested(next_timestep.discount):
                discount = tree.map_structure(
                    lambda d: np.broadcast_to(next_timestep.discount,
                                              np.shape(d)),
                    self._buffer[-1].discount)

        # Add the timestep to the buffer.
        self._buffer.append(
            Step(
                observation=self._next_observation,
                action=action,
                reward=next_timestep.reward,
                discount=discount,
                start_of_episode=self._start_of_episode,
                extras=extras,
            ))

        # Record the next observation and write.
        self._next_observation = next_timestep.observation
        self._start_of_episode = False
        self._write()

        # Write the last "dangling" observation.
        if next_timestep.last():
            self._write_last()
            self.reset()
Exemplo n.º 17
0
  def update(
      self,
      timestep: dm_env.TimeStep,
      action: base.Action,
      new_timestep: dm_env.TimeStep,
  ):
    """Receives a transition and performs a learning update."""
    self._buffer.append(timestep, action, new_timestep)

    if self._buffer.full() or new_timestep.last():
      trajectory = self._buffer.drain()
      trajectory = tree.map_structure(tf.convert_to_tensor, trajectory)
      self._rollout_initial_state = self._step(trajectory)
Exemplo n.º 18
0
    def update(
        self,
        timestep: dm_env.TimeStep,
        action: base.Action,
        new_timestep: dm_env.TimeStep,
    ):
        """Update the agent: add transition to replay and periodically do SGD."""

        # Thompson sampling: every episode pick a new Q-network as the policy.
        if new_timestep.last():
            k = np.random.randint(self._num_ensemble)
            self._active_head = self._ensemble[k]

        # Generate bootstrapping mask & reward noise.
        mask = np.random.binomial(1, self._mask_prob, self._num_ensemble)
        noise = np.random.randn(self._num_ensemble)

        # Make transition and add to replay.
        transition = [
            timestep.observation,
            action,
            np.float32(new_timestep.reward),
            np.float32(new_timestep.discount),
            new_timestep.observation,
            mask,
            noise,
        ]
        self._replay.add(transition)

        if self._replay.size < self._min_replay_size:
            return

        # Periodically sample from replay and do SGD for the whole ensemble.
        if self._total_steps % self._sgd_period == 0:
            transitions = self._replay.sample(self._batch_size)
            o_tm1, a_tm1, r_t, d_t, o_t, m_t, z_t = transitions
            for k, state in enumerate(self._ensemble):
                transitions = [
                    o_tm1, a_tm1, r_t, d_t, o_t, m_t[:, k], z_t[:, k]
                ]
                self._ensemble[k] = self._sgd_step(state, transitions)

        # Periodically update target parameters.
        for k, state in enumerate(self._ensemble):
            if state.step % self._target_update_period == 0:
                self._ensemble[k] = state._replace(target_params=state.params)
Exemplo n.º 19
0
    def step(self, timestep_t: dm_env.TimeStep,
             a_t: parts.Action) -> Iterable[Transition]:
        """Accumulates timestep and resulting action, maybe yields transitions."""
        if timestep_t.first():
            self.reset()

        # There are no transitions on the first timestep.
        if self._timestep_tm1 is None:
            assert self._a_tm1 is None
            if not timestep_t.first():
                raise ValueError('Expected FIRST timestep, got %s.' %
                                 str(timestep_t))
            self._timestep_tm1 = timestep_t
            self._a_tm1 = a_t
            return  # Empty iterable.

        self._transitions.append(
            Transition(
                s_tm1=self._timestep_tm1.observation,
                a_tm1=self._a_tm1,
                r_t=timestep_t.reward,
                discount_t=timestep_t.discount,
                s_t=timestep_t.observation,
                a_t=a_t,
                mc_return_tm1=None,
            ))

        self._timestep_tm1 = timestep_t
        self._a_tm1 = a_t

        if timestep_t.last():
            # Annotate all episode transitions with their MC returns.
            mc_return = 0
            mc_transitions = []
            while self._transitions:
                transition = self._transitions.pop()
                mc_return = transition.discount_t * mc_return + transition.r_t
                mc_transitions.append(
                    transition._replace(mc_return_tm1=mc_return))
            for transition in reversed(mc_transitions):
                yield transition

        else:
            # Wait for episode end before yielding anything.
            return
Exemplo n.º 20
0
    def add(
        self,
        timestep: dm_env.TimeStep,
        action: int,
        new_timestep: dm_env.TimeStep,
        preprocess: Callable = lambda x: x,
    ) -> None:
        #  if buffer is full, prepare for new trajectory
        if self.full():
            self._reset()

        #  collect experience
        self.states.append(preprocess(timestep.observation))
        self._terminal = new_timestep.last()

        #   if transition is terminal, append last state
        if self.full():
            self.states.append(preprocess(new_timestep.observation))
        return
Exemplo n.º 21
0
  def step(self, timestep: dm_env.TimeStep) -> None:
    """Accumulates statistics from timestep."""
    if timestep.first():
      if self._current_episode_rewards:
        raise ValueError('Current episode reward list should be empty.')
      if self._current_episode_step != 0:
        raise ValueError('Current episode step should be zero.')
    else:
      # First reward is invalid, all other rewards are appended.
      self._current_episode_rewards.append(timestep.reward)

    self._num_steps_since_reset += 1
    self._current_episode_step += 1

    if timestep.last():
      self._episode_returns.append(sum(self._current_episode_rewards))
      self._current_episode_rewards = []
      self._num_steps_over_episodes += self._current_episode_step
      self._current_episode_step = 0
Exemplo n.º 22
0
    def step(self, timestep_t: dm_env.TimeStep,
             a_t: parts.Action) -> Iterable[Transition]:
        """Accumulates timestep and resulting action, yields transitions."""
        if timestep_t.first():
            self.reset()

        # There are no transitions on the first timestep.
        if self._timestep_tm1 is None:
            assert self._a_tm1 is None
            if not timestep_t.first():
                raise ValueError('Expected FIRST timestep, got %s.' %
                                 str(timestep_t))
            self._timestep_tm1 = timestep_t
            self._a_tm1 = a_t
            return  # Empty iterator.

        self._transitions.append(
            Transition(
                s_tm1=self._timestep_tm1.observation,
                a_tm1=self._a_tm1,
                r_t=timestep_t.reward,
                discount_t=timestep_t.discount,
                s_t=timestep_t.observation,
            ))

        self._timestep_tm1 = timestep_t
        self._a_tm1 = a_t

        if timestep_t.last():
            # Yield any remaining n, n-1, ..., 1-step transitions at episode end.
            while self._transitions:
                yield _build_n_step_transition(self._transitions)
                self._transitions.popleft()
        else:
            # Wait for n transitions before yielding anything.
            if len(self._transitions) < self._transitions.maxlen:
                return  # Empty iterator.

            assert len(self._transitions) == self._transitions.maxlen

            # This is the typical case, yield a single n-step transition.
            yield _build_n_step_transition(self._transitions)
Exemplo n.º 23
0
    def step(self, timestep: dm_env.TimeStep, loss, shaped_reward,
             penalties) -> None:
        """Accumulates statistics from timestep."""

        if timestep.first():
            if self._current_episode_rewards:
                raise ValueError(
                    'Current episode reward list should be empty.')
            if self._current_episode_step != 0:
                raise ValueError('Current episode step should be zero.')
        else:
            # First reward is invalid, all other rewards are appended.
            self._current_episode_rewards.append(timestep.reward)

        if shaped_reward is not None:
            if isinstance(shaped_reward, list):
                self._current_episode_shaped_rewards.extend(shaped_reward)
            else:
                self._current_episode_shaped_rewards.append(shaped_reward)
        if loss is not None:
            self._current_episode_loss += loss

        if penalties is not None:
            if isinstance(penalties, list):
                self._current_episode_penalties.extend(penalties)
            else:
                self._current_episode_penalties.append(penalties)

        self._num_steps_since_reset += 1
        self._current_episode_step += 1

        if timestep.last():
            self._episode_returns.append(sum(self._current_episode_rewards))
            self._current_episode_rewards = []
            self._current_episode_shaped_rewards = []
            self._current_episode_penalties = []
            self._current_episode_loss = 0
            self._num_steps_over_episodes += self._current_episode_step
            self._current_episode_step = 0
Exemplo n.º 24
0
    def add(self,
            action: types.NestedArray,
            next_timestep: dm_env.TimeStep,
            extras: types.NestedArray = ()):
        """Record an action and the following timestep."""

        if not self._add_first_called:
            raise ValueError(
                'adder.add_first must be called before adder.add.')

        # Add the timestep to the buffer.
        has_extras = (
            len(extras) > 0 if isinstance(extras, Sized)  # pylint: disable=g-explicit-length-test
            else extras is not None)
        current_step = dict(
            # Observation was passed at the previous add call.
            action=action,
            reward=next_timestep.reward,
            discount=next_timestep.discount,
            # Start of episode indicator was passed at the previous add call.
            **({
                'extras': extras
            } if has_extras else {}))
        self._writer.append(current_step)

        # Record the next observation and write.
        self._writer.append(dict(observation=next_timestep.observation,
                                 start_of_episode=next_timestep.first()),
                            partial_step=True)
        self._write()

        if next_timestep.last():
            # Complete the row by appending zeros to remaining open fields.
            # TODO(b/183945808): remove this when fields are no longer expected to be
            # of equal length on the learner side.
            dummy_step = tree.map_structure(np.zeros_like, current_step)
            self._writer.append(dummy_step)
            self._write_last()
            self.reset()
Exemplo n.º 25
0
    def update(self, timestep: dm_env.TimeStep, action: base.Action, new_timestep: dm_env.TimeStep,):
        # Add this transition to replay.
        print("update: " + str(self._total_steps))
        #self._memory.add([timestep.observation, action, new_timestep.reward, new_timestep.discount, new_timestep.observation, new_timestep.last()])
        self._memory.append((timestep.observation, action, new_timestep.reward, new_timestep.discount, new_timestep.observation, new_timestep.last()))

        self._total_steps += 1

        # if self._memory.size > self._batch_size:
        if len(self._memory) > self._batch_size:
            self.replay(self._batch_size)
Exemplo n.º 26
0
    def update(self, timestep: dm_env.TimeStep, action: base.Action, new_timestep: dm_env.TimeStep,):
        # Add this transition to replay.
        if self._total_steps % 50 == 0:
            print("update: " + str(self._total_steps) + " " + datetime.now().strftime("%H:%M:%S"))
        #self._memory.add([timestep.observation, action, new_timestep.reward, new_timestep.discount, new_timestep.observation, new_timestep.last()])
        self._memory.append((timestep.observation, action, new_timestep.reward, new_timestep.discount, new_timestep.observation, new_timestep.last()))

        self._total_steps += 1

        # if self._memory.size > self._batch_size:
        if len(self._memory) > self._batch_size:
            self.replay(self._batch_size)
            self.target_train()  # iterates target model