def test__validate_conditions_with_conditions_invalid_column(self): """Test the `BaseTabularModel._validate_conditions` method with an invalid column. When a condition has an invalid column, expect a ValueError. Input: - Conditions DataFrame with an invalid column. Side Effects: - A ValueError is thrown. """ # Setup model = Mock(spec_set=CTGAN) metadata_mock = Mock() metadata_mock.get_fields.return_value = {'cola': {}} model._metadata = metadata_mock conditions = pd.DataFrame([{'colb': 'a'}] * 5) # Run and Assert with pytest.raises( ValueError, match=( 'Unexpected column name `colb`. ' 'Use a column name that was present in the original data.' )): BaseTabularModel._validate_conditions(model, conditions)
def test__sample_conditions_no_rows(self): """Test `BaseTabularModel._sample_conditions` with invalid condition. If no valid rows are returned for any condition, expect a ValueError. Input: - condition that is impossible to satisfy Side Effects: - ValueError is thrown """ # Setup model = Mock(spec_set=CTGAN) condition = Condition( {'column1': 'b'}, num_rows=5, ) model._make_condition_dfs.return_value = pd.DataFrame([{ 'column1': 'b' }] * 5) model._sample_with_conditions.return_value = pd.DataFrame() # Run and assert with pytest.raises( ValueError, match='Unable to sample any rows for the given conditions.'): BaseTabularModel._sample_conditions(model, [condition], 100, None, True, None)
def test_sample_with_default_file_path_error(self, os_mock): """Test the `BaseTabularModel.sample` method with the default file path. Expect that the file is not removed if there is an error with sampling. Input: - output_file_path=None. Side Effects: - ValueError is thrown. """ # Setup model = Mock() model._validate_file_path.return_value = TMP_FILE_NAME model._sample_batch.side_effect = ValueError('test error') # Run with pytest.raises(ValueError, match='test error'): BaseTabularModel.sample(model, 1, output_file_path=None) # Assert model._sample_batch.called_once_with(1, batch_size_per_try=1, progress_bar=ANY, output_file_path=TMP_FILE_NAME) assert os_mock.remove.call_count == 0
def test_sample_with_default_file_path(self, os_mock): """Test the `BaseTabularModel.sample` method with the default file path. Expect that the file is removed after successfully sampling. Input: - output_file_path=None. Side Effects: - The file is removed. """ # Setup model = Mock() model._validate_file_path.return_value = TMP_FILE_NAME model._sample_batch.return_value = pd.DataFrame({'test': [1]}) os_mock.path.exists.return_value = True # Run BaseTabularModel.sample(model, 1, output_file_path=None) # Assert model._sample_batch.called_once_with(1, batch_size_per_try=1, progress_bar=ANY, output_file_path=TMP_FILE_NAME) os_mock.remove.called_once_with(TMP_FILE_NAME)
def test_sample_no_num_rows(self): """Test the `BaseTabularModel.sample` method with no `num_rows` input. Expect that an error is thrown. """ # Setup model = BaseTabularModel() # Run and assert with pytest.raises( TypeError, match= r'sample\(\) missing 1 required positional argument: \'num_rows\'' ): model.sample()
def test__randomize_samples_false(): """Test that ``_randomize_samples`` is a no-op when user wants random samples. Input: - randomize_samples as False """ # Setup instance = Mock() randomize_samples = False # Run BaseTabularModel._randomize_samples(instance, randomize_samples) # Assert assert instance._set_random_state.called_once_with(None)
def test__conditionally_sample_rows_graceful_reject_sampling_true(self): """Test the `BaseTabularModel._conditionally_sample_rows` method. When `_sample_with_conditions` is called with `graceful_reject_sampling` as True, expect that there are no errors if no valid rows are generated. Input: - An impossible condition Returns: - Empty DataFrame """ # Setup model = Mock(spec_set=CTGAN) model._validate_file_path.return_value = None condition_values = {'cola': 'c'} transformed_conditions = pd.DataFrame([condition_values] * 2) condition = Condition(condition_values, num_rows=2) model._sample_batch.return_value = pd.DataFrame() # Run sampled = BaseTabularModel._conditionally_sample_rows( model, pd.DataFrame([condition_values] * 2), condition, transformed_conditions, graceful_reject_sampling=True, ) # Assert assert len(sampled) == 0 model._sample_batch.assert_called_once_with(2, None, None, condition, transformed_conditions, 0.01, None, None)
def test__sample_remaining_columns(self): """Test the `BaseTabularModel._sample_remaining_colmns` method. When a valid DataFrame is given, expect `_sample_with_conditions` to be called with the input DataFrame. Input: - DataFrame with condition column values populated. Output: - The expected sampled rows. Side Effects: - `_sample_with_conditions` is called once. """ # Setup model = Mock(spec_set=CTGAN) model._validate_file_path.return_value = None conditions = pd.DataFrame([{'cola': 'a'}] * 5) sampled = pd.DataFrame({ 'cola': ['a', 'a', 'a', 'a', 'a'], 'colb': [1, 2, 1, 1, 1], }) model._sample_with_conditions.return_value = sampled # Run out = BaseTabularModel._sample_remaining_columns( model, conditions, 100, None, True, None) # Asserts model._sample_with_conditions.assert_called_once_with( DataFrameMatcher(conditions), 100, None, ANY, None) pd.testing.assert_frame_equal(out, sampled)
def test_sample_valid_num_rows(self, tqdm_mock): """Test the `BaseTabularModel.sample` method with a valid `num_rows` argument. Expect that the expected call to `_sample_batch` is made. Input: - num_rows = 5 Output: - The requested number of sampled rows. Side Effect: - Call `_sample_batch` method with the expected number of rows. """ # Setup model = Mock(spec_set=CTGAN) valid_sampled_data = pd.DataFrame({ 'column1': [28, 28, 21, 1, 2], 'column2': [37, 37, 1, 4, 5], 'column3': [93, 93, 6, 4, 12], }) model._sample_batch.return_value = valid_sampled_data # Run output = BaseTabularModel.sample(model, 5) # Assert assert model._sample_batch.called_once_with(5) assert tqdm_mock.call_count == 0 assert len(output) == 5
def test__sample_batch_output_file_path(self, path_mock): """Test the `BaseTabularModel._sample_batch` method with a valid output file path. Expect that if the output file is empty, the sampled rows are written to the file with the header included in the first batch write. Input: - num_rows = 4 - output_file_path = temp file Output: - The requested number of sampled rows (4). """ # Setup model = Mock(spec_set=CTGAN) sampled_mock = MagicMock() sampled_mock.__len__.return_value = 4 model._sample_rows.return_value = (sampled_mock, 4) output_file_path = 'test.csv' path_mock.getsize.return_value = 0 # Run output = BaseTabularModel._sample_batch( model, num_rows=4, output_file_path=output_file_path) # Assert assert model._sample_rows.call_count == 1 assert output == sampled_mock.head.return_value assert sampled_mock.head.return_value.tail.return_value.to_csv.called_once_with( call(2).tail(2).to_csv(output_file_path, index=False), )
def test__sample_with_conditions_no_transformed_columns(self): """Test the ``BaseTabularModel.sample`` method with no transformed columns. When the transformed conditions DataFrame has no columns, expect that sample does not pass through any conditions when conditionally sampling. Setup: - Mock the ``_make_condition_dfs`` method to return a dataframe representing the expected conditions, and the ``get_fields`` method to return metadata fields containing the expected conditioned column. - Mock the ``_metadata.transform`` method to return an empty transformed conditions dataframe. - Mock the ``_conditionally_sample_rows`` method to return the expected sampled rows. - Mock the `make_ids_unique` to return the expected sampled rows. Input: - number of rows - one set of conditions Output: - the expected sampled rows Side Effects: - Expect ``_conditionally_sample_rows`` to be called with the given condition and a transformed_condition of None. """ # Setup model = Mock(spec_set=CTGAN) expected = pd.DataFrame(['a', 'a', 'a']) condition_dataframe = pd.DataFrame({'a': ['a', 'a', 'a']}) model._make_condition_dfs.return_value = condition_dataframe model._metadata.get_fields.return_value = ['a'] model._metadata.transform.return_value = pd.DataFrame({}, index=[0, 1, 2]) model._conditionally_sample_rows.return_value = pd.DataFrame({ 'a': ['a', 'a', 'a'], COND_IDX: [0, 1, 2] }) model._metadata.make_ids_unique.return_value = expected # Run out = BaseTabularModel._sample_with_conditions(model, condition_dataframe, 100, None) # Asserts model._conditionally_sample_rows.assert_called_once_with( DataFrameMatcher( pd.DataFrame({ COND_IDX: [0, 1, 2], 'a': ['a', 'a', 'a'] })), {'a': 'a'}, None, 100, None, progress_bar=None, output_file_path=None, ) pd.testing.assert_frame_equal(out, expected)
def test__randomize_samples_true(): """Test that ``_randomize_samples`` sets the random state correctly. Input: - randomize_samples as True Side Effect: - random state is set """ # Setup instance = Mock() randomize_samples = True # Run BaseTabularModel._randomize_samples(instance, randomize_samples) # Assert assert instance._set_random_state.called_once_with(FIXED_RNG_SEED)
def test__validate_conditions_with_conditions_valid_columns(self): """Test the `BaseTabularModel._validate_conditions` method with valid columns. Expect no error to be thrown. Input: - Conditions DataFrame contains only valid columns. """ # Setup model = Mock(spec_set=CTGAN) metadata_mock = Mock() metadata_mock.get_fields.return_value = {'cola': {}} model._metadata = metadata_mock conditions = pd.DataFrame([{'cola': 'a'}] * 5) # Run and Assert BaseTabularModel._validate_conditions(model, conditions)
def test__validate_file_path(self, path_mock): """Test the `BaseTabularModel._validate_file_path` method. Expect that an error is thrown if the file path already exists. Input: - A file path that already exists. Side Effects: - An AssertionError. """ # Setup path_mock.exists.return_value = True path_mock.abspath.return_value = 'path/to/file' model = Mock(spec_set=CTGAN) # Run and Assert with pytest.raises(AssertionError, match='path/to/file already exists'): BaseTabularModel._validate_file_path(model, 'file_path')
def test_sample_num_rows_none(self): """Test the `BaseTabularModel.sample` method with a `num_rows` input of `None`. Expect that a `ValueError` is thrown. Input: - num_rows = None Side Effect: - ValueError """ # Setup model = BaseTabularModel() num_rows = None # Run and assert with pytest.raises( ValueError, match= r'You must specify the number of rows to sample \(e.g. num_rows=100\)' ): model.sample(num_rows)
def test__sample_conditions_with_multiple_conditions(self): """Test the `BaseTabularModel._sample_conditions` method with multiple condtions. When multiple condition dataframes are returned by `_make_condition_dfs`, expect `_sample_with_conditions` is called for each condition dataframe. Input: - 2 conditions with different columns Output: - The expected sampled rows """ # Setup model = Mock(spec_set=CTGAN) model._validate_file_path.return_value = None condition_values1 = {'cola': 'a'} condition1 = Condition(condition_values1, num_rows=2) sampled1 = pd.DataFrame({'a': ['a', 'a'], 'b': [1, 2]}) condition_values2 = {'colb': 1} condition2 = Condition(condition_values2, num_rows=3) sampled2 = pd.DataFrame({'a': ['b', 'c', 'a'], 'b': [1, 1, 1]}) expected = pd.DataFrame({ 'a': ['a', 'a', 'b', 'c', 'a'], 'b': [1, 2, 1, 1, 1], }) model._make_condition_dfs.return_value = [ pd.DataFrame([condition_values1] * 2), pd.DataFrame([condition_values2] * 3), ] model._sample_with_conditions.side_effect = [ sampled1, sampled2, ] # Run out = BaseTabularModel._sample_conditions(model, [condition1, condition2], 100, None, True, None) # Asserts model._sample_with_conditions.assert_has_calls([ call(DataFrameMatcher(pd.DataFrame([condition_values1] * 2)), 100, None, ANY, None), call(DataFrameMatcher(pd.DataFrame([condition_values2] * 3)), 100, None, ANY, None), ]) pd.testing.assert_frame_equal(out, expected)
def test_sample_with_custom_file_path(self, os_mock): """Test the `BaseTabularModel.sample` method with a custom file path. Expect that the file is not removed if a custom file path is given. Input: - output_file_path='temp.csv'. Side Effects: - None """ # Setup model = Mock() model._validate_file_path.return_value = 'temp.csv' model._sample_batch.return_value = pd.DataFrame({'test': [1]}) # Run BaseTabularModel.sample(model, 1, output_file_path='temp.csv') # Assert model._sample_batch.called_once_with(1, batch_size_per_try=1, progress_bar=ANY, output_file_path='temp.csv') assert os_mock.remove.call_count == 0
def test__conditionally_sample_rows_graceful_reject_sampling_false(self): """Test the `BaseTabularModel._conditionally_sample_rows` method. When `_sample_with_conditions` is called with `graceful_reject_sampling` as False, expect that an error is thrown if no valid rows are generated. Input: - An impossible condition Side Effect: - A ValueError is thrown """ # Setup model = Mock(spec_set=CTGAN) model._validate_file_path.return_value = None condition_values = {'cola': 'c'} transformed_conditions = pd.DataFrame([condition_values] * 2) condition = Condition(condition_values, num_rows=2) model._sample_batch.return_value = pd.DataFrame() # Run and assert with pytest.raises( ValueError, match='Unable to sample any rows for the given conditions'): BaseTabularModel._conditionally_sample_rows( model, pd.DataFrame([condition_values] * 2), condition, transformed_conditions, graceful_reject_sampling=False, ) model._sample_batch.assert_called_once_with(2, None, None, condition, transformed_conditions, 0.01, None, None)
def test_sample_remaining_columns(self): """Test `BaseTabularModel.sample_remaining_columns` method. Expect the correct args to be passed to `_sample_remaining_columns`. Input: - valid DataFrame Side Effects: - The expected `_sample_remaining_columns` call. """ # Setup model = Mock(spec_set=CTGAN) conditions = pd.DataFrame([{'cola': 'a'}] * 5) # Run out = BaseTabularModel.sample_remaining_columns(model, conditions) # Assert model._sample_remaining_columns.assert_called_once_with( conditions, 100, None, True, None) assert out == model._sample_remaining_columns.return_value
def test__sample_batch_zero_valid(self): """Test the `BaseTabularModel._sample_batch` method with zero valid rows. Expect that the requested number of rows are returned, if the first `_sample_rows` call returns zero valid rows, and the second one returns enough valid rows. See https://github.com/sdv-dev/SDV/issues/285. Input: - num_rows = 5 - condition on `column1` = 2 Output: - The requested number of sampled rows (5). """ # Setup model = Mock(spec_set=CTGAN) valid_sampled_data = pd.DataFrame({ "column1": [28, 28, 21, 1, 2], "column2": [37, 37, 1, 4, 5], "column3": [93, 93, 6, 4, 12], }) model._sample_rows.side_effect = [(pd.DataFrame({}), 0), (valid_sampled_data, 5)] conditions = { 'column1': 2, 'column1': 2, 'column1': 2, 'column1': 2, 'column1': 2, } # Run output = BaseTabularModel._sample_batch(model, num_rows=5, conditions=conditions) # Assert assert model._sample_rows.call_count == 2 assert len(output) == 5
def test_sample_batch_size(self, tqdm_mock): """Test the `BaseTabularModel.sample` method with a valid `batch_size` argument. Expect that the expected calls to `_sample_batch` are made. Input: - num_rows = 10 - batch_size = 5 Output: - The requested number of sampled rows. Side Effect: - Call `_sample_batch` method twice with the expected number of rows. """ # Setup model = Mock(spec_set=CTGAN) sampled_data = pd.DataFrame({ 'column1': [28, 28, 21, 1, 2], 'column2': [37, 37, 1, 4, 5], 'column3': [93, 93, 6, 4, 12], }) model._sample_batch.side_effect = [sampled_data, sampled_data] # Run output = BaseTabularModel.sample(model, 10, batch_size=5) # Assert assert model._sample_batch.has_calls([ call(5, batch_size_per_try=5, progress_bar=ANY, output_file_path=None), call(5, batch_size_per_try=5, progress_bar=ANY, output_file_path=None), ]) tqdm_mock.assert_has_calls([call(total=10)]) assert len(output) == 10
def test_sample_conditions(self): """Test `BaseTabularModel.sample_conditions` method. Expect the correct args to be passed to `_sample_conditions`. Input: - valid conditions Side Effects: - The expected `_sample_conditions` call. """ # Setup model = Mock(spec_set=CTGAN) condition = Condition( {'column1': 'b'}, num_rows=5, ) # Run out = BaseTabularModel.sample_conditions(model, [condition]) # Assert model._sample_conditions.assert_called_once_with([condition], 100, None, True, None) assert out == model._sample_conditions.return_value
def test__sample_batch_with_batch_size_per_try(self): """Test the `BaseTabularModel._sample_batch` method with `batch_size_per_try`. Expect that the expected calls to `_sample_rows` are made. Input: - num_rows = 10 - batch_size_per_try = 5 Output: - The requested number of sampled rows. Side Effect: - Call `_sample_rows` method twice with the expected number of rows. """ # Setup model = Mock(spec_set=CTGAN) sampled_data = pd.DataFrame({ 'column1': [28, 28, 21, 1, 2], 'column2': [37, 37, 1, 4, 5], 'column3': [93, 93, 6, 4, 12], }) model._sample_rows.side_effect = [ (sampled_data, 5), (sampled_data.append(sampled_data, ignore_index=False), 10), ] # Run output = BaseTabularModel._sample_batch(model, num_rows=10, batch_size_per_try=5) # Assert assert model._sample_rows.has_calls([ call(5, None, None, 0.01, DataFrameMatcher(pd.DataFrame())), call(5, None, None, 0.01, DataFrameMatcher(sampled_data)), ]) assert len(output) == 10