def test__conditionally_sample_rows_graceful_reject_sampling_true(self): """Test the `BaseTabularModel._conditionally_sample_rows` method. When `_sample_with_conditions` is called with `graceful_reject_sampling` as True, expect that there are no errors if no valid rows are generated. Input: - An impossible condition Returns: - Empty DataFrame """ # Setup model = Mock(spec_set=CTGAN) model._validate_file_path.return_value = None condition_values = {'cola': 'c'} transformed_conditions = pd.DataFrame([condition_values] * 2) condition = Condition(condition_values, num_rows=2) model._sample_batch.return_value = pd.DataFrame() # Run sampled = BaseTabularModel._conditionally_sample_rows( model, pd.DataFrame([condition_values] * 2), condition, transformed_conditions, graceful_reject_sampling=True, ) # Assert assert len(sampled) == 0 model._sample_batch.assert_called_once_with(2, None, None, condition, transformed_conditions, 0.01, None, None)
def test__conditionally_sample_rows_graceful_reject_sampling_false(self): """Test the `BaseTabularModel._conditionally_sample_rows` method. When `_sample_with_conditions` is called with `graceful_reject_sampling` as False, expect that an error is thrown if no valid rows are generated. Input: - An impossible condition Side Effect: - A ValueError is thrown """ # Setup model = Mock(spec_set=CTGAN) model._validate_file_path.return_value = None condition_values = {'cola': 'c'} transformed_conditions = pd.DataFrame([condition_values] * 2) condition = Condition(condition_values, num_rows=2) model._sample_batch.return_value = pd.DataFrame() # Run and assert with pytest.raises( ValueError, match='Unable to sample any rows for the given conditions'): BaseTabularModel._conditionally_sample_rows( model, pd.DataFrame([condition_values] * 2), condition, transformed_conditions, graceful_reject_sampling=False, ) model._sample_batch.assert_called_once_with(2, None, None, condition, transformed_conditions, 0.01, None, None)