コード例 #1
0
def test_recreate():
    data = load_demo(metadata=False)['users']

    # If distribution is non parametric, get_parameters fails
    model = CTGAN(epochs=1)
    model.fit(data)
    sampled = model.sample()

    assert sampled.shape == data.shape
    assert (sampled.dtypes == data.dtypes).all()
    assert (sampled.notnull().sum(axis=1) != 0).all()

    # Metadata
    model_meta = CTGAN(epochs=1, table_metadata=model.get_metadata())
    model_meta.fit(data)
    sampled = model_meta.sample()

    assert sampled.shape == data.shape
    assert (sampled.dtypes == data.dtypes).all()
    assert (sampled.notnull().sum(axis=1) != 0).all()

    # Metadata dict
    model_meta_dict = CTGAN(epochs=1,
                            table_metadata=model.get_metadata().to_dict())
    model_meta_dict.fit(data)
    sampled = model_meta_dict.sample()

    assert sampled.shape == data.shape
    assert (sampled.dtypes == data.dtypes).all()
    assert (sampled.notnull().sum(axis=1) != 0).all()
コード例 #2
0
def test_ctgan():
    users = load_demo(metadata=False)['users']

    ctgan = CTGAN(primary_key='user_id', epochs=1)
    ctgan.fit(users)

    sampled = ctgan.sample()

    # test shape is right
    assert sampled.shape == users.shape

    # test user_id has been generated as an ID field
    assert list(sampled['user_id']) == list(range(0, len(users)))

    assert ctgan.get_metadata().to_dict() == {
        'fields': {
            'user_id': {
                'type': 'id',
                'subtype': 'integer'
            },
            'country': {
                'type': 'categorical'
            },
            'gender': {
                'type': 'categorical'
            },
            'age': {
                'type': 'numerical',
                'subtype': 'integer'
            }
        },
        'constraints': [],
        'model_kwargs': {}
    }
コード例 #3
0
ファイル: test_ctgan.py プロジェクト: surajitdb/SDV
def test_unique_combination_constraint():
    employees = load_tabular_demo()

    unique_company_department_constraint = UniqueCombinations(
        columns=['company', 'department'], handling_strategy='transform')

    model = CTGAN(constraints=[unique_company_department_constraint])
    model.fit(employees)
    model.sample(10)
コード例 #4
0
ファイル: test_ctgan.py プロジェクト: sdv-dev/SDV
def test_conditional_sampling_dict():
    data = pd.DataFrame({
        'column1': [1.0, 0.5, 2.5] * 10,
        'column2': ['a', 'b', 'c'] * 10
    })

    model = CTGAN(epochs=1)
    model.fit(data)
    conditions = [Condition({'column2': 'b'}, num_rows=30)]
    sampled = model.sample_conditions(conditions=conditions)

    assert sampled.shape == data.shape
    assert set(sampled['column2'].unique()) == set(['b'])
コード例 #5
0
ファイル: test_ctgan.py プロジェクト: sdv-dev/SDV
def test_fixed_combination_constraint():
    # Setup
    employees = load_tabular_demo()
    fixed_company_department_constraint = FixedCombinations(
        column_names=['company', 'department'])
    model = CTGAN(constraints=[fixed_company_department_constraint])

    # Run
    model.fit(employees)
    sampled = model.sample(10)

    # Assert
    assert all(fixed_company_department_constraint.is_valid(sampled))
コード例 #6
0
ファイル: test_ctgan.py プロジェクト: sdv-dev/SDV
def test_conditional_sampling_two_conditions():
    data = pd.DataFrame({
        'column1': [1.0, 0.5, 2.5] * 10,
        'column2': ['a', 'b', 'c'] * 10,
        'column3': ['d', 'e', 'f'] * 10
    })

    model = CTGAN(epochs=1)
    model.fit(data)
    conditions = [Condition({'column2': 'b', 'column3': 'f'}, num_rows=5)]
    samples = model.sample_conditions(conditions=conditions)
    assert list(samples.column2) == ['b'] * 5
    assert list(samples.column3) == ['f'] * 5
コード例 #7
0
ファイル: test_ctgan.py プロジェクト: sdv-dev/SDV
def test_conditional_sampling_dataframe():
    data = pd.DataFrame({
        'column1': [1.0, 0.5, 2.5] * 10,
        'column2': ['a', 'b', 'c'] * 10
    })

    model = CTGAN(epochs=1)
    model.fit(data)
    conditions = pd.DataFrame({'column2': ['b', 'b', 'b', 'c', 'c']})
    sampled = model.sample_remaining_columns(conditions)

    assert sampled.shape[0] == len(conditions['column2'])
    assert (sampled['column2'] == np.array(['b', 'b', 'b', 'c', 'c'])).all()
コード例 #8
0
ファイル: test_ctgan.py プロジェクト: surajitdb/SDV
def test_conditional_sampling_dict():
    data = pd.DataFrame({
        "column1": [1.0, 0.5, 2.5] * 10,
        "column2": ["a", "b", "c"] * 10
    })

    model = CTGAN(epochs=1)
    model.fit(data)
    conditions = {"column2": "b"}
    sampled = model.sample(30, conditions=conditions)

    assert sampled.shape == data.shape
    assert set(sampled["column2"].unique()) == set(["b"])
コード例 #9
0
ファイル: test_ctgan.py プロジェクト: surajitdb/SDV
def test_conditional_sampling_two_conditions():
    data = pd.DataFrame({
        "column1": [1.0, 0.5, 2.5] * 10,
        "column2": ["a", "b", "c"] * 10,
        "column3": ["d", "e", "f"] * 10
    })

    model = CTGAN(epochs=1)
    model.fit(data)
    conditions = {"column2": "b", "column3": "f"}
    samples = model.sample(5, conditions=conditions)
    assert list(samples.column2) == ['b'] * 5
    assert list(samples.column3) == ['f'] * 5
コード例 #10
0
ファイル: test_ctgan.py プロジェクト: surajitdb/SDV
def test_conditional_sampling_dataframe():
    data = pd.DataFrame({
        "column1": [1.0, 0.5, 2.5] * 10,
        "column2": ["a", "b", "c"] * 10
    })

    model = CTGAN(epochs=1)
    model.fit(data)
    conditions = pd.DataFrame({"column2": ["b", "b", "b", "c", "c"]})
    sampled = model.sample(conditions=conditions)

    assert sampled.shape[0] == len(conditions["column2"])
    assert (sampled["column2"] == np.array(["b", "b", "b", "c", "c"])).all()
コード例 #11
0
ファイル: test_ctgan.py プロジェクト: csala/SDV
def test_conditional_sampling_numerical():
    data = pd.DataFrame({
        "column1": [1.0, 0.5, 2.5] * 10,
        "column2": ["a", "b", "c"] * 10,
        "column3": ["d", "e", "f"] * 10
    })

    model = CTGAN(epochs=1)
    model.fit(data)
    conditions = {
        "column1": 1.0,
    }
    sampled = model.sample(5, conditions=conditions)

    assert list(sampled.column1) == [1.0] * 5
コード例 #12
0
ファイル: test_ctgan.py プロジェクト: sdv-dev/SDV
def test_conditional_sampling_numerical():
    data = pd.DataFrame({
        'column1': [1.0, 0.5, 2.5] * 10,
        'column2': ['a', 'b', 'c'] * 10,
        'column3': ['d', 'e', 'f'] * 10
    })

    model = CTGAN(epochs=1)
    model.fit(data)
    conditions = [Condition({
        'column1': 1.0,
    }, num_rows=5)]
    sampled = model.sample_conditions(conditions=conditions)

    assert list(sampled.column1) == [1.0] * 5
コード例 #13
0
ファイル: test_ctgan.py プロジェクト: csala/SDV
def test_ctgan():
    users = load_demo(metadata=False)['users']

    ctgan = CTGAN(
        primary_key='user_id',
        epochs=1
    )
    ctgan.fit(users)

    sampled = ctgan.sample()

    # test shape is right
    assert sampled.shape == users.shape

    # test user_id has been generated as an ID field
    assert list(sampled['user_id']) == list(range(0, len(users)))

    expected_metadata = {
        'fields': {
            'user_id': {
                'type': 'id',
                'subtype': 'integer',
                'transformer': 'integer',
            },
            'country': {
                'type': 'categorical',
                'transformer': None,
            },
            'gender': {
                'type': 'categorical',
                'transformer': None,
            },
            'age': {
                'type': 'numerical',
                'subtype': 'integer',
                'transformer': 'integer',
            }
        },
        'primary_key': 'user_id',
        'constraints': [],
        'sequence_index': None,
        'context_columns': [],
        'entity_columns': [],
        'model_kwargs': {},
        'name': None
    }
    assert ctgan.get_metadata().to_dict() == expected_metadata
コード例 #14
0
ファイル: test_base.py プロジェクト: csala/SDV
def test__init__passes_correct_parameters(metadata_mock):
    """
    Tests the ``BaseTabularModel.__init__`` method.

    The method should pass the parameters to the ``Table``
    class.

    Input:
    - rounding set to an int
    - max_value set to an int
    - min_value set to an int
    Side Effects:
    - ``instance._metadata`` should receive the correct parameters
    """
    # Run
    GaussianCopula(rounding=-1, max_value=100, min_value=-50)
    CTGAN(epochs=1, rounding=-1, max_value=100, min_value=-50)
    TVAE(epochs=1, rounding=-1, max_value=100, min_value=-50)
    CopulaGAN(epochs=1, rounding=-1, max_value=100, min_value=-50)

    # Asserts
    assert len(metadata_mock.mock_calls) == 5
    expected_calls = [
        call(field_names=None,
             primary_key=None,
             field_types=None,
             field_transformers=None,
             anonymize_fields=None,
             constraints=None,
             dtype_transformers={'O': 'one_hot_encoding'},
             rounding=-1,
             max_value=100,
             min_value=-50),
        call(field_names=None,
             primary_key=None,
             field_types=None,
             field_transformers=None,
             anonymize_fields=None,
             constraints=None,
             dtype_transformers={'O': None},
             rounding=-1,
             max_value=100,
             min_value=-50),
        call(field_names=None,
             primary_key=None,
             field_types=None,
             field_transformers=None,
             anonymize_fields=None,
             constraints=None,
             dtype_transformers={'O': None},
             rounding=-1,
             max_value=100,
             min_value=-50),
        call(field_names=None,
             primary_key=None,
             field_types=None,
             field_transformers=None,
             anonymize_fields=None,
             constraints=None,
             dtype_transformers={'O': None},
             rounding=-1,
             max_value=100,
             min_value=-50)
    ]
    metadata_mock.assert_has_calls(expected_calls, any_order=True)
コード例 #15
0
ファイル: test_base.py プロジェクト: csala/SDV
from unittest.mock import Mock, call, patch

import pandas as pd
import pytest

from sdv.metadata.table import Table
from sdv.tabular.base import COND_IDX
from sdv.tabular.copulagan import CopulaGAN
from sdv.tabular.copulas import GaussianCopula
from sdv.tabular.ctgan import CTGAN, TVAE
from tests.utils import DataFrameMatcher

MODELS = [
    CTGAN(epochs=1),
    TVAE(epochs=1),
    GaussianCopula(),
    CopulaGAN(epochs=1),
]


class TestBaseTabularModel:
    def test_sample_no_transformed_columns(self):
        """Test the ``BaseTabularModel.sample`` method with no transformed columns.

        When the transformed conditions DataFrame has no columns, expect that sample
        does not pass through any conditions when conditionally sampling.

        Setup:
            - Mock the ``_make_conditions_df`` method to return a dataframe representing
              the expected conditions, and the ``get_fields`` method to return metadata
              fields containing the expected conditioned column.
コード例 #16
0
ファイル: test_base.py プロジェクト: csala/SDV
from unittest.mock import patch

import pandas as pd
import pytest
from copulas.multivariate.gaussian import GaussianMultivariate

from sdv.constraints import Unique, UniqueCombinations
from sdv.constraints.tabular import GreaterThan
from sdv.tabular.copulagan import CopulaGAN
from sdv.tabular.copulas import GaussianCopula
from sdv.tabular.ctgan import CTGAN, TVAE

MODELS = [
    pytest.param(CTGAN(epochs=1), id='CTGAN'),
    pytest.param(TVAE(epochs=1), id='TVAE'),
    pytest.param(GaussianCopula(), id='GaussianCopula'),
    pytest.param(CopulaGAN(epochs=1), id='CopulaGAN'),
]


@pytest.mark.parametrize('model', MODELS)
def test_conditional_sampling_graceful_reject_sampling_True_dict(model):
    data = pd.DataFrame({
        'column1': list(range(100)),
        'column2': list(range(100)),
        'column3': list(range(100))
    })

    model.fit(data)
    conditions = {'column1': 28, 'column2': 37, 'column3': 93}