def test__sample_batch_output_file_path(self, path_mock): """Test the `BaseTabularModel._sample_batch` method with a valid output file path. Expect that if the output file is empty, the sampled rows are written to the file with the header included in the first batch write. Input: - num_rows = 4 - output_file_path = temp file Output: - The requested number of sampled rows (4). """ # Setup model = Mock(spec_set=CTGAN) sampled_mock = MagicMock() sampled_mock.__len__.return_value = 4 model._sample_rows.return_value = (sampled_mock, 4) output_file_path = 'test.csv' path_mock.getsize.return_value = 0 # Run output = BaseTabularModel._sample_batch( model, num_rows=4, output_file_path=output_file_path) # Assert assert model._sample_rows.call_count == 1 assert output == sampled_mock.head.return_value assert sampled_mock.head.return_value.tail.return_value.to_csv.called_once_with( call(2).tail(2).to_csv(output_file_path, index=False), )
def test__sample_batch_zero_valid(self): """Test the `BaseTabularModel._sample_batch` method with zero valid rows. Expect that the requested number of rows are returned, if the first `_sample_rows` call returns zero valid rows, and the second one returns enough valid rows. See https://github.com/sdv-dev/SDV/issues/285. Input: - num_rows = 5 - condition on `column1` = 2 Output: - The requested number of sampled rows (5). """ # Setup model = Mock(spec_set=CTGAN) valid_sampled_data = pd.DataFrame({ "column1": [28, 28, 21, 1, 2], "column2": [37, 37, 1, 4, 5], "column3": [93, 93, 6, 4, 12], }) model._sample_rows.side_effect = [(pd.DataFrame({}), 0), (valid_sampled_data, 5)] conditions = { 'column1': 2, 'column1': 2, 'column1': 2, 'column1': 2, 'column1': 2, } # Run output = BaseTabularModel._sample_batch(model, num_rows=5, conditions=conditions) # Assert assert model._sample_rows.call_count == 2 assert len(output) == 5
def test__sample_batch_with_batch_size_per_try(self): """Test the `BaseTabularModel._sample_batch` method with `batch_size_per_try`. Expect that the expected calls to `_sample_rows` are made. Input: - num_rows = 10 - batch_size_per_try = 5 Output: - The requested number of sampled rows. Side Effect: - Call `_sample_rows` method twice with the expected number of rows. """ # Setup model = Mock(spec_set=CTGAN) sampled_data = pd.DataFrame({ 'column1': [28, 28, 21, 1, 2], 'column2': [37, 37, 1, 4, 5], 'column3': [93, 93, 6, 4, 12], }) model._sample_rows.side_effect = [ (sampled_data, 5), (sampled_data.append(sampled_data, ignore_index=False), 10), ] # Run output = BaseTabularModel._sample_batch(model, num_rows=10, batch_size_per_try=5) # Assert assert model._sample_rows.has_calls([ call(5, None, None, 0.01, DataFrameMatcher(pd.DataFrame())), call(5, None, None, 0.01, DataFrameMatcher(sampled_data)), ]) assert len(output) == 10