def get_max_q_values( self, next_states: str, possible_next_actions: StackedArray, use_target_network: bool, ) -> str: """ Takes in an array of next_states and outputs an array of the same shape whose ith entry = max_{pna} Q(state_i, pna). Uses target network for Q(state_i, pna) approximation. :param next_states: Blob containing state features. Each row contains a representation of a state. :param possible_next_actions: List of sets of possible next actions. The ith element of this list is a matrix PNA_i such that PNA_i[j] is the parametric representation of the jth possible action from the ith next_state. These have not been normalized. """ stacked_states = C2.LengthsTile(next_states, possible_next_actions.lengths) all_q_values = self.get_q_values( stacked_states, possible_next_actions.values, use_target_network, ) max_q_values = C2.LengthsMax( all_q_values, possible_next_actions.lengths, ) return max_q_values
def concat_states_and_possible_next_actions( self, next_state_preprocessed_matrix_blob: str, possible_next_actions_blob: str, possible_next_actions_lengths_blob: str, ) -> str: stacked_states = C2.LengthsTile(next_state_preprocessed_matrix_blob, possible_next_actions_lengths_blob) state_action_pairs, _ = C2.Concat(stacked_states, possible_next_actions_blob, axis=1) return state_action_pairs
def preprocess_samples( self, samples: Samples, minibatch_size: int, use_gpu: bool = False, one_hot_action: bool = True, normalize_actions: bool = True, ) -> List[TrainingDataPage]: logger.info("Shuffling...") samples.shuffle() logger.info("Sparse2Dense...") net = core.Net("gridworld_preprocessing") C2.set_net(net) saa = StackedAssociativeArray.from_dict_list(samples.states, "states") sorted_state_features, _ = sort_features_by_normalization(self.normalization) state_matrix, _ = sparse_to_dense( saa.lengths, saa.keys, saa.values, sorted_state_features ) saa = StackedAssociativeArray.from_dict_list(samples.next_states, "next_states") next_state_matrix, _ = sparse_to_dense( saa.lengths, saa.keys, saa.values, sorted_state_features ) sorted_action_features, _ = sort_features_by_normalization( self.normalization_action ) saa = StackedAssociativeArray.from_dict_list(samples.actions, "action") action_matrix, _ = sparse_to_dense( saa.lengths, saa.keys, saa.values, sorted_action_features ) saa = StackedAssociativeArray.from_dict_list( samples.next_actions, "next_action" ) next_action_matrix, _ = sparse_to_dense( saa.lengths, saa.keys, saa.values, sorted_action_features ) action_probabilities = torch.tensor( samples.action_probabilities, dtype=torch.float32 ).reshape(-1, 1) rewards = torch.tensor(samples.rewards, dtype=torch.float32).reshape(-1, 1) pnas_lengths_list = [] pnas_flat: List[List[str]] = [] for pnas in samples.possible_next_actions: pnas_lengths_list.append(len(pnas)) pnas_flat.extend(pnas) saa = StackedAssociativeArray.from_dict_list(pnas_flat, "possible_next_actions") pnas_lengths = torch.tensor(pnas_lengths_list, dtype=torch.int32) pna_lens_blob = "pna_lens_blob" workspace.FeedBlob(pna_lens_blob, pnas_lengths.numpy()) possible_next_actions_matrix, _ = sparse_to_dense( saa.lengths, saa.keys, saa.values, sorted_action_features ) state_pnas_tile_blob = C2.LengthsTile(next_state_matrix, pna_lens_blob) workspace.RunNetOnce(net) logger.info("Preprocessing...") state_preprocessor = Preprocessor(self.normalization, False) action_preprocessor = Preprocessor(self.normalization_action, False) states_ndarray = workspace.FetchBlob(state_matrix) states_ndarray = state_preprocessor.forward(states_ndarray) actions_ndarray = torch.from_numpy(workspace.FetchBlob(action_matrix)) if normalize_actions: actions_ndarray = action_preprocessor.forward(actions_ndarray) next_states_ndarray = workspace.FetchBlob(next_state_matrix) next_states_ndarray = state_preprocessor.forward(next_states_ndarray) next_actions_ndarray = torch.from_numpy(workspace.FetchBlob(next_action_matrix)) if normalize_actions: next_actions_ndarray = action_preprocessor.forward(next_actions_ndarray) logged_possible_next_actions = action_preprocessor.forward( workspace.FetchBlob(possible_next_actions_matrix) ) state_pnas_tile = state_preprocessor.forward( workspace.FetchBlob(state_pnas_tile_blob) ) logged_possible_next_state_actions = torch.cat( (state_pnas_tile, logged_possible_next_actions), dim=1 ) logger.info("Reward Timeline to Torch...") possible_next_actions_ndarray = logged_possible_next_actions possible_next_actions_state_concat = logged_possible_next_state_actions time_diffs = torch.ones([len(samples.states), 1]) tdps = [] pnas_start = 0 logger.info("Batching...") for start in range(0, states_ndarray.shape[0], minibatch_size): end = start + minibatch_size if end > states_ndarray.shape[0]: break pnas_end = pnas_start + torch.sum(pnas_lengths[start:end]) pnas = possible_next_actions_ndarray[pnas_start:pnas_end] pnas_concat = possible_next_actions_state_concat[pnas_start:pnas_end] pnas_start = pnas_end tdp = TrainingDataPage( states=states_ndarray[start:end], actions=actions_ndarray[start:end], propensities=action_probabilities[start:end], rewards=rewards[start:end], next_states=next_states_ndarray[start:end], next_actions=next_actions_ndarray[start:end], possible_next_actions=None, not_terminals=(pnas_lengths[start:end] > 0).reshape(-1, 1), time_diffs=time_diffs[start:end], possible_next_actions_lengths=pnas_lengths[start:end], possible_next_actions_state_concat=pnas_concat, ) tdp.set_type(torch.cuda.FloatTensor if use_gpu else torch.FloatTensor) tdps.append(tdp) return tdps
def preprocess_samples(self, samples: Samples, minibatch_size: int) -> List[TrainingDataPage]: samples.shuffle() net = core.Net("gridworld_preprocessing") C2.set_net(net) preprocessor = PreprocessorNet(True) saa = StackedAssociativeArray.from_dict_list(samples.states, "states") state_matrix, _ = preprocessor.normalize_sparse_matrix( saa.lengths, saa.keys, saa.values, self.normalization, "state_norm", False, False, False, ) saa = StackedAssociativeArray.from_dict_list(samples.next_states, "next_states") next_state_matrix, _ = preprocessor.normalize_sparse_matrix( saa.lengths, saa.keys, saa.values, self.normalization, "next_state_norm", False, False, False, ) saa = StackedAssociativeArray.from_dict_list(samples.actions, "action") action_matrix, _ = preprocessor.normalize_sparse_matrix( saa.lengths, saa.keys, saa.values, self.normalization_action, "action_norm", False, False, False, ) saa = StackedAssociativeArray.from_dict_list(samples.next_actions, "next_action") next_action_matrix, _ = preprocessor.normalize_sparse_matrix( saa.lengths, saa.keys, saa.values, self.normalization_action, "next_action_norm", False, False, False, ) propensities = np.array(samples.propensities, dtype=np.float32).reshape(-1, 1) rewards = np.array(samples.rewards, dtype=np.float32).reshape(-1, 1) pnas_lengths_list = [] pnas_flat: List[List[str]] = [] for pnas in samples.possible_next_actions: pnas_lengths_list.append(len(pnas)) pnas_flat.extend(pnas) saa = StackedAssociativeArray.from_dict_list(pnas_flat, "possible_next_actions") pnas_lengths = np.array(pnas_lengths_list, dtype=np.int32) pna_lens_blob = "pna_lens_blob" workspace.FeedBlob(pna_lens_blob, pnas_lengths) possible_next_actions_matrix, _ = preprocessor.normalize_sparse_matrix( saa.lengths, saa.keys, saa.values, self.normalization_action, "possible_next_action_norm", False, False, False, ) state_pnas_tile_blob = C2.LengthsTile(next_state_matrix, pna_lens_blob) workspace.RunNetOnce(net) state_preprocessor = Preprocessor(self.normalization, False) action_preprocessor = Preprocessor(self.normalization_action, False) states_ndarray = workspace.FetchBlob(state_matrix) states_ndarray = state_preprocessor.forward(states_ndarray).numpy() actions_ndarray = workspace.FetchBlob(action_matrix) actions_ndarray = action_preprocessor.forward(actions_ndarray).numpy() next_states_ndarray = workspace.FetchBlob(next_state_matrix) next_states_ndarray = state_preprocessor.forward( next_states_ndarray).numpy() next_actions_ndarray = workspace.FetchBlob(next_action_matrix) next_actions_ndarray = action_preprocessor.forward( next_actions_ndarray).numpy() logged_possible_next_actions = action_preprocessor.forward( workspace.FetchBlob(possible_next_actions_matrix)) state_pnas_tile = state_preprocessor.forward( workspace.FetchBlob(state_pnas_tile_blob)) logged_possible_next_state_actions = torch.cat( (state_pnas_tile, logged_possible_next_actions), dim=1) possible_next_actions_ndarray = logged_possible_next_actions.cpu( ).numpy() next_state_pnas_concat = logged_possible_next_state_actions.cpu( ).numpy() time_diffs = np.ones(len(states_ndarray)) episode_values = None if samples.reward_timelines is not None: episode_values = np.zeros(rewards.shape, dtype=np.float32) for i, reward_timeline in enumerate(samples.reward_timelines): for time_diff, reward in reward_timeline.items(): episode_values[i, 0] += reward * (DISCOUNT**time_diff) tdps = [] pnas_start = 0 for start in range(0, states_ndarray.shape[0], minibatch_size): end = start + minibatch_size if end > states_ndarray.shape[0]: break pnas_end = pnas_start + np.sum(pnas_lengths[start:end]) pnas = possible_next_actions_ndarray[pnas_start:pnas_end] pnas_concat = next_state_pnas_concat[pnas_start:pnas_end] pnas_start = pnas_end tdps.append( TrainingDataPage( states=states_ndarray[start:end], actions=actions_ndarray[start:end], propensities=propensities[start:end], rewards=rewards[start:end], next_states=next_states_ndarray[start:end], next_actions=next_actions_ndarray[start:end], possible_next_actions=StackedArray(pnas_lengths[start:end], pnas), not_terminals=(pnas_lengths[start:end] > 0).reshape(-1, 1), episode_values=episode_values[start:end] if episode_values is not None else None, time_diffs=time_diffs[start:end], possible_next_actions_lengths=pnas_lengths[start:end], next_state_pnas_concat=pnas_concat, )) return tdps