def preprocess_samples_discrete( self, samples: Samples, minibatch_size: int, one_hot_action: bool = True) -> List[TrainingDataPage]: logger.info("Shuffling...") samples.shuffle() logger.info("Preprocessing...") net = core.Net("gridworld_preprocessing") C2.set_net(net) preprocessor = PreprocessorNet(True) saa = StackedAssociativeArray.from_dict_list(samples.states, "states") state_matrix, _ = preprocessor.normalize_sparse_matrix( saa.lengths, saa.keys, saa.values, self.normalization, "state_norm", False, False, False, ) saa = StackedAssociativeArray.from_dict_list(samples.next_states, "next_states") next_state_matrix, _ = preprocessor.normalize_sparse_matrix( saa.lengths, saa.keys, saa.values, self.normalization, "next_state_norm", False, False, False, ) workspace.RunNetOnce(net) logger.info("Converting to Torch...") actions_one_hot = torch.tensor((np.array(samples.actions).reshape( -1, 1) == np.array(self.ACTIONS)).astype(np.int64)) actions = actions_one_hot.argmax(dim=1, keepdim=True) rewards = torch.tensor(samples.rewards, dtype=torch.float32).reshape(-1, 1) action_probabilities = torch.tensor(samples.action_probabilities, dtype=torch.float32).reshape( -1, 1) next_actions_one_hot = torch.tensor( (np.array(samples.next_actions).reshape(-1, 1) == np.array( self.ACTIONS)).astype(np.int64)) logger.info("Converting PNA to Torch...") possible_next_action_strings = np.array( list( itertools.zip_longest(*samples.possible_next_actions, fillvalue=""))).T possible_next_actions_mask = torch.zeros( [len(samples.next_actions), len(self.ACTIONS)]) for i, action in enumerate(self.ACTIONS): possible_next_actions_mask[:, i] = torch.tensor( np.max(possible_next_action_strings == action, axis=1).astype(np.int64)) terminals = torch.tensor(samples.terminals, dtype=torch.int32).reshape(-1, 1) not_terminals = 1 - terminals episode_values = None logger.info("Converting RT to Torch...") episode_values = torch.tensor(samples.episode_values, dtype=torch.float32).reshape(-1, 1) time_diffs = torch.ones([len(samples.states), 1]) logger.info("Preprocessing...") preprocessor = Preprocessor(self.normalization, False) states_ndarray = workspace.FetchBlob(state_matrix) states_ndarray = preprocessor.forward(states_ndarray) next_states_ndarray = workspace.FetchBlob(next_state_matrix) next_states_ndarray = preprocessor.forward(next_states_ndarray) logger.info("Batching...") tdps = [] for start in range(0, states_ndarray.shape[0], minibatch_size): end = start + minibatch_size if end > states_ndarray.shape[0]: break tdp = TrainingDataPage( states=states_ndarray[start:end], actions=actions_one_hot[start:end] if one_hot_action else actions[start:end], propensities=action_probabilities[start:end], rewards=rewards[start:end], next_states=next_states_ndarray[start:end], not_terminals=not_terminals[start:end], next_actions=next_actions_one_hot[start:end], possible_next_actions=possible_next_actions_mask[start:end], episode_values=episode_values[start:end] if episode_values is not None else None, time_diffs=time_diffs[start:end], ) tdp.set_type(torch.FloatTensor) tdps.append(tdp) return tdps
def preprocess_samples_discrete( self, samples: Samples, minibatch_size: int, one_hot_action: bool = True) -> List[TrainingDataPage]: samples.shuffle() net = core.Net("gridworld_preprocessing") C2.set_net(net) preprocessor = PreprocessorNet(True) saa = StackedAssociativeArray.from_dict_list(samples.states, "states") state_matrix, _ = preprocessor.normalize_sparse_matrix( saa.lengths, saa.keys, saa.values, self.normalization, "state_norm", False, False, False, ) saa = StackedAssociativeArray.from_dict_list(samples.next_states, "next_states") next_state_matrix, _ = preprocessor.normalize_sparse_matrix( saa.lengths, saa.keys, saa.values, self.normalization, "next_state_norm", False, False, False, ) workspace.RunNetOnce(net) actions_one_hot = np.zeros( [len(samples.actions), len(self.ACTIONS)], dtype=np.float32) for i, action in enumerate(samples.actions): actions_one_hot[i, self.action_to_index(action)] = 1 actions = np.array( [self.action_to_index(action) for action in samples.actions], dtype=np.int64) rewards = np.array(samples.rewards, dtype=np.float32).reshape(-1, 1) propensities = np.array(samples.propensities, dtype=np.float32).reshape(-1, 1) next_actions_one_hot = np.zeros( [len(samples.next_actions), len(self.ACTIONS)], dtype=np.float32) for i, action in enumerate(samples.next_actions): if action == "": continue next_actions_one_hot[i, self.action_to_index(action)] = 1 possible_next_actions_mask = [] for pna in samples.possible_next_actions: pna_mask = [0] * self.num_actions for action in pna: pna_mask[self.action_to_index(action)] = 1 possible_next_actions_mask.append(pna_mask) possible_next_actions_mask = np.array(possible_next_actions_mask, dtype=np.float32) terminals = np.array(samples.terminals, dtype=np.bool).reshape(-1, 1) not_terminals = np.logical_not(terminals) episode_values = None if samples.reward_timelines is not None: episode_values = np.zeros(rewards.shape, dtype=np.float32) for i, reward_timeline in enumerate(samples.reward_timelines): for time_diff, reward in reward_timeline.items(): episode_values[i, 0] += reward * (DISCOUNT**time_diff) preprocessor = Preprocessor(self.normalization, False) states_ndarray = workspace.FetchBlob(state_matrix) states_ndarray = preprocessor.forward(states_ndarray).numpy() next_states_ndarray = workspace.FetchBlob(next_state_matrix) next_states_ndarray = preprocessor.forward(next_states_ndarray).numpy() time_diffs = np.ones(len(states_ndarray)) tdps = [] for start in range(0, states_ndarray.shape[0], minibatch_size): end = start + minibatch_size if end > states_ndarray.shape[0]: break tdps.append( TrainingDataPage( states=states_ndarray[start:end], actions=actions_one_hot[start:end] if one_hot_action else actions[start:end], propensities=propensities[start:end], rewards=rewards[start:end], next_states=next_states_ndarray[start:end], not_terminals=not_terminals[start:end], next_actions=next_actions_one_hot[start:end], possible_next_actions=possible_next_actions_mask[ start:end], episode_values=episode_values[start:end] if episode_values is not None else None, time_diffs=time_diffs[start:end], )) return tdps