Esempio n. 1
0
    def test_sample_with_default_file_path_error(self, os_mock):
        """Test the `BaseTabularModel.sample` method with the default file path.

        Expect that the file is not removed if there is an error with sampling.

        Input:
            - output_file_path=None.
        Side Effects:
            - ValueError is thrown.
        """
        # Setup
        model = Mock()
        model._validate_file_path.return_value = TMP_FILE_NAME
        model._sample_batch.side_effect = ValueError('test error')

        # Run
        with pytest.raises(ValueError, match='test error'):
            BaseTabularModel.sample(model, 1, output_file_path=None)

        # Assert
        model._sample_batch.called_once_with(1,
                                             batch_size_per_try=1,
                                             progress_bar=ANY,
                                             output_file_path=TMP_FILE_NAME)
        assert os_mock.remove.call_count == 0
Esempio n. 2
0
    def test_sample_with_default_file_path(self, os_mock):
        """Test the `BaseTabularModel.sample` method with the default file path.

        Expect that the file is removed after successfully sampling.

        Input:
            - output_file_path=None.
        Side Effects:
            - The file is removed.
        """
        # Setup
        model = Mock()
        model._validate_file_path.return_value = TMP_FILE_NAME
        model._sample_batch.return_value = pd.DataFrame({'test': [1]})
        os_mock.path.exists.return_value = True

        # Run
        BaseTabularModel.sample(model, 1, output_file_path=None)

        # Assert
        model._sample_batch.called_once_with(1,
                                             batch_size_per_try=1,
                                             progress_bar=ANY,
                                             output_file_path=TMP_FILE_NAME)
        os_mock.remove.called_once_with(TMP_FILE_NAME)
Esempio n. 3
0
    def test_sample_no_num_rows(self):
        """Test the `BaseTabularModel.sample` method with no `num_rows` input.

        Expect that an error is thrown.
        """
        # Setup
        model = BaseTabularModel()

        # Run and assert
        with pytest.raises(
                TypeError,
                match=
                r'sample\(\) missing 1 required positional argument: \'num_rows\''
        ):
            model.sample()
Esempio n. 4
0
    def test_sample_valid_num_rows(self, tqdm_mock):
        """Test the `BaseTabularModel.sample` method with a valid `num_rows` argument.

        Expect that the expected call to `_sample_batch` is made.

        Input:
            - num_rows = 5
        Output:
            - The requested number of sampled rows.
        Side Effect:
            - Call `_sample_batch` method with the expected number of rows.
        """
        # Setup
        model = Mock(spec_set=CTGAN)
        valid_sampled_data = pd.DataFrame({
            'column1': [28, 28, 21, 1, 2],
            'column2': [37, 37, 1, 4, 5],
            'column3': [93, 93, 6, 4, 12],
        })
        model._sample_batch.return_value = valid_sampled_data

        # Run
        output = BaseTabularModel.sample(model, 5)

        # Assert
        assert model._sample_batch.called_once_with(5)
        assert tqdm_mock.call_count == 0
        assert len(output) == 5
Esempio n. 5
0
    def test_sample_num_rows_none(self):
        """Test the `BaseTabularModel.sample` method with a `num_rows` input of `None`.

        Expect that a `ValueError` is thrown.

        Input:
            - num_rows = None
        Side Effect:
            - ValueError
        """
        # Setup
        model = BaseTabularModel()
        num_rows = None

        # Run and assert
        with pytest.raises(
                ValueError,
                match=
                r'You must specify the number of rows to sample \(e.g. num_rows=100\)'
        ):
            model.sample(num_rows)
Esempio n. 6
0
    def test_sample_with_custom_file_path(self, os_mock):
        """Test the `BaseTabularModel.sample` method with a custom file path.

        Expect that the file is not removed if a custom file path is given.

        Input:
            - output_file_path='temp.csv'.
        Side Effects:
            - None
        """
        # Setup
        model = Mock()
        model._validate_file_path.return_value = 'temp.csv'
        model._sample_batch.return_value = pd.DataFrame({'test': [1]})

        # Run
        BaseTabularModel.sample(model, 1, output_file_path='temp.csv')

        # Assert
        model._sample_batch.called_once_with(1,
                                             batch_size_per_try=1,
                                             progress_bar=ANY,
                                             output_file_path='temp.csv')
        assert os_mock.remove.call_count == 0
Esempio n. 7
0
    def test_sample_batch_size(self, tqdm_mock):
        """Test the `BaseTabularModel.sample` method with a valid `batch_size` argument.

        Expect that the expected calls to `_sample_batch` are made.

        Input:
            - num_rows = 10
            - batch_size = 5
        Output:
            - The requested number of sampled rows.
        Side Effect:
            - Call `_sample_batch` method twice with the expected number of rows.
        """
        # Setup
        model = Mock(spec_set=CTGAN)
        sampled_data = pd.DataFrame({
            'column1': [28, 28, 21, 1, 2],
            'column2': [37, 37, 1, 4, 5],
            'column3': [93, 93, 6, 4, 12],
        })
        model._sample_batch.side_effect = [sampled_data, sampled_data]

        # Run
        output = BaseTabularModel.sample(model, 10, batch_size=5)

        # Assert
        assert model._sample_batch.has_calls([
            call(5,
                 batch_size_per_try=5,
                 progress_bar=ANY,
                 output_file_path=None),
            call(5,
                 batch_size_per_try=5,
                 progress_bar=ANY,
                 output_file_path=None),
        ])
        tqdm_mock.assert_has_calls([call(total=10)])
        assert len(output) == 10