def construct_episodes(actions, rewards, **kwargs): """Constructs episodes from actions and rewards nested lists. Args: actions (list): Each episode actions, example: [ [a00, a01, a02, ...], # Actions in the first episode. [a10, a11, a12, ...], # Actions in the second episode. ... ] rewards (list): Each episode rewards, example: [ [r00, r01, r02, ...], # Rewards in the first episode. [r10, r11, r12, ...], # Rewards in the second episode. ... ] **kwargs (dict): Keyword arguments passed to Episode. Return: list of Episodes where: - Transition observations and next observations are set to None. - Done flag is True only for the last transition in the episode. - Episode.return_ is calculated as an undiscounted sum of rewards. """ episodes = [] for acts, rews in zip(actions, rewards): transitions = [ data.Transition(None, act, rew, False, None, {}, {}) for act, rew in zip(acts[:-1], rews[:-1]) ] transitions.append( data.Transition(None, acts[-1], rews[-1], True, None, {}, {})) transition_batch = data.nested_stack(transitions) episodes.append(data.Episode(transition_batch, sum(rews), **kwargs)) return episodes
def construct_episodes(actions, rewards): """Constructs episodes from actions and rewards nested lists.""" episodes = [] for acts, rews in zip(actions, rewards): transitions = [ # TODO(koz4k): Initialize using kwargs. data.Transition(None, act, rew, False, None, {}) for act, rew in zip(acts[:-1], rews[:-1]) ] transitions.append( data.Transition(None, acts[-1], rews[-1], True, None, {})) transition_batch = data.nested_stack(transitions) episodes.append(data.Episode(transition_batch, sum(rews))) return episodes
def solve(self, env, epoch=None, init_state=None, time_limit=None): """Solves a given environment using OnlineAgent.act(). Args: env (gym.Env): Environment to solve. epoch (int): Current training epoch or None if no training. init_state (object): Reset the environment to this state. If None, then do normal gym.Env.reset(). time_limit (int or None): Maximum number of steps to make on the solved environment. None means no time limit. Yields: Network-dependent: A stream of Network inputs requested for inference. Returns: data.Episode: Episode object containing a batch of collected transitions and the return for the episode. """ yield from super().solve(env, epoch, init_state, time_limit) self._epoch = epoch model_env = env if time_limit is not None: # Add the TimeLimitWrapper _after_ passing the model env to the # agent, so the states cloned/restored by the agent do not contain # the number of steps made so far - this would break state lookup # in some Agents. env = envs.TimeLimitWrapper(env, time_limit) if init_state is None: # Model-free case... full_observation = env.reset() observation = np.concatenate([ full_observation['observation'], full_observation['desired_goal'] ], axis=-1) else: # Model-based case... observation = env.restore_state(init_state) # print('init observation', observation) yield from self.reset(model_env, observation) #for x in self.reset(model_env, observation): ##print(x) #yield np.concatenate([x['observation'], x['desired_goal']], axis=-1) transitions = [] done = False info = {} places = {tuple(observation.flatten())} while not done: # Forward network prediction requests to BatchStepper. # print("solving...") #print(observation) (action, agent_info) = yield from self.act(observation) # print("has action!") # TODO (full_next_observation, reward, done, info) = env.step(action) next_observation = np.concatenate([ full_next_observation['observation'], full_next_observation['desired_goal'] ], axis=-1) places.add(tuple(next_observation.flatten())) transitions.append( data.Transition( observation=full_observation, action=action, reward=reward, done=done, next_observation=full_next_observation, agent_info=agent_info, )) full_observation = full_next_observation observation = next_observation return_ = sum(transition.reward for transition in transitions) transitions = self.postprocess_transitions(transitions) solved = info['solved'] if 'solved' in info else None truncated = (info['TimeLimit.truncated'] if 'TimeLimit.truncated' in info else None) transition_batch = data.nested_stack(transitions) info = {'move_diversity': len(places)} # neptune_logger('move diversity', len(places)) # sys.exit(0) return data.Episode( transition_batch=transition_batch, return_=return_, solved=solved, truncated=truncated, info=info, )
def solve(self, env, epoch=None, init_state=None, time_limit=None): """Solves a given environment using OnlineAgent.act(). Args: env (gym.Env): Environment to solve. epoch (int): Current training epoch or None if no training. init_state (object): Reset the environment to this state. If None, then do normal gym.Env.reset(). time_limit (int or None): Maximum number of steps to make on the solved environment. None means no time limit. Yields: Network-dependent: A stream of Network inputs requested for inference. Returns: data.Episode: Episode object containing a batch of collected transitions and the return for the episode. """ yield from super().solve(env, epoch, init_state, time_limit) self._epoch = epoch model_env = env if time_limit is not None: # Add the TimeLimitWrapper _after_ passing the model env to the # agent, so the states cloned/restored by the agent do not contain # the number of steps made so far - this would break state lookup # in some Agents. env = envs.TimeLimitWrapper(env, time_limit) if init_state is None: # Model-free case... observation = env.reset() else: # Model-based case... observation = env.restore_state(init_state) yield from self.reset(model_env, observation) for callback in self._callbacks: callback.on_episode_begin(env, observation, epoch) transitions = [] done = False info = {} while not done: # Forward network prediction requests to BatchStepper. (action, agent_info) = yield from self.act(observation) (next_observation, reward, done, info) = env.step(action) for callback in self._callbacks: callback.on_real_step(agent_info, action, next_observation, reward, done) transitions.append( data.Transition( observation=observation, action=action, reward=reward, done=done, next_observation=next_observation, agent_info=agent_info, )) observation = next_observation for callback in self._callbacks: callback.on_episode_end() transitions = self.postprocess_transitions(transitions) return_ = sum(transition.reward for transition in transitions) solved = info['solved'] if 'solved' in info else None truncated = (info['TimeLimit.truncated'] if 'TimeLimit.truncated' in info else None) transition_batch = data.nested_stack(transitions) additional_info = info[ 'additional_info'] if 'additional_info' in info else None return data.Episode(transition_batch=transition_batch, return_=return_, solved=solved, truncated=truncated, additional_info=additional_info)
def solve(self, env, epoch=None, init_state=None, time_limit=None): yield from super().solve(env, epoch, init_state, time_limit) self._epoch = epoch model_env = env if time_limit is not None: env = envs.TimeLimitWrapper(env, time_limit) if init_state is None: observation = env.reset() else: observation = env.restore_state(init_state) yield from self.reset(model_env, observation) for callback in self._callbacks: callback.on_episode_begin(env, observation, epoch) transitions = [] done = False info = {} while not done: (action, agent_info) = yield from self.act(observation) (next_observation, reward, done, info) = env.step(action) for callback in self._callbacks: callback.on_real_step(agent_info, action, next_observation, reward, done) transitions.append( data.Transition( observation=observation, action=action, reward=reward, done=done, next_observation=next_observation, agent_info=agent_info, )) observation = next_observation for callback in self._callbacks: callback.on_episode_end() transitions = self.postprocess_transitions(transitions) return_ = sum(transition.reward for transition in transitions) solved = info['solved'] if 'solved' in info else None truncated = (info['TimeLimit.truncated'] if 'TimeLimit.truncated' in info else None) transition_batch = data.nested_stack(transitions) action_space_size = space.max_size(model_env.action_space) return data.Episode(transition_batch=transition_batch, return_=return_, solved=solved, truncated=truncated, action_space_size=action_space_size)