def test__sample_remaining_columns(self): """Test the `BaseTabularModel._sample_remaining_colmns` method. When a valid DataFrame is given, expect `_sample_with_conditions` to be called with the input DataFrame. Input: - DataFrame with condition column values populated. Output: - The expected sampled rows. Side Effects: - `_sample_with_conditions` is called once. """ # Setup model = Mock(spec_set=CTGAN) model._validate_file_path.return_value = None conditions = pd.DataFrame([{'cola': 'a'}] * 5) sampled = pd.DataFrame({ 'cola': ['a', 'a', 'a', 'a', 'a'], 'colb': [1, 2, 1, 1, 1], }) model._sample_with_conditions.return_value = sampled # Run out = BaseTabularModel._sample_remaining_columns( model, conditions, 100, None, True, None) # Asserts model._sample_with_conditions.assert_called_once_with( DataFrameMatcher(conditions), 100, None, ANY, None) pd.testing.assert_frame_equal(out, sampled)
def test__sample_remaining_columns_no_rows(self): """Test `BaseTabularModel._sample_remaining_columns` with invalid condition. If no valid rows are returned for any condition, expect a ValueError. Input: - condition that is impossible to satisfy Side Effects: - ValueError is thrown """ # Setup model = Mock(spec_set=CTGAN) conditions = pd.DataFrame([{'cola': 'a'}] * 5) model._sample_with_conditions.return_value = pd.DataFrame() # Run and assert with pytest.raises( ValueError, match='Unable to sample any rows for the given conditions.'): BaseTabularModel._sample_remaining_columns(model, conditions, 100, None, True, None)