def test_prepare_data_compress_sparse(self): util.clear_training_cache() expected_metadata = \ { 'float': { 'spark_data_type': FloatType, 'is_sparse_vector_only': False, 'intermediate_format': constants.NOCHANGE, 'max_size': 1, 'shape': 1 }, 'dense': { 'spark_data_type': DenseVector, 'is_sparse_vector_only': False, 'intermediate_format': constants.ARRAY, 'max_size': 2, 'shape': 2 }, 'sparse': { 'spark_data_type': SparseVector, 'is_sparse_vector_only': True, 'intermediate_format': constants.CUSTOM_SPARSE, 'max_size': 1, 'shape': 2 }, 'mixed': { 'spark_data_type': DenseVector, 'is_sparse_vector_only': False, 'intermediate_format': constants.ARRAY, 'max_size': 2, 'shape': 2 }, } with mock.patch('horovod.spark.common.util._get_metadata', side_effect=util._get_metadata) as mock_get_metadata: with spark_session('test_prepare_data') as spark: data = [[ 0.0, DenseVector([1.0, 1.0]), SparseVector(2, {1: 1.0}), DenseVector([1.0, 1.0]) ], [ 1.0, DenseVector([1.0, 1.0]), SparseVector(2, {1: 1.0}), SparseVector(2, {1: 1.0}) ]] schema = StructType([ StructField('float', FloatType()), StructField('dense', VectorUDT()), StructField('sparse', VectorUDT()), StructField('mixed', VectorUDT()) ]) df = create_test_data_from_schema(spark, data, schema) with local_store() as store: with util.prepare_data( num_processes=2, store=store, df=df, feature_columns=['dense', 'sparse', 'mixed'], label_columns=['float'], compress_sparse=True) as dataset_idx: mock_get_metadata.assert_called() assert dataset_idx == 0 train_rows, val_rows, metadata, avg_row_size = util.get_dataset_properties( dataset_idx) self.assertDictEqual(metadata, expected_metadata)
def test_df_cache(self): # Clean the cache before starting the test util.clear_training_cache() util._training_cache.get_dataset = mock.Mock( side_effect=util._training_cache.get_dataset) with spark_session('test_df_cache') as spark: with local_store() as store: df = create_xor_data(spark) df2 = create_xor_data(spark) df3 = create_xor_data(spark) key = util._training_cache.create_key(df, store, None) key2 = util._training_cache.create_key(df2, store, None) key3 = util._training_cache.create_key(df3, store, None) # All keys are distinct assert key != key2 assert key != key3 assert key2 != key3 # The cache should be empty to start assert not util._training_cache.is_cached(key, store) assert not util._training_cache.is_cached(key2, store) assert not util._training_cache.is_cached(key3, store) # First insertion into the cache with util.prepare_data(num_processes=2, store=store, df=df, feature_columns=['features'], label_columns=['y']) as dataset_idx: train_rows, val_rows, metadata, avg_row_size = util.get_dataset_properties( dataset_idx) util._training_cache.get_dataset.assert_not_called() assert len(util._training_cache._key_to_dataset) == 1 assert util._training_cache.is_cached(key, store) assert dataset_idx == 0 # The first dataset is still in use, so we assign the next integer in sequence to this # dataset assert not util._training_cache.is_cached(key2, store) with util.prepare_data(num_processes=2, store=store, df=df2, feature_columns=['features'], label_columns=['y' ]) as dataset_idx2: util._training_cache.get_dataset.assert_not_called() assert len(util._training_cache._key_to_dataset) == 2 assert util._training_cache.is_cached(key2, store) assert dataset_idx2 == 1 # Even though the first dataset is no longer in use, it is still cached with util.prepare_data(num_processes=2, store=store, df=df, feature_columns=['features'], label_columns=['y']) as dataset_idx1: train_rows1, val_rows1, metadata1, avg_row_size1 = util.get_dataset_properties( dataset_idx1) util._training_cache.get_dataset.assert_called() assert train_rows == train_rows1 assert val_rows == val_rows1 assert metadata == metadata1 assert avg_row_size == avg_row_size1 assert dataset_idx1 == 0 # The first dataset is no longer in use, so we can reclaim its dataset index assert not util._training_cache.is_cached(key3, store) with util.prepare_data(num_processes=2, store=store, df=df3, feature_columns=['features'], label_columns=['y']) as dataset_idx3: train_rows3, val_rows3, metadata3, avg_row_size3 = util.get_dataset_properties( dataset_idx3) assert train_rows == train_rows3 assert val_rows == val_rows3 assert metadata == metadata3 assert avg_row_size == avg_row_size3 assert dataset_idx3 == 0 # Same dataframe, different validation bad_key = util._training_cache.create_key(df, store, 0.1) assert not util._training_cache.is_cached(bad_key, store)