Exemplo n.º 1
0
    def __init__(
        self,
        train_df,
        categorical_input: List,
        numerical_input: List,
        target: str,
        valid_df=None,
        test_df=None,
        batch_size=2,
        num_workers: Optional[int] = None,
    ):
        dfs = [train_df]
        self._test_df = None

        if valid_df is not None:
            dfs.append(valid_df)

        if test_df is not None:
            # save for predict function
            self._test_df = test_df.copy()
            self._test_df.drop(target, axis=1)
            dfs.append(test_df)

        # impute missing values
        dfs = _impute(dfs, numerical_input)

        # compute train dataset stats
        self.mean, self.std = _compute_normalization(dfs[0], numerical_input)

        if dfs[0][target].dtype == object:
            # if the target is a category, not an int
            self.target_codes = _generate_codes(dfs, [target])
        else:
            self.target_codes = None

        self.codes = _generate_codes(dfs, categorical_input)

        dfs = _pre_transform(dfs, numerical_input, categorical_input,
                             self.codes, self.mean, self.std, target,
                             self.target_codes)

        # normalize
        self.cat_cols = categorical_input
        self.num_cols = numerical_input

        self._num_classes = len(train_df[target].unique())

        train_ds = PandasDataset(dfs[0], categorical_input, numerical_input,
                                 target)
        valid_ds = PandasDataset(dfs[1], categorical_input, numerical_input,
                                 target) if valid_df is not None else None
        test_ds = PandasDataset(dfs[-1], categorical_input, numerical_input,
                                target) if test_df is not None else None
        super().__init__(train_ds,
                         valid_ds,
                         test_ds,
                         batch_size=batch_size,
                         num_workers=num_workers)
Exemplo n.º 2
0
def test_pandas_no_num():
    df = TEST_DF.copy()
    ds = PandasDataset(
        df,
        cat_cols=["category"],
        num_cols=[],
        target_col="label",
        is_regression=False,
    )
    assert len(ds) == 6
    (cat, num), target = ds[0]
    assert cat == np.array([0])
    assert num.size == 0
    assert target == 0
Exemplo n.º 3
0
def test_pandas_no_cat():
    df = TEST_DF.copy()
    ds = PandasDataset(
        df,
        cat_cols=[],
        num_cols=["scalar_a", "scalar_b"],
        target_col="label",
        is_regression=False,
    )
    assert len(ds) == 6
    (cat, num), target = ds[0]
    assert cat.size == 0
    assert np.allclose(num, np.array([0.0, 5.0]))
    assert target == 0