Esempio n. 1
0
    def test__conditionally_sample_rows_graceful_reject_sampling_true(self):
        """Test the `BaseTabularModel._conditionally_sample_rows` method.

        When `_sample_with_conditions` is called with `graceful_reject_sampling` as True,
        expect that there are no errors if no valid rows are generated.

        Input:
            - An impossible condition
        Returns:
            - Empty DataFrame
        """
        # Setup
        model = Mock(spec_set=CTGAN)
        model._validate_file_path.return_value = None

        condition_values = {'cola': 'c'}
        transformed_conditions = pd.DataFrame([condition_values] * 2)
        condition = Condition(condition_values, num_rows=2)

        model._sample_batch.return_value = pd.DataFrame()

        # Run
        sampled = BaseTabularModel._conditionally_sample_rows(
            model,
            pd.DataFrame([condition_values] * 2),
            condition,
            transformed_conditions,
            graceful_reject_sampling=True,
        )

        # Assert
        assert len(sampled) == 0
        model._sample_batch.assert_called_once_with(2, None, None, condition,
                                                    transformed_conditions,
                                                    0.01, None, None)
Esempio n. 2
0
    def test__conditionally_sample_rows_graceful_reject_sampling_false(self):
        """Test the `BaseTabularModel._conditionally_sample_rows` method.

        When `_sample_with_conditions` is called with `graceful_reject_sampling` as False,
        expect that an error is thrown if no valid rows are generated.

        Input:
            - An impossible condition
        Side Effect:
            - A ValueError is thrown
        """
        # Setup
        model = Mock(spec_set=CTGAN)
        model._validate_file_path.return_value = None

        condition_values = {'cola': 'c'}
        transformed_conditions = pd.DataFrame([condition_values] * 2)
        condition = Condition(condition_values, num_rows=2)

        model._sample_batch.return_value = pd.DataFrame()

        # Run and assert
        with pytest.raises(
                ValueError,
                match='Unable to sample any rows for the given conditions'):
            BaseTabularModel._conditionally_sample_rows(
                model,
                pd.DataFrame([condition_values] * 2),
                condition,
                transformed_conditions,
                graceful_reject_sampling=False,
            )

        model._sample_batch.assert_called_once_with(2, None, None, condition,
                                                    transformed_conditions,
                                                    0.01, None, None)