Esempio n. 1
0
    def test__sample_batch_output_file_path(self, path_mock):
        """Test the `BaseTabularModel._sample_batch` method with a valid output file path.

        Expect that if the output file is empty, the sampled rows are written to the file
        with the header included in the first batch write.

        Input:
            - num_rows = 4
            - output_file_path = temp file
        Output:
            - The requested number of sampled rows (4).
        """
        # Setup
        model = Mock(spec_set=CTGAN)
        sampled_mock = MagicMock()
        sampled_mock.__len__.return_value = 4
        model._sample_rows.return_value = (sampled_mock, 4)
        output_file_path = 'test.csv'
        path_mock.getsize.return_value = 0

        # Run
        output = BaseTabularModel._sample_batch(
            model, num_rows=4, output_file_path=output_file_path)

        # Assert
        assert model._sample_rows.call_count == 1
        assert output == sampled_mock.head.return_value
        assert sampled_mock.head.return_value.tail.return_value.to_csv.called_once_with(
            call(2).tail(2).to_csv(output_file_path, index=False), )
Esempio n. 2
0
    def test__sample_batch_zero_valid(self):
        """Test the `BaseTabularModel._sample_batch` method with zero valid rows.

        Expect that the requested number of rows are returned, if the first `_sample_rows` call
        returns zero valid rows, and the second one returns enough valid rows.
        See https://github.com/sdv-dev/SDV/issues/285.

        Input:
            - num_rows = 5
            - condition on `column1` = 2
        Output:
            - The requested number of sampled rows (5).
        """
        # Setup
        model = Mock(spec_set=CTGAN)
        valid_sampled_data = pd.DataFrame({
            "column1": [28, 28, 21, 1, 2],
            "column2": [37, 37, 1, 4, 5],
            "column3": [93, 93, 6, 4, 12],
        })
        model._sample_rows.side_effect = [(pd.DataFrame({}), 0),
                                          (valid_sampled_data, 5)]

        conditions = {
            'column1': 2,
            'column1': 2,
            'column1': 2,
            'column1': 2,
            'column1': 2,
        }

        # Run
        output = BaseTabularModel._sample_batch(model,
                                                num_rows=5,
                                                conditions=conditions)

        # Assert
        assert model._sample_rows.call_count == 2
        assert len(output) == 5
Esempio n. 3
0
    def test__sample_batch_with_batch_size_per_try(self):
        """Test the `BaseTabularModel._sample_batch` method with `batch_size_per_try`.

        Expect that the expected calls to `_sample_rows` are made.

        Input:
            - num_rows = 10
            - batch_size_per_try = 5
        Output:
            - The requested number of sampled rows.
        Side Effect:
            - Call `_sample_rows` method twice with the expected number of rows.
        """
        # Setup
        model = Mock(spec_set=CTGAN)
        sampled_data = pd.DataFrame({
            'column1': [28, 28, 21, 1, 2],
            'column2': [37, 37, 1, 4, 5],
            'column3': [93, 93, 6, 4, 12],
        })
        model._sample_rows.side_effect = [
            (sampled_data, 5),
            (sampled_data.append(sampled_data, ignore_index=False), 10),
        ]

        # Run
        output = BaseTabularModel._sample_batch(model,
                                                num_rows=10,
                                                batch_size_per_try=5)

        # Assert
        assert model._sample_rows.has_calls([
            call(5, None, None, 0.01, DataFrameMatcher(pd.DataFrame())),
            call(5, None, None, 0.01, DataFrameMatcher(sampled_data)),
        ])
        assert len(output) == 10