def test_fixed_random_seed(): """Test the CTGANSynthesizer with a fixed seed. Expect that when the random seed is reset with the same seed, the same sequence of data will be produced. Expect that the data generated with the seed is different than randomly sampled data. """ # Setup data = pd.DataFrame({ 'continuous': np.random.random(100), 'discrete': np.random.choice(['a', 'b', 'c'], 100) }) discrete_columns = ['discrete'] ctgan = CTGANSynthesizer(epochs=1) # Run ctgan.fit(data, discrete_columns) sampled_random = ctgan.sample(10) ctgan.set_random_state(0) sampled_0_0 = ctgan.sample(10) sampled_0_1 = ctgan.sample(10) ctgan.set_random_state(0) sampled_1_0 = ctgan.sample(10) sampled_1_1 = ctgan.sample(10) # Assert assert not np.array_equal(sampled_random, sampled_0_0) assert not np.array_equal(sampled_random, sampled_0_1) np.testing.assert_array_equal(sampled_0_0, sampled_1_0) np.testing.assert_array_equal(sampled_0_1, sampled_1_1)
def test_wrong_sampling_conditions(): data = pd.DataFrame({ 'continuous': np.random.random(100), 'discrete': np.random.choice(['a', 'b', 'c'], 100) }) discrete_columns = ['discrete'] ctgan = CTGANSynthesizer(epochs=1) ctgan.fit(data, discrete_columns) with pytest.raises(ValueError): ctgan.sample(1, 'cardinal', "doesn't matter") with pytest.raises(ValueError): ctgan.sample(1, 'discrete', "d")
def test_synthesizer_sample(): data = pd.DataFrame({'discrete': np.random.choice(['a', 'b', 'c'], 100)}) discrete_columns = ['discrete'] ctgan = CTGANSynthesizer(epochs=1) ctgan.fit(data, discrete_columns) samples = ctgan.sample(1000, 'discrete', 'a') assert isinstance(samples, pd.DataFrame)
def test_wrong_sampling_conditions(): """Test the CTGANSynthesizer correctly crashes when passed incorrect sampling conditions.""" data = pd.DataFrame({ 'continuous': np.random.random(100), 'discrete': np.random.choice(['a', 'b', 'c'], 100) }) discrete_columns = ['discrete'] ctgan = CTGANSynthesizer(epochs=1) ctgan.fit(data, discrete_columns) with pytest.raises( ValueError, match="The column_name `cardinal` doesn't exist in the data."): ctgan.sample(1, 'cardinal', "doesn't matter") with pytest.raises( ValueError ): # noqa: RDT currently incorrectly raises a tuple instead of a string ctgan.sample(1, 'discrete', 'd')
def test_ctgan_no_categoricals(): data = pd.DataFrame({'continuous': np.random.random(1000)}) ctgan = CTGANSynthesizer(epochs=1) ctgan.fit(data, []) sampled = ctgan.sample(100) assert sampled.shape == (100, 1) assert isinstance(sampled, pd.DataFrame) assert set(sampled.columns) == {'continuous'}
def test_log_frequency(): data = pd.DataFrame({ 'continuous': np.random.random(1000), 'discrete': np.repeat(['a', 'b', 'c'], [950, 25, 25]) }) discrete_columns = ['discrete'] ctgan = CTGANSynthesizer(epochs=100) ctgan.fit(data, discrete_columns) sampled = ctgan.sample(10000) counts = sampled['discrete'].value_counts() assert counts['a'] < 6500 ctgan = CTGANSynthesizer(log_frequency=False, epochs=100) ctgan.fit(data, discrete_columns) sampled = ctgan.sample(10000) counts = sampled['discrete'].value_counts() assert counts['a'] > 9000
def test_ctgan_numpy(): data = pd.DataFrame({ 'continuous': np.random.random(100), 'discrete': np.random.choice(['a', 'b', 'c'], 100) }) discrete_columns = [1] ctgan = CTGANSynthesizer(epochs=1) ctgan.fit(data.values, discrete_columns) sampled = ctgan.sample(100) assert sampled.shape == (100, 2) assert isinstance(sampled, np.ndarray) assert set(np.unique(sampled[:, 1])) == {'a', 'b', 'c'}
def test_ctgan_dataframe(): data = pd.DataFrame({ 'continuous': np.random.random(100), 'discrete': np.random.choice(['a', 'b', 'c'], 100) }) discrete_columns = ['discrete'] ctgan = CTGANSynthesizer(epochs=1) ctgan.fit(data, discrete_columns) sampled = ctgan.sample(100) assert sampled.shape == (100, 2) assert isinstance(sampled, pd.DataFrame) assert set(sampled.columns) == {'continuous', 'discrete'} assert set(sampled['discrete'].unique()) == {'a', 'b', 'c'}
def test_save_load(): data = pd.DataFrame({ 'continuous': np.random.random(100), 'discrete': np.random.choice(['a', 'b', 'c'], 100) }) discrete_columns = ['discrete'] ctgan = CTGANSynthesizer(epochs=1) ctgan.fit(data, discrete_columns) with tf.TemporaryDirectory() as temporary_directory: ctgan.save(temporary_directory + "test_tvae.pkl") ctgan = CTGANSynthesizer.load(temporary_directory + "test_tvae.pkl") sampled = ctgan.sample(1000) assert set(sampled.columns) == {'continuous', 'discrete'} assert set(sampled['discrete'].unique()) == {'a', 'b', 'c'}
def main(): """CLI.""" args = _parse_args() if args.tsv: data, discrete_columns = read_tsv(args.data, args.metadata) else: data, discrete_columns = read_csv(args.data, args.metadata, args.header, args.discrete) if args.load: model = CTGANSynthesizer.load(args.load) else: generator_dim = [int(x) for x in args.generator_dim.split(',')] discriminator_dim = [int(x) for x in args.discriminator_dim.split(',')] model = CTGANSynthesizer(embedding_dim=args.embedding_dim, generator_dim=generator_dim, discriminator_dim=discriminator_dim, generator_lr=args.generator_lr, generator_decay=args.generator_decay, discriminator_lr=args.discriminator_lr, discriminator_decay=args.discriminator_decay, batch_size=args.batch_size, epochs=args.epochs) model.fit(data, discrete_columns) if args.save is not None: model.save(args.save) num_samples = args.num_samples or len(data) if args.sample_condition_column is not None: assert args.sample_condition_column_value is not None sampled = model.sample(num_samples, args.sample_condition_column, args.sample_condition_column_value) if args.tsv: write_tsv(sampled, args.metadata, args.output) else: sampled.to_csv(args.output, index=False)
def test_categorical_nan(): data = pd.DataFrame({ 'continuous': np.random.random(30), # This must be a list (not a np.array) or NaN will be cast to a string. 'discrete': [np.nan, 'b', 'c'] * 10 }) discrete_columns = ['discrete'] ctgan = CTGANSynthesizer(epochs=1) ctgan.fit(data, discrete_columns) sampled = ctgan.sample(100) assert sampled.shape == (100, 2) assert isinstance(sampled, pd.DataFrame) assert set(sampled.columns) == {'continuous', 'discrete'} # since np.nan != np.nan, we need to be careful here values = set(sampled['discrete'].unique()) assert len(values) == 3 assert any(pd.isnull(x) for x in values) assert {"b", "c"}.issubset(values)