def preprocess_samples(
        self,
        samples: Samples,
        minibatch_size: int,
        use_gpu: bool = False,
        one_hot_action: bool = True,
        normalize_actions: bool = True,
    ) -> List[TrainingDataPage]:
        logger.info("Shuffling...")
        samples = shuffle_samples(samples)

        logger.info("Sparse2Dense...")
        net = core.Net("gridworld_preprocessing")
        C2.set_net(net)
        sorted_state_features, _ = sort_features_by_normalization(
            self.normalization)
        sorted_action_features, _ = sort_features_by_normalization(
            self.normalization_action)
        state_sparse_to_dense_processor = Caffe2SparseToDenseProcessor(
            sorted_state_features)
        action_sparse_to_dense_processor = Caffe2SparseToDenseProcessor(
            sorted_action_features)
        saa = StackedAssociativeArray.from_dict_list(samples.states, "states")
        state_matrix, state_matrix_presence, _ = state_sparse_to_dense_processor(
            saa)
        saa = StackedAssociativeArray.from_dict_list(samples.next_states,
                                                     "next_states")
        next_state_matrix, next_state_matrix_presence, _ = state_sparse_to_dense_processor(
            saa)
        saa = StackedAssociativeArray.from_dict_list(  # type: ignore
            samples.actions, "action")
        action_matrix, action_matrix_presence, _ = action_sparse_to_dense_processor(
            saa)
        saa = StackedAssociativeArray.from_dict_list(  # type: ignore
            samples.next_actions, "next_action")
        next_action_matrix, next_action_matrix_presence, _ = action_sparse_to_dense_processor(
            saa)
        action_probabilities = torch.tensor(samples.action_probabilities,
                                            dtype=torch.float32).reshape(
                                                -1, 1)
        rewards = torch.tensor(samples.rewards,
                               dtype=torch.float32).reshape(-1, 1)

        max_action_size = 4

        pnas_mask_list: List[List[int]] = []
        pnas_flat: List[Dict[str, float]] = []
        for pnas in samples.possible_next_actions:
            pnas_mask_list.append([1] * len(pnas) + [0] *
                                  (max_action_size - len(pnas)))
            pnas_flat.extend(pnas)  # type: ignore
            for _ in range(max_action_size - len(pnas)):
                pnas_flat.append({})  # Filler
        saa = StackedAssociativeArray.from_dict_list(  # type: ignore
            pnas_flat, "possible_next_actions")
        pnas_mask = torch.Tensor(pnas_mask_list)

        possible_next_actions_matrix, possible_next_actions_matrix_presence, _ = action_sparse_to_dense_processor(
            saa)

        workspace.RunNetOnce(net)

        logger.info("Preprocessing...")
        state_preprocessor = Preprocessor(self.normalization, False)
        action_preprocessor = Preprocessor(self.normalization_action, False)

        states_ndarray = state_preprocessor(
            torch.from_numpy(workspace.FetchBlob(state_matrix)),
            torch.from_numpy(
                workspace.FetchBlob(state_matrix_presence)).float(),
        )

        if normalize_actions:
            actions_ndarray = action_preprocessor(
                torch.from_numpy(workspace.FetchBlob(action_matrix)),
                torch.from_numpy(
                    workspace.FetchBlob(action_matrix_presence)).float(),
            )
        else:
            actions_ndarray = torch.from_numpy(
                workspace.FetchBlob(action_matrix))

        next_states_ndarray = torch.from_numpy(
            workspace.FetchBlob(next_state_matrix))
        next_states_ndarray = state_preprocessor(
            next_states_ndarray,
            (next_states_ndarray != MISSING_VALUE).float())

        state_pnas_tile = next_states_ndarray.repeat(
            1, max_action_size).reshape(-1, next_states_ndarray.shape[1])

        if normalize_actions:
            next_actions_ndarray = action_preprocessor(
                torch.from_numpy(workspace.FetchBlob(next_action_matrix)),
                torch.from_numpy(
                    workspace.FetchBlob(next_action_matrix_presence)).float(),
            )
        else:
            next_actions_ndarray = torch.from_numpy(
                workspace.FetchBlob(next_action_matrix))

        if normalize_actions:
            logged_possible_next_actions = action_preprocessor(
                torch.from_numpy(
                    workspace.FetchBlob(possible_next_actions_matrix)),
                torch.from_numpy(
                    workspace.FetchBlob(
                        possible_next_actions_matrix_presence)).float(),
            )
        else:
            logged_possible_next_actions = torch.from_numpy(
                workspace.FetchBlob(possible_next_actions_matrix))

        assert state_pnas_tile.shape[0] == logged_possible_next_actions.shape[
            0], ("Invalid shapes: " + str(state_pnas_tile.shape) + " != " +
                 str(logged_possible_next_actions.shape))
        logged_possible_next_state_actions = torch.cat(
            (state_pnas_tile, logged_possible_next_actions), dim=1)

        logger.info("Reward Timeline to Torch...")
        time_diffs = torch.ones([len(samples.states), 1])

        tdps = []
        pnas_start = 0
        logger.info("Batching...")
        for start in range(0, states_ndarray.shape[0], minibatch_size):
            end = start + minibatch_size
            if end > states_ndarray.shape[0]:
                break
            pnas_end = pnas_start + (minibatch_size * max_action_size)
            tdp = TrainingDataPage(
                states=states_ndarray[start:end],
                actions=actions_ndarray[start:end],
                propensities=action_probabilities[start:end],
                rewards=rewards[start:end],
                next_states=next_states_ndarray[start:end],
                next_actions=next_actions_ndarray[start:end],
                not_terminal=(pnas_mask[start:end, :].sum(dim=1, keepdim=True)
                              > 0),
                time_diffs=time_diffs[start:end],
                possible_next_actions_mask=pnas_mask[start:end, :],
                possible_next_actions_state_concat=
                logged_possible_next_state_actions[pnas_start:pnas_end, :],
            )
            pnas_start = pnas_end
            tdp.set_type(torch.cuda.FloatTensor if use_gpu else torch.
                         FloatTensor  # type: ignore
                         )
            tdps.append(tdp)
        return tdps
Exemplo n.º 2
0
    def preprocess_samples_discrete(
        self,
        samples: Samples,
        minibatch_size: int,
        one_hot_action: bool = True,
        use_gpu: bool = False,
        do_shuffle: bool = True,
    ) -> List[TrainingDataPage]:

        if do_shuffle:
            logger.info("Shuffling...")
            samples = shuffle_samples(samples)

        logger.info("Preprocessing...")
        sparse_to_dense_processor = Caffe2SparseToDenseProcessor()

        if self.sparse_to_dense_net is None:
            self.sparse_to_dense_net = core.Net("gridworld_sparse_to_dense")
            C2.set_net(self.sparse_to_dense_net)
            saa = StackedAssociativeArray.from_dict_list(samples.states, "states")
            sorted_features, _ = sort_features_by_normalization(self.normalization)
            self.state_matrix, _ = sparse_to_dense_processor(sorted_features, saa)
            saa = StackedAssociativeArray.from_dict_list(
                samples.next_states, "next_states"
            )
            self.next_state_matrix, _ = sparse_to_dense_processor(sorted_features, saa)
            C2.set_net(None)
        else:
            StackedAssociativeArray.from_dict_list(samples.states, "states")
            StackedAssociativeArray.from_dict_list(samples.next_states, "next_states")
        workspace.RunNetOnce(self.sparse_to_dense_net)

        logger.info("Converting to Torch...")
        actions_one_hot = torch.tensor(
            (np.array(samples.actions).reshape(-1, 1) == np.array(self.ACTIONS)).astype(
                np.int64
            )
        )
        actions = actions_one_hot.argmax(dim=1, keepdim=True)
        rewards = torch.tensor(samples.rewards, dtype=torch.float32).reshape(-1, 1)
        action_probabilities = torch.tensor(
            samples.action_probabilities, dtype=torch.float32
        ).reshape(-1, 1)
        next_actions_one_hot = torch.tensor(
            (
                np.array(samples.next_actions).reshape(-1, 1) == np.array(self.ACTIONS)
            ).astype(np.int64)
        )
        logger.info("Converting PA to Torch...")
        possible_action_strings = np.array(
            list(itertools.zip_longest(*samples.possible_actions, fillvalue=""))
        ).T
        possible_actions_mask = torch.zeros([len(samples.actions), len(self.ACTIONS)])
        for i, action in enumerate(self.ACTIONS):
            possible_actions_mask[:, i] = torch.tensor(
                np.max(possible_action_strings == action, axis=1).astype(np.int64)
            )
        logger.info("Converting PNA to Torch...")
        possible_next_action_strings = np.array(
            list(itertools.zip_longest(*samples.possible_next_actions, fillvalue=""))
        ).T
        possible_next_actions_mask = torch.zeros(
            [len(samples.next_actions), len(self.ACTIONS)]
        )
        for i, action in enumerate(self.ACTIONS):
            possible_next_actions_mask[:, i] = torch.tensor(
                np.max(possible_next_action_strings == action, axis=1).astype(np.int64)
            )
        terminals = torch.tensor(samples.terminals, dtype=torch.int32).reshape(-1, 1)
        not_terminal = 1 - terminals
        logger.info("Converting RT to Torch...")

        time_diffs = torch.ones([len(samples.states), 1])

        logger.info("Preprocessing...")
        preprocessor = Preprocessor(self.normalization, False)

        states_ndarray = workspace.FetchBlob(self.state_matrix)
        states_ndarray = preprocessor.forward(states_ndarray)

        next_states_ndarray = workspace.FetchBlob(self.next_state_matrix)
        next_states_ndarray = preprocessor.forward(next_states_ndarray)

        logger.info("Batching...")
        tdps = []
        for start in range(0, states_ndarray.shape[0], minibatch_size):
            end = start + minibatch_size
            if end > states_ndarray.shape[0]:
                break
            tdp = TrainingDataPage(
                states=states_ndarray[start:end],
                actions=actions_one_hot[start:end]
                if one_hot_action
                else actions[start:end],
                propensities=action_probabilities[start:end],
                rewards=rewards[start:end],
                next_states=next_states_ndarray[start:end],
                not_terminal=not_terminal[start:end],
                next_actions=next_actions_one_hot[start:end],
                possible_actions_mask=possible_actions_mask[start:end],
                possible_next_actions_mask=possible_next_actions_mask[start:end],
                time_diffs=time_diffs[start:end],
            )
            tdp.set_type(torch.cuda.FloatTensor if use_gpu else torch.FloatTensor)
            tdps.append(tdp)
        return tdps