def store_episode(self, episode_batch, update_stats=True): """ Story the episode transitions :param episode_batch: (numpy Number) array of batch_size x (T or T+1) x dim_key 'o' is of size T+1, others are of size T :param update_stats: (bool) whether to update stats or not """ self.buffer.store_episode(episode_batch) if update_stats: # add transitions to normalizer episode_batch['o_2'] = episode_batch['o'][:, 1:, :] episode_batch['ag_2'] = episode_batch['ag'][:, 1:, :] num_normalizing_transitions = transitions_in_episode_batch( episode_batch) transitions = self.sample_transitions(episode_batch, num_normalizing_transitions) obs, _, goal, achieved_goal = transitions['o'], transitions[ 'o_2'], transitions['g'], transitions['ag'] transitions['o'], transitions['g'] = self._preprocess_obs_goal( obs, achieved_goal, goal) # No need to preprocess the o_2 and g_2 since this is only used for stats self.o_stats.update(transitions['o']) self.g_stats.update(transitions['g']) self.o_stats.recompute_stats() self.g_stats.recompute_stats()
def store_episode(self, episode_batch, update_stats=True, if_clear_buffer_first=False): """ episode_batch: array of batch_size x (T or T+1) x dim_key 'o' is of size T+1, others are of size T """ # if if_clear_buffer_first: # self.buffer.clear_buffer() self.buffer.store_episode(episode_batch) if update_stats: # add transitions to normalizer episode_batch['o_2'] = episode_batch['o'][:, 1:, :] # what are these for? episode_batch['ag_2'] = episode_batch['ag'][:, 1:, :] # ag2 will be used for computation of rewards num_normalizing_transitions = transitions_in_episode_batch(episode_batch) # Do KER, GER and retrieve augmented batch transitions = self.sample_transitions(episode_batch, num_normalizing_transitions, env_name = self.env_name, n_GER = self.n_GER, err_distance = self.err_distance) # Preprocess: Clip observations and goals. **Many qs. How can we clip obs that are so different with 1 num. Curr 200 too large for any of our numbers. o, g, ag = transitions['o'], transitions['g'], transitions['ag'] transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g) # No need to preprocess the o_2 and g_2 since this is only used for stats # Get sums and sum-square data for each col. Saved in self.local self.o_stats.update(transitions['o']) # Normalizer object self.g_stats.update(transitions['g']) # Subtract mean and set std. dev to 1 self.o_stats.recompute_stats() self.g_stats.recompute_stats()
def store_episode(self, episode_batch, update_stats=True): """ episode_batch: array of batch_size x (T or T+1) x dim_key 'o' is of size T+1, others are of size T """ self.buffer.store_episode(episode_batch) if update_stats: # add transitions to normalizer episode_batch['o_2'] = episode_batch['o'][:, 1:, :] episode_batch['ag_2'] = episode_batch['ag'][:, 1:, :] num_normalizing_transitions = transitions_in_episode_batch( episode_batch) transitions = self.sample_transitions(episode_batch, num_normalizing_transitions) o, o_2, g, ag = transitions['o'], transitions['o_2'], transitions[ 'g'], transitions['ag'] transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g) # No need to preprocess the o_2 and g_2 since this is only used for stats self.o_stats.update(transitions['o']) self.g_stats.update(transitions['g']) self.o_stats.recompute_stats() self.g_stats.recompute_stats()
def store_episode(self, episode_batch, update_stats=True, if_clear_buffer_first=False): """ episode_batch: array of batch_size x (T or T+1) x dim_key 'o' is of size T+1, others are of size T """ # if if_clear_buffer_first: # self.buffer.clear_buffer() self.buffer.store_episode(episode_batch) if update_stats: # add transitions to normalizer episode_batch['o_2'] = episode_batch['o'][:, 1:, :] episode_batch['ag_2'] = episode_batch['ag'][:, 1:, :] num_normalizing_transitions = transitions_in_episode_batch( episode_batch) transitions = self.sample_transitions( episode_batch, num_normalizing_transitions, env_name=self.env_name, n_GER=self.n_GER, err_distance=self.err_distance) o, g, ag = transitions['o'], transitions['g'], transitions['ag'] transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g) # No need to preprocess the o_2 and g_2 since this is only used for stats self.o_stats.update(transitions['o']) self.g_stats.update(transitions['g']) self.o_stats.recompute_stats() self.g_stats.recompute_stats()
def store_episode(self, episode_batch, update_stats=True): """ episode_batch: array of batch_size x (T or T+1) x dim_key 'o' is of size T+1, others are of size T """ # update the mutual information reward into the episode batch episode_batch['m'] = np.empty([episode_batch['o'].shape[0], 1]) episode_batch['s'] = np.empty([episode_batch['o'].shape[0], 1]) # # self.buffer.store_episode(episode_batch, self) if update_stats: # add transitions to normalizer episode_batch['o_2'] = episode_batch['o'][:, 1:, :] episode_batch['ag_2'] = episode_batch['ag'][:, 1:, :] num_normalizing_transitions = transitions_in_episode_batch( episode_batch) transitions = self.sample_transitions(self, False, episode_batch, num_normalizing_transitions, 0, 0, 0) o, o_2, g, ag = transitions['o'], transitions['o_2'], transitions[ 'g'], transitions['ag'] transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g) self.o_stats.update(transitions['o']) self.g_stats.update(transitions['g']) self.o_stats.recompute_stats() self.g_stats.recompute_stats()
def ddpg_store_episode(self, episode_batch, dump_buffer, w_potential, w_linear, w_rotational, rank_method, clip_energy, update_stats=True): """ episode_batch: array of batch_size x (T or T+1) x dim_key 'o' is of size T+1, others are of size T """ # if self.prioritization == 'tderror': # self.buffer.store_episode(episode_batch, dump_buffer) # print("DDPG BEGIN STORE episode") if self.prioritization == 'energy': self.buffer.store_episode(episode_batch, w_potential, w_linear, w_rotational, rank_method, clip_energy) else: self.buffer.store_episode(episode_batch) # print("DDPG END STORE episode") if update_stats: # add transitions to normalizer episode_batch['o_2'] = episode_batch['o'][:, 1:, :] episode_batch['ag_2'] = episode_batch['ag'][:, 1:, :] num_normalizing_transitions = transitions_in_episode_batch( episode_batch) # print("START ddpg sample transition") # n_cycles calls HER sampler if self.prioritization == 'energy': if not self.buffer.current_size == 0 and not len( episode_batch['ag']) == 0: transitions = self.sample_transitions( episode_batch, num_normalizing_transitions, 'none', 1.0, self.sample_count, self.cycle_count, True) # elif self.prioritization == 'tderror': # transitions, weights, episode_idxs = \ # self.sample_transitions(self.buffer, episode_batch, num_normalizing_transitions, beta=0) else: transitions = self.sample_transitions( episode_batch, num_normalizing_transitions) # print("END ddpg sample transition") # print("DDPG END STORE episode 2") o, g, ag = transitions['o'], transitions['g'], transitions['ag'] transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g) # No need to preprocess the o_2 and g_2 since this is only used for stats self.o_stats.update(transitions['o']) self.g_stats.update(transitions['g']) self.o_stats.recompute_stats() self.g_stats.recompute_stats()
def init_demo_buffer(self, demoDataFile, update_stats=True): # function that initializes the demo buffer demoData = np.load(demoDataFile) # load the demonstration data from data file info_keys = [key.replace('info_', '') for key in self.input_dims.keys() if key.startswith('info_')] info_values = [np.empty((self.T - 1, 1, self.input_dims['info_' + key]), np.float32) for key in info_keys] demo_data_obs = demoData['obs'] demo_data_acs = demoData['acs'] demo_data_info = demoData['info'] for epsd in range(self.num_demo): # we initialize the whole demo buffer at the start of the training obs, acts, goals, achieved_goals = [], [], [], [] i = 0 for transition in range(self.T - 1): obs.append([demo_data_obs[epsd][transition].get('observation')]) acts.append([demo_data_acs[epsd][transition]]) goals.append([demo_data_obs[epsd][transition].get('desired_goal')]) achieved_goals.append([demo_data_obs[epsd][transition].get('achieved_goal')]) for idx, key in enumerate(info_keys): info_values[idx][transition, i] = demo_data_info[epsd][transition][key] obs.append([demo_data_obs[epsd][self.T - 1].get('observation')]) achieved_goals.append([demo_data_obs[epsd][self.T - 1].get('achieved_goal')]) episode = dict(o=obs, u=acts, g=goals, ag=achieved_goals) for key, value in zip(info_keys, info_values): episode['info_{}'.format(key)] = value episode = convert_episode_to_batch_major(episode) global DEMO_BUFFER DEMO_BUFFER.store_episode( episode) # create the observation dict and append them into the demonstration buffer logger.debug("Demo buffer size currently ", DEMO_BUFFER.get_current_size()) # print out the demonstration buffer size if update_stats: # add transitions to normalizer to normalize the demo data as well episode['o_2'] = episode['o'][:, 1:, :] episode['ag_2'] = episode['ag'][:, 1:, :] num_normalizing_transitions = transitions_in_episode_batch(episode) transitions = self.sample_transitions(episode, num_normalizing_transitions) o, g, ag = transitions['o'], transitions['g'], transitions['ag'] transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g) # No need to preprocess the o_2 and g_2 since this is only used for stats self.o_stats.update(transitions['o']) self.g_stats.update(transitions['g']) self.o_stats.recompute_stats() self.g_stats.recompute_stats() episode.clear() logger.info("Demo buffer size: ", DEMO_BUFFER.get_current_size()) # print out the demonstration buffer size
def init_demo_buffer(self, demoDataFile, update_stats=True): #function that initializes the demo buffer demoData = np.load(demoDataFile) #load the demonstration data from data file info_keys = [key.replace('info_', '') for key in self.input_dims.keys() if key.startswith('info_')] info_values = [np.empty((self.T - 1, 1, self.input_dims['info_' + key]), np.float32) for key in info_keys] demo_data_obs = demoData['obs'] demo_data_acs = demoData['acs'] demo_data_info = demoData['info'] for epsd in range(self.num_demo): # we initialize the whole demo buffer at the start of the training obs, acts, goals, achieved_goals = [], [] ,[] ,[] i = 0 for transition in range(self.T - 1): obs.append([demo_data_obs[epsd][transition].get('observation')]) acts.append([demo_data_acs[epsd][transition]]) goals.append([demo_data_obs[epsd][transition].get('desired_goal')]) achieved_goals.append([demo_data_obs[epsd][transition].get('achieved_goal')]) for idx, key in enumerate(info_keys): info_values[idx][transition, i] = demo_data_info[epsd][transition][key] obs.append([demo_data_obs[epsd][self.T - 1].get('observation')]) achieved_goals.append([demo_data_obs[epsd][self.T - 1].get('achieved_goal')]) episode = dict(o=obs, u=acts, g=goals, ag=achieved_goals) for key, value in zip(info_keys, info_values): episode['info_{}'.format(key)] = value episode = convert_episode_to_batch_major(episode) global DEMO_BUFFER DEMO_BUFFER.store_episode(episode) # create the observation dict and append them into the demonstration buffer logger.debug("Demo buffer size currently ", DEMO_BUFFER.get_current_size()) #print out the demonstration buffer size if update_stats: # add transitions to normalizer to normalize the demo data as well episode['o_2'] = episode['o'][:, 1:, :] episode['ag_2'] = episode['ag'][:, 1:, :] num_normalizing_transitions = transitions_in_episode_batch(episode) transitions = self.sample_transitions(episode, num_normalizing_transitions) o, g, ag = transitions['o'], transitions['g'], transitions['ag'] transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g) # No need to preprocess the o_2 and g_2 since this is only used for stats self.o_stats.update(transitions['o']) self.g_stats.update(transitions['g']) self.o_stats.recompute_stats() self.g_stats.recompute_stats() episode.clear() logger.info("Demo buffer size: ", DEMO_BUFFER.get_current_size()) #print out the demonstration buffer size
def store_episode(self, episode_batch, dump_buffer, w_potential, w_linear, w_rotational, rank_method, clip_energy, update_stats=True): """ episode_batch: array of batch_size x (T or T+1) x dim_key 'o' is of size T+1, others are of size T """ if self.prioritization == 'tderror': self.buffer.store_episode(episode_batch, dump_buffer) elif self.prioritization == 'energy': self.buffer.store_episode(episode_batch, w_potential, w_linear, w_rotational, rank_method, clip_energy) else: self.buffer.store_episode(episode_batch) if update_stats: # add transitions to normalizer episode_batch['o_2'] = episode_batch['o'][:, 1:, :] episode_batch['ag_2'] = episode_batch['ag'][:, 1:, :] num_normalizing_transitions = transitions_in_episode_batch( episode_batch) if self.prioritization == 'energy': if not self.buffer.current_size == 0 and not len( episode_batch['ag']) == 0: transitions = self.sample_transitions( episode_batch, num_normalizing_transitions, 'none', 1.0, True) elif self.prioritization == 'tderror': transitions, weights, episode_idxs = \ self.sample_transitions(self.buffer, episode_batch, num_normalizing_transitions, beta=0) else: transitions = self.sample_transitions( episode_batch, num_normalizing_transitions) o, o_2, g, ag = transitions['o'], transitions['o_2'], transitions[ 'g'], transitions['ag'] transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g) self.o_stats.update(transitions['o']) self.g_stats.update(transitions['g']) self.o_stats.recompute_stats() self.g_stats.recompute_stats()
def store_episode(self, episode_batch, update_stats=True): """ episode_batch: array of batch_size x (T or T+1) x dim_key 'o' is of size T+1, others are of size T """ if episode_batch is None: return self.buffer.store_episode(episode_batch) if update_stats: ################################ # variance of successful goals # success = episode_batch['info_is_success'][:, -1, :] # g = episode_batch['g'][:][:, -1, :] # # successful_g = (success * g) # successful_g = g[np.where(success.squeeze())[0]] # success_rate = len(successful_g) / len(success) # if len(successful_g) == 0: # current_success_measure = 0 # else: # current_success_measure = success_rate * successful_g.std() # self.success_measure = 0.99 * self.success_measure + 0.01 * current_success_measure # if len(successful_g) > 0: # self.successful_goals += successful_g.tolist() # self.successful_g_stats.update(successful_g) # self.successful_g_stats.recompute_stats() ################################ # add transitions to normalizer episode_batch['o_2'] = episode_batch['o'][:, 1:, :] episode_batch['ag_2'] = episode_batch['ag'][:, 1:, :] num_normalizing_transitions = transitions_in_episode_batch( episode_batch) transitions = self.sample_transitions(episode_batch, num_normalizing_transitions) o, g, ag = transitions['o'], transitions['g'], transitions['ag'] transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g) # No need to preprocess the o_2 and g_2 since this is only used for stats self.o_stats.update(transitions['o']) self.g_stats.update(transitions['g']) self.o_stats.recompute_stats() self.g_stats.recompute_stats()
def store_episode(self, episode_batch, update_stats=True): """ episode_batch: array of batch_size x (T or T+1) x dim_key 'o' is of size T+1, others are of size T """ # if self.use_Q: # u_2 = policy.get_actions(o=episode_batch['o'][:, 1:, :], ag=episode_batch['ag'][:, 1:, :], g=episode_batch['g']) # (batch_size x t x dimu) # self.buffer.store_episode({**episode_batch, 'u_2': u_2.reshape(episode_batch['u'].shape)}) # else: # self.buffer.store_episode(episode_batch) self.buffer.store_episode(episode_batch) if update_stats: # add transitions to normalizer # # flatten episode batch # o = episode_batch['o']#[:, :-1, :] # g = episode_batch['g']#[:, :-1, :] # ag = episode_batch['ag']#[:, :-1, :] # o = np.reshape(o, (-1, self.dimo)) # g = np.reshape(g, (-1, self.dimg)) # ag = np.reshape(ag, (-1, self.dimg)) # o, g = self._preprocess_og(o, ag, g) # # self.o_stats.update(o) # self.g_stats.update(g) # # self.o_stats.recompute_stats() # self.g_stats.recompute_stats() episode_batch['o_2'] = episode_batch['o'][:, 1:, :] episode_batch['ag_2'] = episode_batch['ag'][:, 1:, :] num_normalizing_transitions = transitions_in_episode_batch( episode_batch) transitions = self.sample_transitions(episode_batch, num_normalizing_transitions) o, g, ag = transitions['o'], transitions['g'], transitions['ag'] transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g) # No need to preprocess the o_2 and g_2 since this is only used for stats self.o_stats.update(transitions['o']) self.g_stats.update(transitions['g']) self.o_stats.recompute_stats() self.g_stats.recompute_stats()
def store_episode(self, episode_batch, update_stats=True): """ episode_batch: array of batch_size x (T or T+1) x dim_key 'o' is of size T+1, others are of size T """ self.buffer.store_episode(episode_batch) if update_stats: # add transitions to normalizer episode_batch['o_2'] = episode_batch['o'][:, 1:, :] episode_batch['ag_2'] = episode_batch['ag'][:, 1:, :] num_normalizing_transitions = transitions_in_episode_batch( episode_batch) transitions = self.sample_transitions(episode_batch, num_normalizing_transitions) o, o_2, g, ag = transitions['o'], transitions['o_2'], transitions[ 'g'], transitions['ag'] transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g) # No need to preprocess the o_2 and g_2 since this is only used for stats if 'Variation' in self.kwargs['info']['env_name']: o = transitions['o'][:, 1:] # o = np.concatenate([transitions['o'][:,:ENV_FEATURES], # transitions['o'][:,ENV_FEATURES+1:]], axis=1) else: o = transitions['o'] self.o_stats.update(o) self.G_stats.update(transitions['G']) self.sigma_stats.update(transitions['sigma']) # self.g_stats.update(transitions['g']) self.o_stats.recompute_stats() # self.g_stats.recompute_stats() self.G_stats.recompute_stats() self.sigma_stats.recompute_stats()
def store_episode(self, episode_batch, update_stats=True): """ episode_batch: array of batch_size x (T or T+1) x dim_key 'o' is of size T+1, others are of size T """ print('calling init of store_episode', episode_batch.keys()) #each episode_batch is size of [rollots_size, T, xxx] self.buffer.store_episode(episode_batch) print('update_stats of ddpg', update_stats) if update_stats: # add transitions to normalizer #remove the first frame of each batch episode_batch['o_2'] = episode_batch['o'][:, 1:, :] episode_batch['ag_2'] = episode_batch['ag'][:, 1:, :] print('episode_batch is updated', episode_batch['o'].shape, episode_batch['o'][:, 1:, :].shape) # get shape from episode_batch['u'].shape num_normalizing_transitions = transitions_in_episode_batch(episode_batch) print('num_normalizing_transitions of ddpg', num_normalizing_transitions) transitions = self.sample_transitions(episode_batch, num_normalizing_transitions) for key, value in transitions.items(): print('key, value of transitions in ddpg', key, value.shape) o, o_2, g, ag = transitions['o'], transitions['o_2'], transitions['g'], transitions['ag'] #clip values transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g) # No need to preprocess the o_2 and g_2 since this is only used for stats self.o_stats.update(transitions['o']) self.g_stats.update(transitions['g']) self.o_stats.recompute_stats() self.g_stats.recompute_stats()
def store_episode(self, episode_batch, update_stats=True): """ episode_batch: array of batch_size x (T or T+1) x dim_key 'o' is of size T+1, others are of size T """ ###### Remove the l value - Supposed to be a list of length 2 # First entry consists of transitions with actual goals and second is alternate goals self.buffer.store_episode(episode_batch) # ###### Debug # # This functions was used to check the hypothesis that if TD error is high # # for a state with some goal, it is high for that states with all other goals # self.debug_td_error_alternate_actual(debug_transitions) # Updating stats ## Change this-------------- update_stats = False ###-------------------------- if update_stats: # add transitions to normalizer episode_batch['o_2'] = episode_batch['o'][:, 1:, :] episode_batch['ag_2'] = episode_batch['ag'][:, 1:, :] num_normalizing_transitions = transitions_in_episode_batch(episode_batch) transitions = self.sample_transitions(episode_batch, num_normalizing_transitions) o, o_2, g, ag = transitions['o'], transitions['o_2'], transitions['g'], transitions['ag'] transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g) # No need to preprocess the o_2 and g_2 since this is only used for stats self.o_stats.update(transitions['o']) self.g_stats.update(transitions['g']) self.o_stats.recompute_stats() self.g_stats.recompute_stats()
def store_episode(self, episode_batch, update_stats=True): """ episode_batch: array of batch_size x (T or T+1) x dim_key 'o' is of size T+1, others are of size T """ self.buffer.store_episode(episode_batch) if update_stats: # add transitions to normalizer episode_batch['o_2'] = episode_batch['o'][:, 1:, :] episode_batch['ag_2'] = episode_batch['ag'][:, 1:, :] num_normalizing_transitions = transitions_in_episode_batch( episode_batch) transitions = self.sample_transitions(episode_batch, num_normalizing_transitions) o, o_2, g, ag = transitions['o'], transitions['o_2'], transitions[ 'g'], transitions['ag'] transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g) # No need to preprocess the o_2 and g_2 since this is only used for stats # If we are using the variation environment, then there is an extra dimension # in the observation that tells the agent how many blocks there are # we need to get rid of that while computing the normalized stats if 'Variation' in self.kwargs['info']['env_name']: o = transitions['o'][:, 1:] # o = np.concatenate([transitions['o'][:,:ENV_FEATURES], # transitions['o'][:,ENV_FEATURES+1:]], axis=1) else: o = transitions['o'] self.o_stats.update(o) # self.g_stats.update(transitions['g']) self.o_stats.recompute_stats()
def store_episode(self, episode_batch, update_stats=True): """ episode_batch: array of batch_size x (T or T+1) x dim_key 'o' is of size T+1, others are of size T """ self.buffer.store_episode(episode_batch) if update_stats: # add transitions to normalizer episode_batch['o_2'] = episode_batch['o'][:, 1:, :] episode_batch['ag_2'] = episode_batch['ag'][:, 1:, :] num_normalizing_transitions = transitions_in_episode_batch(episode_batch) transitions = self.sample_transitions(episode_batch, num_normalizing_transitions) o, g, ag = transitions['o'], transitions['g'], transitions['ag'] transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g) # No need to preprocess the o_2 and g_2 since this is only used for stats self.o_stats.update(transitions['o']) self.g_stats.update(transitions['g']) self.o_stats.recompute_stats() self.g_stats.recompute_stats()
def store_episode(self, episode_batch, cp, n_ep, update_stats=True): """ episode_batch: array of batch_size x (T or T+1) x dim_key 'o' is of size T+1, others are of size T """ # decompose episode_batch in episodes batch_size = episode_batch['ag'].shape[0] # addition in the case of curious goals, compute count of achieved goal that moved in the n modules self.cp = cp self.n_episodes = n_ep # addition for multi-task structures if self.structure == 'curious' or self.structure == 'task_experts': new_count_local = np.zeros([self.nb_tasks]) new_count_total = np.zeros([self.nb_tasks]) # add a new transition in a buffer only if the corresponding outcome has changed compare to the initial outcome for b in range(batch_size): active_tasks = [] for j in range(self.nb_tasks): if any(episode_batch['change'] [b, -1, self.tasks_ag_id[j][:len(self.tasks_g_id[j])]]): new_count_local[j] += 1 if self.nb_tasks < 5 or j < 5: active_tasks.append(j) MPI.COMM_WORLD.Allreduce(new_count_local, new_count_total, op=MPI.SUM) ep = dict() for key in episode_batch.keys(): ep[key] = episode_batch[key][b].reshape([ 1, episode_batch[key].shape[1], episode_batch[key].shape[2] ]) if 'buffer' in self.task_replay or self.task_replay == 'hand_designed': if len(active_tasks) == 0: ind_buffer = [0] else: for task in active_tasks: self.buffer[task + 1].store_episode(ep) else: self.buffer.store_episode(ep) elif self.structure == 'flat' or self.structure == 'task_experts': for b in range(batch_size): ep = dict() for key in episode_batch.keys(): ep[key] = episode_batch[key][b].reshape([ 1, episode_batch[key].shape[1], episode_batch[key].shape[2] ]) self.buffer.store_episode(ep) # update statistics for goal and observation normalizations if update_stats: # add transitions to normalizer episode_batch['o_2'] = episode_batch['o'][:, 1:, :] episode_batch['ag_2'] = episode_batch['ag'][:, 1:, :] num_normalizing_transitions = transitions_in_episode_batch( episode_batch) if self.structure == 'curious' or self.structure == 'task_experts': transitions = self.sample_transitions( episode_batch, num_normalizing_transitions, task_to_replay=None) else: transitions = self.sample_transitions( episode_batch, num_normalizing_transitions) o, o_2, g, ag = transitions['o'], transitions['o_2'], transitions[ 'g'], transitions['ag'] transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g) # No need to preprocess the o_2 and g_2 since this is only used for stats self.o_stats.update(transitions['o']) self.g_stats.update(transitions['g']) self.o_stats.recompute_stats() self.g_stats.recompute_stats()