コード例 #1
0
 def test_index_dict(self):
     labels = np.asarray([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
     dict_a = {
         "variable_one": np.random.randn(10),
         "variable_two": np.random.randn(10, 10),
     }
     dict_a_1 = PhotonDataHelper.index_dict(dict_a, labels == 0)
     dict_a_2 = PhotonDataHelper.index_dict(dict_a, labels == 1)
     self.assertEqual(len(dict_a_1["variable_one"]), 5)
     self.assertEqual(dict_a_2["variable_two"].shape, (5, 10))
コード例 #2
0
    def transform(self, X, y=None, **kwargs):
        """
        Generates "new samples" by computing the mean between all or n_draws pairs of existing samples and appends them to X
        The target for each new sample is computed as the mean between the constituent targets
        :param X: data
        :param y: targets (optional)
        :param draw_limit: in case the full number of combinations is > 10k, how many to draw?
        :param rand_seed: sets seed for random sampling of combinations (for reproducibility only)
        :return: X_new: X and X_augmented; (y_new: the correspoding targets)
        """

        logger.debug("Pairing " + str(self.draw_limit) + " samples...")

        # ensure class balance in the training set if balance_classes is True
        unique_classes = np.unique(y)
        n_pairs = list()
        for label in unique_classes:
            if self.balance_classes:
                n_pairs.append(self.draw_limit - np.sum(y == label))
            else:
                n_pairs.append(self.draw_limit)

        # run get_samples for each class independently
        X_extended = list()
        y_extended = list()
        kwargs_extended = dict()

        for label, limit in zip(unique_classes, n_pairs):
            X_new_class, y_new_class, kwargs_new_class = self._return_samples(
                X[y == label],
                y[y == label],
                PhotonDataHelper.index_dict(kwargs, y == label),
                generator=self.generator,
                distance_metric=self.distance_metric,
                draw_limit=limit,
                rand_seed=self.random_state,
            )

            X_extended.extend(X_new_class)
            y_extended.extend(y_new_class)

            # get the corresponding kwargs
            if kwargs:
                kwargs_extended = PhotonDataHelper.join_dictionaries(
                    kwargs_extended, kwargs_new_class)

        return X_extended, y_extended, kwargs_extended