def test_cross_entropy_method_assert_step_index(observation_space, action_space, horizon): env_model, policy = get_cross_entropy_policy(observation_space, action_space, horizon, 2) initial_time_step = env_model.reset() initial_policy_state = policy.get_initial_state(env_model.batch_size) policy_step = policy.action(initial_time_step, initial_policy_state) mid_time_step = TimeStep( np.array([StepType.MID, StepType.LAST]), initial_time_step.reward, initial_time_step.discount, initial_time_step.observation, ) for _ in range(horizon): policy_step = policy.action(mid_time_step, policy_step.state) final_time_step = TimeStep( np.array([StepType.MID, StepType.LAST]), initial_time_step.reward, initial_time_step.discount, initial_time_step.observation, ) with pytest.raises(AssertionError) as excinfo: policy.action(final_time_step, policy_step.state) assert f"Max step index {horizon + 1} is out of range (> {horizon})" in str( excinfo)
def _step(self, action: np.ndarray): if self._current_time_step.is_last(): return self._reset() action = tuple(action) if self._states[action] != 0: return TimeStep(StepType.LAST, TicTacToeEnvironment.REWARD_ILLEGAL_MOVE, self._discount, self._states) self._states[action] = 1 is_final, reward = self._check_states(self._states) if is_final: return TimeStep(StepType.LAST, reward, self._discount, self._states) # TODO(b/152638947): handle multiple agents properly. # Opponent places '2' on the board. opponent_action = self._opponent_play(self._states) self._states[opponent_action] = 2 is_final, reward = self._check_states(self._states) step_type = StepType.MID if np.all(self._states == 0): step_type = StepType.FIRST elif is_final: step_type = StepType.LAST return TimeStep(step_type, reward, self._discount, self._states)
def make_timestep_mask(batched_next_time_step: ts.TimeStep, allow_partial_episodes: bool = False) -> types.Tensor: """Create a mask for transitions and optionally final incomplete episodes. Args: batched_next_time_step: Next timestep, doubly-batched [batch_dim, time_dim, ...]. allow_partial_episodes: If true, then steps on incomplete episodes are allowed. Returns: A mask, type tf.float32, that is 0.0 for all between-episode timesteps (batched_next_time_step is FIRST). If allow_partial_episodes is set to False, the mask has 0.0 for incomplete episode at the end of the sequence. """ if allow_partial_episodes: episode_is_complete = None else: # 1.0 for timesteps of all complete episodes. 0.0 for incomplete episode at # the end of the sequence. episode_is_complete = tf.cumsum( tf.cast(batched_next_time_step.is_last(), tf.float32), axis=1, reverse=True) > 0 # 1.0 for all valid timesteps. 0.0 where between episodes. not_between_episodes = ~batched_next_time_step.is_first() if allow_partial_episodes: return tf.cast(not_between_episodes, tf.float32) else: return tf.cast(episode_is_complete & not_between_episodes, tf.float32)
def _step(self, action): self._num_steps.assign_add(tf.ones_like(self._num_steps)) time_step = super()._step(action) time_limit_terminations = tf.math.greater_equal( self._num_steps, self._duration) step_types = tf.where(condition=time_limit_terminations, x=StepType.LAST, y=time_step.step_type) discounts = tf.where(condition=time_limit_terminations, x=0, y=time_step.discount) new_time_step = TimeStep(step_types, time_step.reward, discounts, time_step.observation) self._env._time_step = new_time_step # pylint: disable=protected-access # We convert the TF Tensors to numpy first for performance reasons. if any(new_time_step.is_last().numpy()): terminates = step_types == StepType.LAST termination_indexes = tf.where(terminates) number_terminations = tf.math.count_nonzero(terminates) # we use dtype tf.int32 because this avoids a GPU bug detected by Dongho self._num_steps.scatter_nd_update( termination_indexes, tf.constant(-1, shape=(number_terminations, ), dtype=tf.int32), ) return new_time_step
def compute_avg_return(environment, policies, num_episodes=1): total_return = [0 for _ in AGENT_NAMES] for _ in range(num_episodes): aggregate_time_step = environment.reset() episode_return = [0 for _ in AGENT_NAMES] while not aggregate_time_step.is_last().numpy().all(): is_first_step = aggregate_time_step.reward.shape == (1, ) aggregate_action = {} for i, name in enumerate(AGENT_NAMES): if is_first_step: observation = tf.convert_to_tensor( aggregate_time_step.observation[name].numpy(), dtype='float32') time_step = TimeStep(aggregate_time_step.step_type[0], aggregate_time_step.reward[0], aggregate_time_step.discount[0], observation) else: observation = tf.convert_to_tensor( aggregate_time_step.observation[name].numpy(), dtype='float32') time_step = TimeStep(aggregate_time_step.step_type[0][0], aggregate_time_step.reward[0][i], aggregate_time_step.discount[0], observation) action_step = policies[i].action(time_step) aggregate_action[name] = action_step for i in range(len(AGENT_NAMES)): if is_first_step: episode_return[i] += 0 else: episode_return[i] += aggregate_time_step.reward[0][i] aggregate_time_step = environment.step(aggregate_action) for i in range(len(AGENT_NAMES)): total_return[i] += episode_return[i] avg_return = [0 for _ in range(len(AGENT_NAMES))] for i in range(len(AGENT_NAMES)): avg_return[i] = total_return[i] / num_episodes avg_return = [r.numpy() for r in avg_return] return avg_return
def _step(self, action): """Must return a tf_agents.trajectories.time_step.TimeStep namedTuple obj""" if self._episode_ended: # The last action ended the episode. Ignore the current action and start # a new episode. return self.reset() #print('#### TYPE OF ACTION', type(action)) #if isinstance(action, np.ndarray): action = int(action) #print('#### TYPE OF ACTION', type(action)) observations, reward, done, info = self._env.step(action) observation = observations['player_observations'][ observations['current_player']] reward = np.asarray(reward, dtype=np.float32) obs_vec = np.array(observation['vectorized'], dtype=dtype_vectorized) mask_valid_actions = self.get_mask_legal_moves(observation) obs = {'state': obs_vec, 'mask': mask_valid_actions} if done: self._episode_ended = True step_type = StepType.LAST else: step_type = StepType.MID return TimeStep(step_type, reward, discount, obs)
def _reset(self, pbt_config=None): """Must return a tf_agents.trajectories.time_step.TimeStep namedTubple obj""" # i.e. ['step_type', 'reward', 'discount', 'observation'] self._episode_ended = False observations = self._env.reset(pbt_config) observation = observations['player_observations'][ observations['current_player']] # reward is 0 on reset reward = np.asarray(0, dtype=np.float32) # oberservation is currently a dict, extract the 'vectorized' object obs_vec = np.array(observation['vectorized'], dtype=dtype_vectorized) mask_valid_actions = self.get_mask_legal_moves(observation) info = self._env.state.score() #used for two-player curiosity implementation otherplayer_id = 1 if observations['current_player'] == 1: otherplayer_id = 0 state2 = observations['player_observations'][otherplayer_id] state2_vec = np.array(state2['vectorized'], dtype=dtype_vectorized) obs = { 'state': obs_vec, 'mask': mask_valid_actions, 'info': info, 'state2': state2_vec } return TimeStep(StepType.FIRST, reward, discount, obs)
def call(self, trajectory: Trajectory) -> Trajectory: time_step = TimeStep(trajectory.step_type, trajectory.reward, trajectory.discount, trajectory.observation) action_dist = self._policy.distribution(time_step).action # If the action distribution is in fact a tuple of distributions (one for each resource set) # then we need to index into them to attain the underlying distribution which can then be # used to attain probabilities. This is only the case where there are multiple resource # sets. for i in self._action_indices[:-1]: action_dist = action_dist[i] action_probs = action_dist.probs_parameter() # Zero out batch indices where a new episode is starting. self._probability_accumulator.assign( tf.where(trajectory.is_first(), tf.zeros_like(self._probability_accumulator), self._probability_accumulator)) self._count_accumulator.assign( tf.where(trajectory.is_first(), tf.zeros_like(self._count_accumulator), self._count_accumulator)) # Update accumulators with probability and count increments. self._probability_accumulator.assign_add(action_probs[..., 0, self._action_indices[-1]]) self._count_accumulator.assign_add(tf.ones_like(self._count_accumulator)) # Add final cumulants to buffer at the end of episodes. last_episode_indices = tf.squeeze(tf.where(trajectory.is_last()), axis=-1) for idx in last_episode_indices: self._buffer.add(self._probability_accumulator[idx] / self._count_accumulator[idx]) return trajectory
def _gen_time_step(self, s, action): step_type = StepType.MID discount = 1.0 if s == 0: step_type = StepType.FIRST elif s == self._episode_length - 1: step_type = StepType.LAST discount = 0.0 if s == 0: reward = tf.constant([0.] * self.batch_size) else: prev_observation = self._current_time_step.observation reward = 1.0 - tf.abs(prev_observation - tf.cast(action, tf.float32)) reward = tf.reshape(reward, shape=(self.batch_size, )) observation = tf.constant(np.random.randint(2, size=(self.batch_size, 1)), dtype=tf.float32) return TimeStep(step_type=tf.constant([step_type] * self.batch_size), reward=reward, discount=tf.constant([discount] * self.batch_size), observation=observation)
def _reset(self): np.random.seed(123) self._total_score = 0. self._states = np.zeros(self._shape, np.float32) self._add_numbers() return TimeStep(StepType.FIRST, np.asarray(0.0, dtype=np.float32), self._discount, self._states)
def test_rl_simulation_agent_normalise_obs_usage_with_normalisation(): """Ensure that the _normalise_obs property of RLSimulationAgent is used correctly.""" # Set up the agent as before. seed = 72 state = np.array([100, 100, 100, 100]) env = load_scenario("klimov_model", job_gen_seed=seed, override_env_params={ "initial_state": state }).env rl_env, _ = rl_env_from_snc_env(env, discount_factor=0.99, normalise_observations=True) ppo_agent = MagicMock() ppo_agent.discount_factor = 0.99 ppo_agent._gamma = 0.99 policy = MagicMock() ppo_agent.collect_policy = policy del rl_env ppo_sim_agent = RLSimulationAgent(env, ppo_agent, normalise_obs=True) ppo_sim_agent._rl_env.preprocess_action = MagicMock() ppo_sim_agent.map_state_to_actions(state) expected_timestep = TimeStep(step_type=StepType(0), reward=None, discount=0.99, observation=state.reshape(1, -1) / state.sum()) assert policy.action.call_count == 1 call_timestep = policy.action.call_args[0][0] assert (call_timestep.observation == expected_timestep.observation).all()
def _override_step_type(self, time_step, counter): policy_time_step = tf.cond(tf.equal(tf.math.floormod(counter[0], self.episodes_per_trial), 0), # if first episode in a trial lambda: time_step, lambda: TimeStep(step_type=np.expand_dims(StepType.MID, axis=0), reward=time_step.reward, discount=time_step.discount, observation=time_step.observation,),) return policy_time_step
def _step(self, action: np.ndarray): if self._current_time_step.is_last(): return self._reset() action = self._actions[action[0]] score = 0. if action[1]: it_list = zip(range(self._n-1, 0, -1), range(self._n-2, -1, -1)) else: it_list = zip(range(self._n-1), range(1,self._n)) for i, next_i in it_list: index_base = np.ones(shape=[1 for x in range(self._dims)], dtype=np.int32) xi = np.take_along_axis(self._states, index_base*i, axis=action[0]) xiplus = np.take_along_axis(self._states, index_base*next_i, axis=action[0]) with np.nditer([xi, xiplus], flags=[], op_flags=[['readwrite'], ['readwrite']]) as it: for j, next_j in it: if j[...] != 0. and next_j[...] == 0.: next_j[...] = j[...] j[...] = 0. elif j[...] == next_j[...]: next_j *= 2. j[...] = 0. score += next_j[...] np.put_along_axis(self._states, index_base*i, xi, axis=action[0]) np.put_along_axis(self._states, index_base*next_i, xiplus, axis=action[0]) self._total_score += score is_final = self._add_numbers() return TimeStep( StepType.LAST if is_final else StepType.MID, np.asarray(self._total_score, dtype=np.float32) if is_final else np.asarray(0., dtype=np.float32), self._discount, self._states ) # def show(self, video=None): # n = self.n # block_size = 50 # img = (self.board * (np.iinfo(np.uint8).max/np.max(self.board))).astype(np.uint8) # img = Image.fromarray(img, "L") # img_reshaped = img.resize((n * block_size, n * block_size), resample=Image.BOX) # img_reshaped = np.float32(img_reshaped) # text_color = (200, 20, 220) # for i in range(n): # for j in range(n): # block_value = str(int(self.board[i, j])) # cv2.putText( # img_reshaped, # block_value, # (j*block_size+int(block_size / 2)-10*len(block_value), i*block_size+int(block_size/2)+10), # fontFace=cv2.FONT_HERSHEY_DUPLEX, # fontScale=1, # color=text_color # ) # if video is not None: # video.write(np.uint8(img_reshaped)) # else: # cv2.imshow('2048',img_reshaped)
def prepare(time_step): step_type = fix_tensor(time_step.step_type, spec.step_type) reward = fix_tensor(time_step.reward, spec.reward) discount = fix_tensor(time_step.discount, spec.discount) observation = fix_tensor(time_step.observation, spec.observation) return TimeStep(step_type=step_type, reward=reward, discount=discount, observation=observation)
def _step(self, action): """ act -> reward -> update nav -> next observation Returns ------- observation reward: float episode_over: bool when you lose all your money and go thru the data set? if you don't stop at some point: how to implement episode-based updates? info: dict diagnostic """ if self.current_step >= self.data.shape[0]: return TimeStep( StepType.LAST, np.float32(0), self._discount, np.zeros(shape=(self.back_looking, self.num_features + 2), dtype=np.float64)) self._execute_trade(action) reward = self._calculate_reward() self.current_step += 1 obs = self._observe() done = self.nav <= 0 or self.current_step >= self.data.shape[0] step_type = StepType.MID if self.current_step == self.back_looking: step_type = StepType.FIRST elif done: step_type = StepType.LAST return TimeStep(step_type, reward.astype('float32'), self._discount, obs)
def should_reset(self, current_time_step: ts.TimeStep) -> bool: """Whether the Environmet should reset given the current timestep. By default it only resets when all time_steps are `LAST`. Args: current_time_step: The current `TimeStep`. Returns: A bool indicating whether the Environment should reset or not. """ handle_auto_reset = getattr(self, '_handle_auto_reset', False) return handle_auto_reset and np.all(current_time_step.is_last())
def take_action(self, observation): """ Returns percentage of equity to invest. If negative, that is the amount to sell. """ reward = 0 gamma = 0 step_type = StepType.MID time_step = TimeStep(step_type=np.array([1], dtype=np.int32), reward=np.array([1.0], dtype=np.float32), discount=np.array([1.001], dtype=np.float32), observation=np.array([observation], dtype=np.float32)) return float(self.agent.policy.action(time_step).action)
def _reset(self): self.past_holdings = np.zeros(self.back_looking) self.past_nav = np.ones(self.back_looking) * self.initial_capital self.holding = 0 self.cash = self.initial_capital self.nav = self.initial_capital self.current_step = self.back_looking obs = self._observe() return TimeStep(StepType.FIRST, np.asarray(0.0, dtype=np.float32), self._discount, obs)
def take_action_for_stock(self, stock): feature_generator = self.feature_generators[stock.ticker] price = stock.price() feature_generator.append_price(price) observation = feature_generator.get_features() time_step = TimeStep(step_type=np.array([1], dtype=np.int32), reward=np.array([1.0], dtype=np.float32), discount=np.array([1.001], dtype=np.float32), observation=np.array([observation], dtype=np.float32)) action = self.convert_action(self, self.eval_policy.action(time_step)[0], stock.ticker, price) return action
def _gen_time_step(self, s, action): """Return the current `TimeStep`.""" step_type = StepType.MID discount = 1.0 if s == 0: step_type = StepType.FIRST elif s == self._episode_length - 1: step_type = StepType.LAST discount = 0.0 return TimeStep(step_type=tf.constant([step_type] * self.batch_size), reward=tf.constant([1.] * self.batch_size), discount=tf.constant([discount] * self.batch_size), observation=tf.constant([[1.]] * self.batch_size))
def act(self, obs): batch_obs = {} for key in obs: batch_obs[key] = np.expand_dims(obs[key], axis=0) time_step = TimeStep( np.ones(1), np.ones(1), np.ones(1), batch_obs, ) policy_state = () with self.sess.as_default(): action_step = self.eval_py_policy.action(time_step, policy_state) action = action_step.action[0] return action
def run( self, time_step: ts.TimeStep, policy_state: types.NestedArray = () ) -> Tuple[ts.TimeStep, types.NestedArray]: """Run policy in environment given initial time_step and policy_state. Args: time_step: The initial time_step. policy_state: The initial policy_state. Returns: A tuple (final time_step, final policy_state). """ num_steps = 0 num_episodes = 0 while num_steps < self._max_steps and num_episodes < self._max_episodes: # For now we reset the policy_state for non batched envs. if not self.env.batched and time_step.is_first( ) and num_episodes > 0: policy_state = self._policy.get_initial_state( self.env.batch_size or 1) action_step = self.policy.action(time_step, policy_state) next_time_step = self.env.step(action_step.action) # When using observer (for the purpose of training), only the previous # policy_state is useful. Therefore substitube it in the PolicyStep and # consume it w/ the observer. action_step_with_previous_state = action_step._replace( state=policy_state) traj = trajectory.from_transition(time_step, action_step_with_previous_state, next_time_step) for observer in self._transition_observers: observer((time_step, action_step_with_previous_state, next_time_step)) for observer in self.observers: observer(traj) num_episodes += np.sum(traj.is_boundary()) num_steps += np.sum(~traj.is_boundary()) time_step = next_time_step policy_state = action_step.state return time_step, policy_state
def _reset(self): """Must return a tf_agents.trajectories.time_step.TimeStep namedTubple obj""" # i.e. ['step_type', 'reward', 'discount', 'observation'] self._episode_ended = False observations = self._env.reset() observation = observations['player_observations'][ observations['current_player']] # reward is 0 on reset reward = np.asarray(0, dtype=np.float32) # oberservation is currently a dict, extract the 'vectorized' object obs_vec = np.array(observation['vectorized'], dtype=dtype_vectorized) mask_valid_actions = self.get_mask_legal_moves(observation) obs = {'state': obs_vec, 'mask': mask_valid_actions} # (48, ) int64 #print(mask_valid_actions.shape, mask_valid_actions.dtype) return TimeStep(StepType.FIRST, reward, discount, obs)
def test_step_win(self): self.env.set_state( TimeStep(StepType.MID, TicTacToeEnvironment.REWARD_DRAW_OR_NOT_FINAL, self.discount, np.array([[2, 2, 0], [0, 1, 1], [0, 0, 0]]))) current_time_step = self.env.current_time_step() self.assertEqual(StepType.MID, current_time_step.step_type) ts = self.env.step(np.array([1, 0])) np.testing.assert_array_equal([[2, 2, 0], [1, 1, 1], [0, 0, 0]], ts.observation) self.assertEqual(StepType.LAST, ts.step_type) self.assertEqual(1., ts.reward) # Reset if an action is taken after final state is reached. ts = self.env.step(np.array([2, 0])) self.assertEqual(StepType.FIRST, ts.step_type) self.assertEqual(0., ts.reward)
def map_state_to_actions(self, state: snc_types.StateSpace, **override_args: Any) \ -> snc_types.ActionProcess: """ The action function taking in the observed state and returning an action vector. :param state: The observed state of the environment. :param override_args: Dictionary of additional keyword arguments (in this case all additional arguments are ignored). :return: An action vector in the format expected by the SNC simulator. """ # To be compatible with the TensorFlow agent we must form a TimeStep object to pass the data # to the agent. # Note that for step types 0 is the initial time step, 1 is any non-terminal time step state # and 2 represents the final time step. # The reward is not provided. This code will need to be refactored if the agent's decision # making is based on the reward. if self.env.is_at_final_state: step_type = StepType(2) else: step_type = StepType(int(1 - self.env.is_at_initial_state)) # Scale the state as is done in RLControlledRandomWalk in order to allow RL to work. if self._normalise_obs: scaled_state = self._rl_env.normalise_state(state) else: scaled_state = state.astype(np.float32) time_step = TimeStep( step_type=step_type, reward=None, discount=override_args.get("discount_factor", self.discount_factor), observation=scaled_state.reshape(1, state.shape[0])) # The action provided by the TensorFlow agent is in a form suitable for the TensorFlow # environment. We therefore use the environment's action processing method to convert the # action in to a suitable form for the SNC simulator. rl_action = self._policy.action(time_step) snc_action = self._rl_env.preprocess_action(rl_action.action) return snc_action
def dummy_trajectory_batch(batch_size=2, n_steps=5, obs_dim=2): observations = tf.reshape( tf.constant(np.arange(batch_size * n_steps * obs_dim), dtype=tf.float32), (batch_size, n_steps, obs_dim), ) time_steps = TimeStep( step_type=tf.constant([[1] * (n_steps - 2) + [2] * 2] * batch_size, dtype=tf.int32), reward=tf.constant([[1] * n_steps] * batch_size, dtype=tf.float32), discount=tf.constant([[1.0] * n_steps] * batch_size, dtype=tf.float32), observation=observations, ) actions = tf.ones((batch_size, n_steps, 1), dtype=tf.float32) action_distribution_parameters = { "dist_params": { "loc": tf.constant([[[10.0]] * n_steps] * batch_size, dtype=tf.float32), "scale": tf.constant([[[10.0]] * n_steps] * batch_size, dtype=tf.float32), }, "value_prediction": tf.constant([[0.0] * n_steps] * batch_size, dtype=tf.float32), } policy_info = action_distribution_parameters return Trajectory( time_steps.step_type, observations, actions, policy_info, time_steps.step_type, time_steps.reward, time_steps.discount, )
def _gen_time_step(self, s, action): step_type = StepType.MID discount = 1.0 obs_dim = self._obs_dim if s == 0: self._observation0 = tf.constant( 2 * np.random.randint(2, size=(self.batch_size, 1)) - 1, dtype=tf.float32) if obs_dim > 1: self._observation0 = tf.concat([ self._observation0, tf.ones((self.batch_size, obs_dim - 1)) ], axis=-1) step_type = StepType.FIRST elif s == self._episode_length - 1: step_type = StepType.LAST discount = 0.0 if s <= self._gap: reward = tf.constant([0.] * self.batch_size) else: obs0 = tf.reshape(tf.cast(self._observation0[:, 0], tf.int64), shape=(self.batch_size, 1)) reward = tf.cast(tf.equal(action * 2 - 1, obs0), tf.float32) reward = tf.reshape(reward, shape=(self.batch_size, )) if s == 0: observation = self._observation0 else: observation = tf.zeros((self.batch_size, obs_dim)) return TimeStep(step_type=tf.constant([step_type] * self.batch_size), reward=reward, discount=tf.constant([discount] * self.batch_size), observation=observation)
def run( self, time_step: ts.TimeStep, policy_state: types.NestedArray = () ) -> Tuple[ts.TimeStep, types.NestedArray]: """Run policy in environment given initial time_step and policy_state. Args: time_step: The initial time_step. policy_state: The initial policy_state. Returns: A tuple (final time_step, final policy_state). """ num_steps = 0 num_episodes = 0 while num_steps < self._max_steps and num_episodes < self._max_episodes: # For now we reset the policy_state for non batched envs. if not self.env.batched and time_step.is_first() and num_episodes > 0: policy_state = self._policy.get_initial_state(self.env.batch_size or 1) action_step = self.policy.action(time_step, policy_state) next_time_step = self.env.step(action_step.action) traj = trajectory.from_transition(time_step, action_step, next_time_step) for observer in self._transition_observers: observer((time_step, action_step, next_time_step)) for observer in self.observers: observer(traj) num_episodes += np.sum(traj.is_boundary()) num_steps += np.sum(~traj.is_boundary()) time_step = next_time_step policy_state = action_step.state return time_step, policy_state
def _reset(self): self._states = np.zeros((3, 3), np.int32) return TimeStep(StepType.FIRST, np.asarray(0.0, dtype=np.float32), self._discount, self._states)
def _step(self, action): """ Return predictions of next states for each member of the batch. :param action: A batch of actions (the batch size is the first dimension) :return: A batch of next state predictions in the form of a `TimeStep` object """ # Make sure that action shape is as expected batch_size = get_outer_shape(action, self._transition_model.action_space_spec) assert batch_size == self._batch_size # Get observation from current time step observation = self._time_step.observation # Identify observation batch elements in the previous time step that have terminated. Note # the conversion to numpy is for performance reasons is_last = self._time_step.is_last() is_any_last = any(is_last.numpy()) # Elements of the observation batch that terminated on the previous time step require reset. if is_any_last: # Identify number of elements to be reset and their corresponding indexes number_resets = tf.math.count_nonzero(is_last) reset_indexes = tf.where(is_last) # Sample reset observations from initial state distribution reset_observation = self._initial_state_distribution_model.sample( (number_resets, )) # Raise error when any terminal observations are left after re-initialization self._ensure_no_terminal_observations(reset_observation) # Get batches of next observations, update observations that were reset next_observation = self._transition_model.step(observation, action) if is_any_last: next_observation = tf.tensor_scatter_nd_update( next_observation, reset_indexes, reset_observation) # Get batches of rewards, set rewards from reset batch elements to 0 reward = self._reward_model.step_reward(observation, action, next_observation) if is_any_last: reward = tf.where(condition=is_last, x=tf.constant(0.0), y=reward) # Get batches of termination flags has_terminated = self._termination_model.terminates(next_observation) # Get batches of step types, set step types from reset batch elements to FIRST step_type = tf.where(condition=has_terminated, x=StepType.LAST, y=StepType.MID) if is_any_last: step_type = tf.where(condition=is_last, x=StepType.FIRST, y=step_type) # Get batches of discounts, set discounts from reset batch elements to 1 discount = tf.where(condition=has_terminated, x=tf.constant(0.0), y=tf.constant(1.0)) if is_any_last: discount = tf.where(condition=is_last, x=tf.constant(1.0), y=discount) # Create TimeStep object and return self._time_step = TimeStep(step_type, reward, discount, next_observation) return self._time_step