def test__sample_conditions_no_rows(self): """Test `BaseTabularModel._sample_conditions` with invalid condition. If no valid rows are returned for any condition, expect a ValueError. Input: - condition that is impossible to satisfy Side Effects: - ValueError is thrown """ # Setup model = Mock(spec_set=CTGAN) condition = Condition( {'column1': 'b'}, num_rows=5, ) model._make_condition_dfs.return_value = pd.DataFrame([{ 'column1': 'b' }] * 5) model._sample_with_conditions.return_value = pd.DataFrame() # Run and assert with pytest.raises( ValueError, match='Unable to sample any rows for the given conditions.'): BaseTabularModel._sample_conditions(model, [condition], 100, None, True, None)
def test__sample_conditions_with_multiple_conditions(self): """Test the `BaseTabularModel._sample_conditions` method with multiple condtions. When multiple condition dataframes are returned by `_make_condition_dfs`, expect `_sample_with_conditions` is called for each condition dataframe. Input: - 2 conditions with different columns Output: - The expected sampled rows """ # Setup model = Mock(spec_set=CTGAN) model._validate_file_path.return_value = None condition_values1 = {'cola': 'a'} condition1 = Condition(condition_values1, num_rows=2) sampled1 = pd.DataFrame({'a': ['a', 'a'], 'b': [1, 2]}) condition_values2 = {'colb': 1} condition2 = Condition(condition_values2, num_rows=3) sampled2 = pd.DataFrame({'a': ['b', 'c', 'a'], 'b': [1, 1, 1]}) expected = pd.DataFrame({ 'a': ['a', 'a', 'b', 'c', 'a'], 'b': [1, 2, 1, 1, 1], }) model._make_condition_dfs.return_value = [ pd.DataFrame([condition_values1] * 2), pd.DataFrame([condition_values2] * 3), ] model._sample_with_conditions.side_effect = [ sampled1, sampled2, ] # Run out = BaseTabularModel._sample_conditions(model, [condition1, condition2], 100, None, True, None) # Asserts model._sample_with_conditions.assert_has_calls([ call(DataFrameMatcher(pd.DataFrame([condition_values1] * 2)), 100, None, ANY, None), call(DataFrameMatcher(pd.DataFrame([condition_values2] * 3)), 100, None, ANY, None), ]) pd.testing.assert_frame_equal(out, expected)