コード例 #1
0
def test_fixed_random_seed():
    """Test the CTGANSynthesizer with a fixed seed.

    Expect that when the random seed is reset with the same seed, the same sequence
    of data will be produced. Expect that the data generated with the seed is
    different than randomly sampled data.
    """
    # Setup
    data = pd.DataFrame({
        'continuous': np.random.random(100),
        'discrete': np.random.choice(['a', 'b', 'c'], 100)
    })
    discrete_columns = ['discrete']

    ctgan = CTGANSynthesizer(epochs=1)

    # Run
    ctgan.fit(data, discrete_columns)
    sampled_random = ctgan.sample(10)

    ctgan.set_random_state(0)
    sampled_0_0 = ctgan.sample(10)
    sampled_0_1 = ctgan.sample(10)

    ctgan.set_random_state(0)
    sampled_1_0 = ctgan.sample(10)
    sampled_1_1 = ctgan.sample(10)

    # Assert
    assert not np.array_equal(sampled_random, sampled_0_0)
    assert not np.array_equal(sampled_random, sampled_0_1)
    np.testing.assert_array_equal(sampled_0_0, sampled_1_0)
    np.testing.assert_array_equal(sampled_0_1, sampled_1_1)
コード例 #2
0
ファイル: test_ctgan.py プロジェクト: ppeddada97/CTGAN
def test_wrong_sampling_conditions():
    data = pd.DataFrame({
        'continuous': np.random.random(100),
        'discrete': np.random.choice(['a', 'b', 'c'], 100)
    })
    discrete_columns = ['discrete']

    ctgan = CTGANSynthesizer(epochs=1)
    ctgan.fit(data, discrete_columns)

    with pytest.raises(ValueError):
        ctgan.sample(1, 'cardinal', "doesn't matter")

    with pytest.raises(ValueError):
        ctgan.sample(1, 'discrete', "d")
コード例 #3
0
def test_synthesizer_sample():
    data = pd.DataFrame({'discrete': np.random.choice(['a', 'b', 'c'], 100)})
    discrete_columns = ['discrete']

    ctgan = CTGANSynthesizer(epochs=1)
    ctgan.fit(data, discrete_columns)

    samples = ctgan.sample(1000, 'discrete', 'a')
    assert isinstance(samples, pd.DataFrame)
コード例 #4
0
def test_wrong_sampling_conditions():
    """Test the CTGANSynthesizer correctly crashes when passed incorrect sampling conditions."""
    data = pd.DataFrame({
        'continuous': np.random.random(100),
        'discrete': np.random.choice(['a', 'b', 'c'], 100)
    })
    discrete_columns = ['discrete']

    ctgan = CTGANSynthesizer(epochs=1)
    ctgan.fit(data, discrete_columns)

    with pytest.raises(
            ValueError,
            match="The column_name `cardinal` doesn't exist in the data."):
        ctgan.sample(1, 'cardinal', "doesn't matter")

    with pytest.raises(
            ValueError
    ):  # noqa: RDT currently incorrectly raises a tuple instead of a string
        ctgan.sample(1, 'discrete', 'd')
コード例 #5
0
def test_ctgan_no_categoricals():
    data = pd.DataFrame({'continuous': np.random.random(1000)})

    ctgan = CTGANSynthesizer(epochs=1)
    ctgan.fit(data, [])

    sampled = ctgan.sample(100)

    assert sampled.shape == (100, 1)
    assert isinstance(sampled, pd.DataFrame)
    assert set(sampled.columns) == {'continuous'}
コード例 #6
0
ファイル: test_ctgan.py プロジェクト: ppeddada97/CTGAN
def test_log_frequency():
    data = pd.DataFrame({
        'continuous': np.random.random(1000),
        'discrete': np.repeat(['a', 'b', 'c'], [950, 25, 25])
    })

    discrete_columns = ['discrete']

    ctgan = CTGANSynthesizer(epochs=100)
    ctgan.fit(data, discrete_columns)

    sampled = ctgan.sample(10000)
    counts = sampled['discrete'].value_counts()
    assert counts['a'] < 6500

    ctgan = CTGANSynthesizer(log_frequency=False, epochs=100)
    ctgan.fit(data, discrete_columns)

    sampled = ctgan.sample(10000)
    counts = sampled['discrete'].value_counts()
    assert counts['a'] > 9000
コード例 #7
0
ファイル: test_ctgan.py プロジェクト: ppeddada97/CTGAN
def test_ctgan_numpy():
    data = pd.DataFrame({
        'continuous': np.random.random(100),
        'discrete': np.random.choice(['a', 'b', 'c'], 100)
    })
    discrete_columns = [1]

    ctgan = CTGANSynthesizer(epochs=1)
    ctgan.fit(data.values, discrete_columns)

    sampled = ctgan.sample(100)

    assert sampled.shape == (100, 2)
    assert isinstance(sampled, np.ndarray)
    assert set(np.unique(sampled[:, 1])) == {'a', 'b', 'c'}
コード例 #8
0
ファイル: test_ctgan.py プロジェクト: ppeddada97/CTGAN
def test_ctgan_dataframe():
    data = pd.DataFrame({
        'continuous': np.random.random(100),
        'discrete': np.random.choice(['a', 'b', 'c'], 100)
    })
    discrete_columns = ['discrete']

    ctgan = CTGANSynthesizer(epochs=1)
    ctgan.fit(data, discrete_columns)

    sampled = ctgan.sample(100)

    assert sampled.shape == (100, 2)
    assert isinstance(sampled, pd.DataFrame)
    assert set(sampled.columns) == {'continuous', 'discrete'}
    assert set(sampled['discrete'].unique()) == {'a', 'b', 'c'}
コード例 #9
0
ファイル: test_ctgan.py プロジェクト: ppeddada97/CTGAN
def test_save_load():
    data = pd.DataFrame({
        'continuous': np.random.random(100),
        'discrete': np.random.choice(['a', 'b', 'c'], 100)
    })
    discrete_columns = ['discrete']

    ctgan = CTGANSynthesizer(epochs=1)
    ctgan.fit(data, discrete_columns)

    with tf.TemporaryDirectory() as temporary_directory:
        ctgan.save(temporary_directory + "test_tvae.pkl")
        ctgan = CTGANSynthesizer.load(temporary_directory + "test_tvae.pkl")

    sampled = ctgan.sample(1000)
    assert set(sampled.columns) == {'continuous', 'discrete'}
    assert set(sampled['discrete'].unique()) == {'a', 'b', 'c'}
コード例 #10
0
ファイル: __main__.py プロジェクト: sdv-dev/CTGAN
def main():
    """CLI."""
    args = _parse_args()
    if args.tsv:
        data, discrete_columns = read_tsv(args.data, args.metadata)
    else:
        data, discrete_columns = read_csv(args.data, args.metadata,
                                          args.header, args.discrete)

    if args.load:
        model = CTGANSynthesizer.load(args.load)
    else:
        generator_dim = [int(x) for x in args.generator_dim.split(',')]
        discriminator_dim = [int(x) for x in args.discriminator_dim.split(',')]
        model = CTGANSynthesizer(embedding_dim=args.embedding_dim,
                                 generator_dim=generator_dim,
                                 discriminator_dim=discriminator_dim,
                                 generator_lr=args.generator_lr,
                                 generator_decay=args.generator_decay,
                                 discriminator_lr=args.discriminator_lr,
                                 discriminator_decay=args.discriminator_decay,
                                 batch_size=args.batch_size,
                                 epochs=args.epochs)
    model.fit(data, discrete_columns)

    if args.save is not None:
        model.save(args.save)

    num_samples = args.num_samples or len(data)

    if args.sample_condition_column is not None:
        assert args.sample_condition_column_value is not None

    sampled = model.sample(num_samples, args.sample_condition_column,
                           args.sample_condition_column_value)

    if args.tsv:
        write_tsv(sampled, args.metadata, args.output)
    else:
        sampled.to_csv(args.output, index=False)
コード例 #11
0
ファイル: test_ctgan.py プロジェクト: ppeddada97/CTGAN
def test_categorical_nan():
    data = pd.DataFrame({
        'continuous': np.random.random(30),
        # This must be a list (not a np.array) or NaN will be cast to a string.
        'discrete': [np.nan, 'b', 'c'] * 10
    })
    discrete_columns = ['discrete']

    ctgan = CTGANSynthesizer(epochs=1)
    ctgan.fit(data, discrete_columns)

    sampled = ctgan.sample(100)

    assert sampled.shape == (100, 2)
    assert isinstance(sampled, pd.DataFrame)
    assert set(sampled.columns) == {'continuous', 'discrete'}

    # since np.nan != np.nan, we need to be careful here
    values = set(sampled['discrete'].unique())
    assert len(values) == 3
    assert any(pd.isnull(x) for x in values)
    assert {"b", "c"}.issubset(values)