コード例 #1
0
ファイル: test_base.py プロジェクト: sdv-dev/SDV
    def test__validate_conditions_with_conditions_invalid_column(self):
        """Test the `BaseTabularModel._validate_conditions` method with an invalid column.

        When a condition has an invalid column, expect a ValueError.

        Input:
            - Conditions DataFrame with an invalid column.
        Side Effects:
            - A ValueError is thrown.
        """
        # Setup
        model = Mock(spec_set=CTGAN)
        metadata_mock = Mock()
        metadata_mock.get_fields.return_value = {'cola': {}}
        model._metadata = metadata_mock

        conditions = pd.DataFrame([{'colb': 'a'}] * 5)

        # Run and Assert
        with pytest.raises(
                ValueError,
                match=(
                    'Unexpected column name `colb`. '
                    'Use a column name that was present in the original data.'
                )):
            BaseTabularModel._validate_conditions(model, conditions)
コード例 #2
0
ファイル: test_base.py プロジェクト: sdv-dev/SDV
    def test__sample_conditions_no_rows(self):
        """Test `BaseTabularModel._sample_conditions` with invalid condition.

        If no valid rows are returned for any condition, expect a ValueError.

        Input:
            - condition that is impossible to satisfy
        Side Effects:
            - ValueError is thrown
        """
        # Setup
        model = Mock(spec_set=CTGAN)
        condition = Condition(
            {'column1': 'b'},
            num_rows=5,
        )
        model._make_condition_dfs.return_value = pd.DataFrame([{
            'column1': 'b'
        }] * 5)
        model._sample_with_conditions.return_value = pd.DataFrame()

        # Run and assert
        with pytest.raises(
                ValueError,
                match='Unable to sample any rows for the given conditions.'):
            BaseTabularModel._sample_conditions(model, [condition], 100, None,
                                                True, None)
コード例 #3
0
ファイル: test_base.py プロジェクト: sdv-dev/SDV
    def test_sample_with_default_file_path_error(self, os_mock):
        """Test the `BaseTabularModel.sample` method with the default file path.

        Expect that the file is not removed if there is an error with sampling.

        Input:
            - output_file_path=None.
        Side Effects:
            - ValueError is thrown.
        """
        # Setup
        model = Mock()
        model._validate_file_path.return_value = TMP_FILE_NAME
        model._sample_batch.side_effect = ValueError('test error')

        # Run
        with pytest.raises(ValueError, match='test error'):
            BaseTabularModel.sample(model, 1, output_file_path=None)

        # Assert
        model._sample_batch.called_once_with(1,
                                             batch_size_per_try=1,
                                             progress_bar=ANY,
                                             output_file_path=TMP_FILE_NAME)
        assert os_mock.remove.call_count == 0
コード例 #4
0
ファイル: test_base.py プロジェクト: sdv-dev/SDV
    def test_sample_with_default_file_path(self, os_mock):
        """Test the `BaseTabularModel.sample` method with the default file path.

        Expect that the file is removed after successfully sampling.

        Input:
            - output_file_path=None.
        Side Effects:
            - The file is removed.
        """
        # Setup
        model = Mock()
        model._validate_file_path.return_value = TMP_FILE_NAME
        model._sample_batch.return_value = pd.DataFrame({'test': [1]})
        os_mock.path.exists.return_value = True

        # Run
        BaseTabularModel.sample(model, 1, output_file_path=None)

        # Assert
        model._sample_batch.called_once_with(1,
                                             batch_size_per_try=1,
                                             progress_bar=ANY,
                                             output_file_path=TMP_FILE_NAME)
        os_mock.remove.called_once_with(TMP_FILE_NAME)
コード例 #5
0
ファイル: test_base.py プロジェクト: sdv-dev/SDV
    def test_sample_no_num_rows(self):
        """Test the `BaseTabularModel.sample` method with no `num_rows` input.

        Expect that an error is thrown.
        """
        # Setup
        model = BaseTabularModel()

        # Run and assert
        with pytest.raises(
                TypeError,
                match=
                r'sample\(\) missing 1 required positional argument: \'num_rows\''
        ):
            model.sample()
コード例 #6
0
ファイル: test_base.py プロジェクト: sdv-dev/SDV
def test__randomize_samples_false():
    """Test that ``_randomize_samples`` is a no-op when user wants random samples.

    Input:
        - randomize_samples as False
    """
    # Setup
    instance = Mock()
    randomize_samples = False

    # Run
    BaseTabularModel._randomize_samples(instance, randomize_samples)

    # Assert
    assert instance._set_random_state.called_once_with(None)
コード例 #7
0
ファイル: test_base.py プロジェクト: sdv-dev/SDV
    def test__conditionally_sample_rows_graceful_reject_sampling_true(self):
        """Test the `BaseTabularModel._conditionally_sample_rows` method.

        When `_sample_with_conditions` is called with `graceful_reject_sampling` as True,
        expect that there are no errors if no valid rows are generated.

        Input:
            - An impossible condition
        Returns:
            - Empty DataFrame
        """
        # Setup
        model = Mock(spec_set=CTGAN)
        model._validate_file_path.return_value = None

        condition_values = {'cola': 'c'}
        transformed_conditions = pd.DataFrame([condition_values] * 2)
        condition = Condition(condition_values, num_rows=2)

        model._sample_batch.return_value = pd.DataFrame()

        # Run
        sampled = BaseTabularModel._conditionally_sample_rows(
            model,
            pd.DataFrame([condition_values] * 2),
            condition,
            transformed_conditions,
            graceful_reject_sampling=True,
        )

        # Assert
        assert len(sampled) == 0
        model._sample_batch.assert_called_once_with(2, None, None, condition,
                                                    transformed_conditions,
                                                    0.01, None, None)
コード例 #8
0
ファイル: test_base.py プロジェクト: sdv-dev/SDV
    def test__sample_remaining_columns(self):
        """Test the `BaseTabularModel._sample_remaining_colmns` method.

        When a valid DataFrame is given, expect `_sample_with_conditions` to be called
        with the input DataFrame.

        Input:
            - DataFrame with condition column values populated.
        Output:
            - The expected sampled rows.
        Side Effects:
            - `_sample_with_conditions` is called once.
        """
        # Setup
        model = Mock(spec_set=CTGAN)
        model._validate_file_path.return_value = None

        conditions = pd.DataFrame([{'cola': 'a'}] * 5)

        sampled = pd.DataFrame({
            'cola': ['a', 'a', 'a', 'a', 'a'],
            'colb': [1, 2, 1, 1, 1],
        })
        model._sample_with_conditions.return_value = sampled

        # Run
        out = BaseTabularModel._sample_remaining_columns(
            model, conditions, 100, None, True, None)

        # Asserts
        model._sample_with_conditions.assert_called_once_with(
            DataFrameMatcher(conditions), 100, None, ANY, None)
        pd.testing.assert_frame_equal(out, sampled)
コード例 #9
0
ファイル: test_base.py プロジェクト: sdv-dev/SDV
    def test_sample_valid_num_rows(self, tqdm_mock):
        """Test the `BaseTabularModel.sample` method with a valid `num_rows` argument.

        Expect that the expected call to `_sample_batch` is made.

        Input:
            - num_rows = 5
        Output:
            - The requested number of sampled rows.
        Side Effect:
            - Call `_sample_batch` method with the expected number of rows.
        """
        # Setup
        model = Mock(spec_set=CTGAN)
        valid_sampled_data = pd.DataFrame({
            'column1': [28, 28, 21, 1, 2],
            'column2': [37, 37, 1, 4, 5],
            'column3': [93, 93, 6, 4, 12],
        })
        model._sample_batch.return_value = valid_sampled_data

        # Run
        output = BaseTabularModel.sample(model, 5)

        # Assert
        assert model._sample_batch.called_once_with(5)
        assert tqdm_mock.call_count == 0
        assert len(output) == 5
コード例 #10
0
ファイル: test_base.py プロジェクト: sdv-dev/SDV
    def test__sample_batch_output_file_path(self, path_mock):
        """Test the `BaseTabularModel._sample_batch` method with a valid output file path.

        Expect that if the output file is empty, the sampled rows are written to the file
        with the header included in the first batch write.

        Input:
            - num_rows = 4
            - output_file_path = temp file
        Output:
            - The requested number of sampled rows (4).
        """
        # Setup
        model = Mock(spec_set=CTGAN)
        sampled_mock = MagicMock()
        sampled_mock.__len__.return_value = 4
        model._sample_rows.return_value = (sampled_mock, 4)
        output_file_path = 'test.csv'
        path_mock.getsize.return_value = 0

        # Run
        output = BaseTabularModel._sample_batch(
            model, num_rows=4, output_file_path=output_file_path)

        # Assert
        assert model._sample_rows.call_count == 1
        assert output == sampled_mock.head.return_value
        assert sampled_mock.head.return_value.tail.return_value.to_csv.called_once_with(
            call(2).tail(2).to_csv(output_file_path, index=False), )
コード例 #11
0
ファイル: test_base.py プロジェクト: sdv-dev/SDV
    def test__sample_with_conditions_no_transformed_columns(self):
        """Test the ``BaseTabularModel.sample`` method with no transformed columns.

        When the transformed conditions DataFrame has no columns, expect that sample
        does not pass through any conditions when conditionally sampling.

        Setup:
            - Mock the ``_make_condition_dfs`` method to return a dataframe representing
              the expected conditions, and the ``get_fields`` method to return metadata
              fields containing the expected conditioned column.
            - Mock the ``_metadata.transform`` method to return an empty transformed
              conditions dataframe.
            - Mock the ``_conditionally_sample_rows`` method to return the expected
              sampled rows.
            - Mock the `make_ids_unique` to return the expected sampled rows.
        Input:
            - number of rows
            - one set of conditions
        Output:
            - the expected sampled rows
        Side Effects:
            - Expect ``_conditionally_sample_rows`` to be called with the given condition
              and a transformed_condition of None.
        """
        # Setup
        model = Mock(spec_set=CTGAN)
        expected = pd.DataFrame(['a', 'a', 'a'])

        condition_dataframe = pd.DataFrame({'a': ['a', 'a', 'a']})
        model._make_condition_dfs.return_value = condition_dataframe
        model._metadata.get_fields.return_value = ['a']
        model._metadata.transform.return_value = pd.DataFrame({},
                                                              index=[0, 1, 2])
        model._conditionally_sample_rows.return_value = pd.DataFrame({
            'a': ['a', 'a', 'a'],
            COND_IDX: [0, 1, 2]
        })
        model._metadata.make_ids_unique.return_value = expected

        # Run
        out = BaseTabularModel._sample_with_conditions(model,
                                                       condition_dataframe,
                                                       100, None)

        # Asserts
        model._conditionally_sample_rows.assert_called_once_with(
            DataFrameMatcher(
                pd.DataFrame({
                    COND_IDX: [0, 1, 2],
                    'a': ['a', 'a', 'a']
                })),
            {'a': 'a'},
            None,
            100,
            None,
            progress_bar=None,
            output_file_path=None,
        )
        pd.testing.assert_frame_equal(out, expected)
コード例 #12
0
ファイル: test_base.py プロジェクト: sdv-dev/SDV
def test__randomize_samples_true():
    """Test that ``_randomize_samples`` sets the random state correctly.

    Input:
        - randomize_samples as True

    Side Effect:
        - random state is set
    """
    # Setup
    instance = Mock()
    randomize_samples = True

    # Run
    BaseTabularModel._randomize_samples(instance, randomize_samples)

    # Assert
    assert instance._set_random_state.called_once_with(FIXED_RNG_SEED)
コード例 #13
0
ファイル: test_base.py プロジェクト: sdv-dev/SDV
    def test__validate_conditions_with_conditions_valid_columns(self):
        """Test the `BaseTabularModel._validate_conditions` method with valid columns.

        Expect no error to be thrown.

        Input:
            - Conditions DataFrame contains only valid columns.
        """
        # Setup
        model = Mock(spec_set=CTGAN)
        metadata_mock = Mock()
        metadata_mock.get_fields.return_value = {'cola': {}}
        model._metadata = metadata_mock

        conditions = pd.DataFrame([{'cola': 'a'}] * 5)

        # Run and Assert
        BaseTabularModel._validate_conditions(model, conditions)
コード例 #14
0
ファイル: test_base.py プロジェクト: sdv-dev/SDV
    def test__validate_file_path(self, path_mock):
        """Test the `BaseTabularModel._validate_file_path` method.

        Expect that an error is thrown if the file path already exists.

        Input:
            - A file path that already exists.
        Side Effects:
            - An AssertionError.
        """
        # Setup
        path_mock.exists.return_value = True
        path_mock.abspath.return_value = 'path/to/file'
        model = Mock(spec_set=CTGAN)

        # Run and Assert
        with pytest.raises(AssertionError,
                           match='path/to/file already exists'):
            BaseTabularModel._validate_file_path(model, 'file_path')
コード例 #15
0
ファイル: test_base.py プロジェクト: sdv-dev/SDV
    def test_sample_num_rows_none(self):
        """Test the `BaseTabularModel.sample` method with a `num_rows` input of `None`.

        Expect that a `ValueError` is thrown.

        Input:
            - num_rows = None
        Side Effect:
            - ValueError
        """
        # Setup
        model = BaseTabularModel()
        num_rows = None

        # Run and assert
        with pytest.raises(
                ValueError,
                match=
                r'You must specify the number of rows to sample \(e.g. num_rows=100\)'
        ):
            model.sample(num_rows)
コード例 #16
0
ファイル: test_base.py プロジェクト: sdv-dev/SDV
    def test__sample_conditions_with_multiple_conditions(self):
        """Test the `BaseTabularModel._sample_conditions` method with multiple condtions.

        When multiple condition dataframes are returned by `_make_condition_dfs`,
        expect `_sample_with_conditions` is called for each condition dataframe.

        Input:
            - 2 conditions with different columns
        Output:
            - The expected sampled rows
        """
        # Setup
        model = Mock(spec_set=CTGAN)
        model._validate_file_path.return_value = None

        condition_values1 = {'cola': 'a'}
        condition1 = Condition(condition_values1, num_rows=2)
        sampled1 = pd.DataFrame({'a': ['a', 'a'], 'b': [1, 2]})

        condition_values2 = {'colb': 1}
        condition2 = Condition(condition_values2, num_rows=3)
        sampled2 = pd.DataFrame({'a': ['b', 'c', 'a'], 'b': [1, 1, 1]})

        expected = pd.DataFrame({
            'a': ['a', 'a', 'b', 'c', 'a'],
            'b': [1, 2, 1, 1, 1],
        })

        model._make_condition_dfs.return_value = [
            pd.DataFrame([condition_values1] * 2),
            pd.DataFrame([condition_values2] * 3),
        ]
        model._sample_with_conditions.side_effect = [
            sampled1,
            sampled2,
        ]

        # Run
        out = BaseTabularModel._sample_conditions(model,
                                                  [condition1, condition2],
                                                  100, None, True, None)

        # Asserts
        model._sample_with_conditions.assert_has_calls([
            call(DataFrameMatcher(pd.DataFrame([condition_values1] * 2)), 100,
                 None, ANY, None),
            call(DataFrameMatcher(pd.DataFrame([condition_values2] * 3)), 100,
                 None, ANY, None),
        ])
        pd.testing.assert_frame_equal(out, expected)
コード例 #17
0
ファイル: test_base.py プロジェクト: sdv-dev/SDV
    def test_sample_with_custom_file_path(self, os_mock):
        """Test the `BaseTabularModel.sample` method with a custom file path.

        Expect that the file is not removed if a custom file path is given.

        Input:
            - output_file_path='temp.csv'.
        Side Effects:
            - None
        """
        # Setup
        model = Mock()
        model._validate_file_path.return_value = 'temp.csv'
        model._sample_batch.return_value = pd.DataFrame({'test': [1]})

        # Run
        BaseTabularModel.sample(model, 1, output_file_path='temp.csv')

        # Assert
        model._sample_batch.called_once_with(1,
                                             batch_size_per_try=1,
                                             progress_bar=ANY,
                                             output_file_path='temp.csv')
        assert os_mock.remove.call_count == 0
コード例 #18
0
ファイル: test_base.py プロジェクト: sdv-dev/SDV
    def test__conditionally_sample_rows_graceful_reject_sampling_false(self):
        """Test the `BaseTabularModel._conditionally_sample_rows` method.

        When `_sample_with_conditions` is called with `graceful_reject_sampling` as False,
        expect that an error is thrown if no valid rows are generated.

        Input:
            - An impossible condition
        Side Effect:
            - A ValueError is thrown
        """
        # Setup
        model = Mock(spec_set=CTGAN)
        model._validate_file_path.return_value = None

        condition_values = {'cola': 'c'}
        transformed_conditions = pd.DataFrame([condition_values] * 2)
        condition = Condition(condition_values, num_rows=2)

        model._sample_batch.return_value = pd.DataFrame()

        # Run and assert
        with pytest.raises(
                ValueError,
                match='Unable to sample any rows for the given conditions'):
            BaseTabularModel._conditionally_sample_rows(
                model,
                pd.DataFrame([condition_values] * 2),
                condition,
                transformed_conditions,
                graceful_reject_sampling=False,
            )

        model._sample_batch.assert_called_once_with(2, None, None, condition,
                                                    transformed_conditions,
                                                    0.01, None, None)
コード例 #19
0
ファイル: test_base.py プロジェクト: sdv-dev/SDV
    def test_sample_remaining_columns(self):
        """Test `BaseTabularModel.sample_remaining_columns` method.

        Expect the correct args to be passed to `_sample_remaining_columns`.

        Input:
            - valid DataFrame
        Side Effects:
            - The expected `_sample_remaining_columns` call.
        """
        # Setup
        model = Mock(spec_set=CTGAN)
        conditions = pd.DataFrame([{'cola': 'a'}] * 5)

        # Run
        out = BaseTabularModel.sample_remaining_columns(model, conditions)

        # Assert
        model._sample_remaining_columns.assert_called_once_with(
            conditions, 100, None, True, None)
        assert out == model._sample_remaining_columns.return_value
コード例 #20
0
ファイル: test_base.py プロジェクト: sdv-dev/SDV
    def test__sample_batch_zero_valid(self):
        """Test the `BaseTabularModel._sample_batch` method with zero valid rows.

        Expect that the requested number of rows are returned, if the first `_sample_rows` call
        returns zero valid rows, and the second one returns enough valid rows.
        See https://github.com/sdv-dev/SDV/issues/285.

        Input:
            - num_rows = 5
            - condition on `column1` = 2
        Output:
            - The requested number of sampled rows (5).
        """
        # Setup
        model = Mock(spec_set=CTGAN)
        valid_sampled_data = pd.DataFrame({
            "column1": [28, 28, 21, 1, 2],
            "column2": [37, 37, 1, 4, 5],
            "column3": [93, 93, 6, 4, 12],
        })
        model._sample_rows.side_effect = [(pd.DataFrame({}), 0),
                                          (valid_sampled_data, 5)]

        conditions = {
            'column1': 2,
            'column1': 2,
            'column1': 2,
            'column1': 2,
            'column1': 2,
        }

        # Run
        output = BaseTabularModel._sample_batch(model,
                                                num_rows=5,
                                                conditions=conditions)

        # Assert
        assert model._sample_rows.call_count == 2
        assert len(output) == 5
コード例 #21
0
ファイル: test_base.py プロジェクト: sdv-dev/SDV
    def test_sample_batch_size(self, tqdm_mock):
        """Test the `BaseTabularModel.sample` method with a valid `batch_size` argument.

        Expect that the expected calls to `_sample_batch` are made.

        Input:
            - num_rows = 10
            - batch_size = 5
        Output:
            - The requested number of sampled rows.
        Side Effect:
            - Call `_sample_batch` method twice with the expected number of rows.
        """
        # Setup
        model = Mock(spec_set=CTGAN)
        sampled_data = pd.DataFrame({
            'column1': [28, 28, 21, 1, 2],
            'column2': [37, 37, 1, 4, 5],
            'column3': [93, 93, 6, 4, 12],
        })
        model._sample_batch.side_effect = [sampled_data, sampled_data]

        # Run
        output = BaseTabularModel.sample(model, 10, batch_size=5)

        # Assert
        assert model._sample_batch.has_calls([
            call(5,
                 batch_size_per_try=5,
                 progress_bar=ANY,
                 output_file_path=None),
            call(5,
                 batch_size_per_try=5,
                 progress_bar=ANY,
                 output_file_path=None),
        ])
        tqdm_mock.assert_has_calls([call(total=10)])
        assert len(output) == 10
コード例 #22
0
ファイル: test_base.py プロジェクト: sdv-dev/SDV
    def test_sample_conditions(self):
        """Test `BaseTabularModel.sample_conditions` method.

        Expect the correct args to be passed to `_sample_conditions`.

        Input:
            - valid conditions
        Side Effects:
            - The expected `_sample_conditions` call.
        """
        # Setup
        model = Mock(spec_set=CTGAN)
        condition = Condition(
            {'column1': 'b'},
            num_rows=5,
        )

        # Run
        out = BaseTabularModel.sample_conditions(model, [condition])

        # Assert
        model._sample_conditions.assert_called_once_with([condition], 100,
                                                         None, True, None)
        assert out == model._sample_conditions.return_value
コード例 #23
0
ファイル: test_base.py プロジェクト: sdv-dev/SDV
    def test__sample_batch_with_batch_size_per_try(self):
        """Test the `BaseTabularModel._sample_batch` method with `batch_size_per_try`.

        Expect that the expected calls to `_sample_rows` are made.

        Input:
            - num_rows = 10
            - batch_size_per_try = 5
        Output:
            - The requested number of sampled rows.
        Side Effect:
            - Call `_sample_rows` method twice with the expected number of rows.
        """
        # Setup
        model = Mock(spec_set=CTGAN)
        sampled_data = pd.DataFrame({
            'column1': [28, 28, 21, 1, 2],
            'column2': [37, 37, 1, 4, 5],
            'column3': [93, 93, 6, 4, 12],
        })
        model._sample_rows.side_effect = [
            (sampled_data, 5),
            (sampled_data.append(sampled_data, ignore_index=False), 10),
        ]

        # Run
        output = BaseTabularModel._sample_batch(model,
                                                num_rows=10,
                                                batch_size_per_try=5)

        # Assert
        assert model._sample_rows.has_calls([
            call(5, None, None, 0.01, DataFrameMatcher(pd.DataFrame())),
            call(5, None, None, 0.01, DataFrameMatcher(sampled_data)),
        ])
        assert len(output) == 10