def test_recreate(): data = load_demo(metadata=False)['users'] # If distribution is non parametric, get_parameters fails model = CTGAN(epochs=1) model.fit(data) sampled = model.sample() assert sampled.shape == data.shape assert (sampled.dtypes == data.dtypes).all() assert (sampled.notnull().sum(axis=1) != 0).all() # Metadata model_meta = CTGAN(epochs=1, table_metadata=model.get_metadata()) model_meta.fit(data) sampled = model_meta.sample() assert sampled.shape == data.shape assert (sampled.dtypes == data.dtypes).all() assert (sampled.notnull().sum(axis=1) != 0).all() # Metadata dict model_meta_dict = CTGAN(epochs=1, table_metadata=model.get_metadata().to_dict()) model_meta_dict.fit(data) sampled = model_meta_dict.sample() assert sampled.shape == data.shape assert (sampled.dtypes == data.dtypes).all() assert (sampled.notnull().sum(axis=1) != 0).all()
def test_ctgan(): users = load_demo(metadata=False)['users'] ctgan = CTGAN(primary_key='user_id', epochs=1) ctgan.fit(users) sampled = ctgan.sample() # test shape is right assert sampled.shape == users.shape # test user_id has been generated as an ID field assert list(sampled['user_id']) == list(range(0, len(users))) assert ctgan.get_metadata().to_dict() == { 'fields': { 'user_id': { 'type': 'id', 'subtype': 'integer' }, 'country': { 'type': 'categorical' }, 'gender': { 'type': 'categorical' }, 'age': { 'type': 'numerical', 'subtype': 'integer' } }, 'constraints': [], 'model_kwargs': {} }
def test_ctgan(): users = load_demo(metadata=False)['users'] ctgan = CTGAN( primary_key='user_id', epochs=1 ) ctgan.fit(users) sampled = ctgan.sample() # test shape is right assert sampled.shape == users.shape # test user_id has been generated as an ID field assert list(sampled['user_id']) == list(range(0, len(users))) expected_metadata = { 'fields': { 'user_id': { 'type': 'id', 'subtype': 'integer', 'transformer': 'integer', }, 'country': { 'type': 'categorical', 'transformer': None, }, 'gender': { 'type': 'categorical', 'transformer': None, }, 'age': { 'type': 'numerical', 'subtype': 'integer', 'transformer': 'integer', } }, 'primary_key': 'user_id', 'constraints': [], 'sequence_index': None, 'context_columns': [], 'entity_columns': [], 'model_kwargs': {}, 'name': None } assert ctgan.get_metadata().to_dict() == expected_metadata