Example #1
0
    def test__sample_valid_rows_fitted(self):
        """sample valid rows with model fitted"""

        # Setup
        pk_keys_mock = Mock()
        pk_keys_mock.return_value = 'pk_name', [1, 2, 3, 4]

        synthesized_mock = pd.DataFrame({'foo': [0, 1.1], 'bar': [1, 0]})

        sample_model_mock = Mock()
        sample_model_mock.return_value = synthesized_mock

        missing_valid_rows_mock = Mock()
        missing_valid_rows_mock.side_effect = [(True, {}), (False, {})]
        missing_valid_rows_mock.return_value = False, {}

        dn_mock = Mock()
        dn_mock.get_meta_data.return_value = {
            'fields': {
                'foo': {
                    'type': 'categorical',
                },
                'bar': {
                    'type': 'numeric'
                }
            }
        }

        tables = {'DEMO': pd.DataFrame({'a_field': [1, 0], 'b_field': [0, 1]})}

        # Run
        sampler_mock = Mock()
        sampler_mock._get_primary_keys = pk_keys_mock
        sampler_mock._sample_model = sample_model_mock
        sampler_mock._get_missing_valid_rows = missing_valid_rows_mock
        sampler_mock.modeler.tables = tables
        sampler_mock.dn = dn_mock

        model_mock = Mock()
        model_mock.fitted = True

        Sampler._sample_valid_rows(sampler_mock, model_mock, 5, 'DEMO')

        # Asserts
        assert missing_valid_rows_mock.call_count == 2
        assert sample_model_mock.call_count == 2
Example #2
0
    def test__sample_valid_rows_raises_unfitted_model(self):
        """_sample_valid_rows raise an exception for invalid models."""
        # Setup
        data_navigator = MagicMock(spec=DataNavigator)
        modeler = MagicMock(spec=Modeler)
        sampler = Sampler(data_navigator, modeler)

        data_navigator.get_parents.return_value = set()

        num_rows = 5
        table_name = 'table_name'
        model = None

        # Run
        with self.assertRaises(ValueError):
            sampler._sample_valid_rows(model, num_rows, table_name)

        # Check
        modeler.assert_not_called()
        assert len(modeler.method_calls) == 0

        data_navigator.assert_not_called()
        data_navigator.get_parents.assert_called_once_with('table_name')
Example #3
0
    def test__sample_valid_rows_respect_categorical_values(self):
        """_sample_valid_rows will return rows with valid values for categorical columns."""
        # Setup
        data_navigator = MagicMock(spec=DataNavigator)
        modeler = MagicMock(spec=Modeler)
        sampler = Sampler(data_navigator, modeler)

        data = pd.DataFrame(columns=['field_A', 'field_B'])
        modeler.tables = {
            'table_name': data,
        }

        data_navigator.meta = {
            'tables': [{
                'name':
                'table_name',
                'fields': [{
                    'name': 'field_A',
                    'type': 'categorical'
                }, {
                    'name': 'field_B',
                    'type': 'categorical'
                }]
            }]
        }

        num_rows = 5
        table_name = 'table_name'
        model = MagicMock(spec=GaussianMultivariate)
        model.fitted = True
        sample_dataframe = pd.DataFrame([
            {
                'field_A': 0.5,
                'field_B': 0.5
            },
            {
                'field_A': 0.5,
                'field_B': 0.5
            },
            {
                'field_A': 0.5,
                'field_B': 0.5
            },
            {
                'field_A': 0.5,
                'field_B': 1.5
            },  # Invalid field_B
            {
                'field_A': 1.5,
                'field_B': 0.5
            },  # Invalid field_A
        ])

        model.sample.side_effect = lambda x: sample_dataframe.iloc[:x].copy()

        expected_model_call_args_list = [((5, ), {}), ((2, ), {})]

        expected_result = pd.DataFrame([
            {
                'field_A': 0.5,
                'field_B': 0.5
            },
            {
                'field_A': 0.5,
                'field_B': 0.5
            },
            {
                'field_A': 0.5,
                'field_B': 0.5
            },
            {
                'field_A': 0.5,
                'field_B': 0.5
            },
            {
                'field_A': 0.5,
                'field_B': 0.5
            },
        ])

        # Run
        result = sampler._sample_valid_rows(model, num_rows, table_name)

        # Check
        assert result.equals(expected_result)

        modeler.assert_not_called()
        assert len(modeler.method_calls) == 0

        data_navigator.assert_not_called()
        assert len(data_navigator.method_calls) == 0

        assert model.sample.call_args_list == expected_model_call_args_list