Example #1
0
    def get_max_q_values(
        self,
        next_states: str,
        possible_next_actions: StackedArray,
        use_target_network: bool,
    ) -> str:
        """
        Takes in an array of next_states and outputs an array of the same shape
        whose ith entry = max_{pna} Q(state_i, pna). Uses target network for
        Q(state_i, pna) approximation.

        :param next_states: Blob containing state features.  Each
            row contains a representation of a state.
        :param possible_next_actions: List of sets of possible next actions. The
            ith element of this list is a matrix PNA_i such that PNA_i[j] is the
            parametric representation of the jth possible action from the ith
            next_state. These have not been normalized.
        """

        stacked_states = C2.LengthsTile(next_states,
                                        possible_next_actions.lengths)
        all_q_values = self.get_q_values(
            stacked_states,
            possible_next_actions.values,
            use_target_network,
        )
        max_q_values = C2.LengthsMax(
            all_q_values,
            possible_next_actions.lengths,
        )
        return max_q_values
Example #2
0
 def concat_states_and_possible_next_actions(
     self,
     next_state_preprocessed_matrix_blob: str,
     possible_next_actions_blob: str,
     possible_next_actions_lengths_blob: str,
 ) -> str:
     stacked_states = C2.LengthsTile(next_state_preprocessed_matrix_blob,
                                     possible_next_actions_lengths_blob)
     state_action_pairs, _ = C2.Concat(stacked_states,
                                       possible_next_actions_blob,
                                       axis=1)
     return state_action_pairs
Example #3
0
    def preprocess_samples(
        self,
        samples: Samples,
        minibatch_size: int,
        use_gpu: bool = False,
        one_hot_action: bool = True,
        normalize_actions: bool = True,
    ) -> List[TrainingDataPage]:
        logger.info("Shuffling...")
        samples.shuffle()

        logger.info("Sparse2Dense...")
        net = core.Net("gridworld_preprocessing")
        C2.set_net(net)
        saa = StackedAssociativeArray.from_dict_list(samples.states, "states")
        sorted_state_features, _ = sort_features_by_normalization(self.normalization)
        state_matrix, _ = sparse_to_dense(
            saa.lengths, saa.keys, saa.values, sorted_state_features
        )
        saa = StackedAssociativeArray.from_dict_list(samples.next_states, "next_states")
        next_state_matrix, _ = sparse_to_dense(
            saa.lengths, saa.keys, saa.values, sorted_state_features
        )
        sorted_action_features, _ = sort_features_by_normalization(
            self.normalization_action
        )
        saa = StackedAssociativeArray.from_dict_list(samples.actions, "action")
        action_matrix, _ = sparse_to_dense(
            saa.lengths, saa.keys, saa.values, sorted_action_features
        )
        saa = StackedAssociativeArray.from_dict_list(
            samples.next_actions, "next_action"
        )
        next_action_matrix, _ = sparse_to_dense(
            saa.lengths, saa.keys, saa.values, sorted_action_features
        )
        action_probabilities = torch.tensor(
            samples.action_probabilities, dtype=torch.float32
        ).reshape(-1, 1)
        rewards = torch.tensor(samples.rewards, dtype=torch.float32).reshape(-1, 1)

        pnas_lengths_list = []
        pnas_flat: List[List[str]] = []
        for pnas in samples.possible_next_actions:
            pnas_lengths_list.append(len(pnas))
            pnas_flat.extend(pnas)
        saa = StackedAssociativeArray.from_dict_list(pnas_flat, "possible_next_actions")

        pnas_lengths = torch.tensor(pnas_lengths_list, dtype=torch.int32)
        pna_lens_blob = "pna_lens_blob"
        workspace.FeedBlob(pna_lens_blob, pnas_lengths.numpy())

        possible_next_actions_matrix, _ = sparse_to_dense(
            saa.lengths, saa.keys, saa.values, sorted_action_features
        )

        state_pnas_tile_blob = C2.LengthsTile(next_state_matrix, pna_lens_blob)

        workspace.RunNetOnce(net)

        logger.info("Preprocessing...")
        state_preprocessor = Preprocessor(self.normalization, False)
        action_preprocessor = Preprocessor(self.normalization_action, False)

        states_ndarray = workspace.FetchBlob(state_matrix)
        states_ndarray = state_preprocessor.forward(states_ndarray)

        actions_ndarray = torch.from_numpy(workspace.FetchBlob(action_matrix))
        if normalize_actions:
            actions_ndarray = action_preprocessor.forward(actions_ndarray)

        next_states_ndarray = workspace.FetchBlob(next_state_matrix)
        next_states_ndarray = state_preprocessor.forward(next_states_ndarray)

        next_actions_ndarray = torch.from_numpy(workspace.FetchBlob(next_action_matrix))
        if normalize_actions:
            next_actions_ndarray = action_preprocessor.forward(next_actions_ndarray)

        logged_possible_next_actions = action_preprocessor.forward(
            workspace.FetchBlob(possible_next_actions_matrix)
        )

        state_pnas_tile = state_preprocessor.forward(
            workspace.FetchBlob(state_pnas_tile_blob)
        )
        logged_possible_next_state_actions = torch.cat(
            (state_pnas_tile, logged_possible_next_actions), dim=1
        )

        logger.info("Reward Timeline to Torch...")
        possible_next_actions_ndarray = logged_possible_next_actions
        possible_next_actions_state_concat = logged_possible_next_state_actions
        time_diffs = torch.ones([len(samples.states), 1])

        tdps = []
        pnas_start = 0
        logger.info("Batching...")
        for start in range(0, states_ndarray.shape[0], minibatch_size):
            end = start + minibatch_size
            if end > states_ndarray.shape[0]:
                break
            pnas_end = pnas_start + torch.sum(pnas_lengths[start:end])
            pnas = possible_next_actions_ndarray[pnas_start:pnas_end]
            pnas_concat = possible_next_actions_state_concat[pnas_start:pnas_end]
            pnas_start = pnas_end
            tdp = TrainingDataPage(
                states=states_ndarray[start:end],
                actions=actions_ndarray[start:end],
                propensities=action_probabilities[start:end],
                rewards=rewards[start:end],
                next_states=next_states_ndarray[start:end],
                next_actions=next_actions_ndarray[start:end],
                possible_next_actions=None,
                not_terminals=(pnas_lengths[start:end] > 0).reshape(-1, 1),
                time_diffs=time_diffs[start:end],
                possible_next_actions_lengths=pnas_lengths[start:end],
                possible_next_actions_state_concat=pnas_concat,
            )
            tdp.set_type(torch.cuda.FloatTensor if use_gpu else torch.FloatTensor)
            tdps.append(tdp)
        return tdps
Example #4
0
    def preprocess_samples(self, samples: Samples,
                           minibatch_size: int) -> List[TrainingDataPage]:
        samples.shuffle()

        net = core.Net("gridworld_preprocessing")
        C2.set_net(net)
        preprocessor = PreprocessorNet(True)
        saa = StackedAssociativeArray.from_dict_list(samples.states, "states")
        state_matrix, _ = preprocessor.normalize_sparse_matrix(
            saa.lengths,
            saa.keys,
            saa.values,
            self.normalization,
            "state_norm",
            False,
            False,
            False,
        )
        saa = StackedAssociativeArray.from_dict_list(samples.next_states,
                                                     "next_states")
        next_state_matrix, _ = preprocessor.normalize_sparse_matrix(
            saa.lengths,
            saa.keys,
            saa.values,
            self.normalization,
            "next_state_norm",
            False,
            False,
            False,
        )
        saa = StackedAssociativeArray.from_dict_list(samples.actions, "action")
        action_matrix, _ = preprocessor.normalize_sparse_matrix(
            saa.lengths,
            saa.keys,
            saa.values,
            self.normalization_action,
            "action_norm",
            False,
            False,
            False,
        )
        saa = StackedAssociativeArray.from_dict_list(samples.next_actions,
                                                     "next_action")
        next_action_matrix, _ = preprocessor.normalize_sparse_matrix(
            saa.lengths,
            saa.keys,
            saa.values,
            self.normalization_action,
            "next_action_norm",
            False,
            False,
            False,
        )
        propensities = np.array(samples.propensities,
                                dtype=np.float32).reshape(-1, 1)
        rewards = np.array(samples.rewards, dtype=np.float32).reshape(-1, 1)

        pnas_lengths_list = []
        pnas_flat: List[List[str]] = []
        for pnas in samples.possible_next_actions:
            pnas_lengths_list.append(len(pnas))
            pnas_flat.extend(pnas)
        saa = StackedAssociativeArray.from_dict_list(pnas_flat,
                                                     "possible_next_actions")

        pnas_lengths = np.array(pnas_lengths_list, dtype=np.int32)
        pna_lens_blob = "pna_lens_blob"
        workspace.FeedBlob(pna_lens_blob, pnas_lengths)

        possible_next_actions_matrix, _ = preprocessor.normalize_sparse_matrix(
            saa.lengths,
            saa.keys,
            saa.values,
            self.normalization_action,
            "possible_next_action_norm",
            False,
            False,
            False,
        )

        state_pnas_tile_blob = C2.LengthsTile(next_state_matrix, pna_lens_blob)

        workspace.RunNetOnce(net)

        state_preprocessor = Preprocessor(self.normalization, False)
        action_preprocessor = Preprocessor(self.normalization_action, False)

        states_ndarray = workspace.FetchBlob(state_matrix)
        states_ndarray = state_preprocessor.forward(states_ndarray).numpy()

        actions_ndarray = workspace.FetchBlob(action_matrix)
        actions_ndarray = action_preprocessor.forward(actions_ndarray).numpy()

        next_states_ndarray = workspace.FetchBlob(next_state_matrix)
        next_states_ndarray = state_preprocessor.forward(
            next_states_ndarray).numpy()

        next_actions_ndarray = workspace.FetchBlob(next_action_matrix)
        next_actions_ndarray = action_preprocessor.forward(
            next_actions_ndarray).numpy()

        logged_possible_next_actions = action_preprocessor.forward(
            workspace.FetchBlob(possible_next_actions_matrix))

        state_pnas_tile = state_preprocessor.forward(
            workspace.FetchBlob(state_pnas_tile_blob))
        logged_possible_next_state_actions = torch.cat(
            (state_pnas_tile, logged_possible_next_actions), dim=1)

        possible_next_actions_ndarray = logged_possible_next_actions.cpu(
        ).numpy()
        next_state_pnas_concat = logged_possible_next_state_actions.cpu(
        ).numpy()
        time_diffs = np.ones(len(states_ndarray))
        episode_values = None
        if samples.reward_timelines is not None:
            episode_values = np.zeros(rewards.shape, dtype=np.float32)
            for i, reward_timeline in enumerate(samples.reward_timelines):
                for time_diff, reward in reward_timeline.items():
                    episode_values[i, 0] += reward * (DISCOUNT**time_diff)

        tdps = []
        pnas_start = 0
        for start in range(0, states_ndarray.shape[0], minibatch_size):
            end = start + minibatch_size
            if end > states_ndarray.shape[0]:
                break
            pnas_end = pnas_start + np.sum(pnas_lengths[start:end])
            pnas = possible_next_actions_ndarray[pnas_start:pnas_end]
            pnas_concat = next_state_pnas_concat[pnas_start:pnas_end]
            pnas_start = pnas_end
            tdps.append(
                TrainingDataPage(
                    states=states_ndarray[start:end],
                    actions=actions_ndarray[start:end],
                    propensities=propensities[start:end],
                    rewards=rewards[start:end],
                    next_states=next_states_ndarray[start:end],
                    next_actions=next_actions_ndarray[start:end],
                    possible_next_actions=StackedArray(pnas_lengths[start:end],
                                                       pnas),
                    not_terminals=(pnas_lengths[start:end] > 0).reshape(-1, 1),
                    episode_values=episode_values[start:end]
                    if episode_values is not None else None,
                    time_diffs=time_diffs[start:end],
                    possible_next_actions_lengths=pnas_lengths[start:end],
                    next_state_pnas_concat=pnas_concat,
                ))
        return tdps