Example #1
0
 def _fit(self, df):
     backend = self._get_or_create_backend()
     with util.prepare_data(
             backend.num_processes(),
             self.getStore(),
             df,
             label_columns=self.getLabelCols(),
             feature_columns=self.getFeatureCols(),
             validation=self.getValidation(),
             sample_weight_col=self.getSampleWeightCol(),
             compress_sparse=self.getCompressSparseCols(),
             partitions_per_process=self.getPartitionsPerProcess(),
             verbose=self.getVerbose()) as dataset_idx:
         train_rows, val_rows, metadata, avg_row_size = util.get_dataset_properties(
             dataset_idx)
         self._check_metadata_compatibility(metadata)
         return self._fit_on_prepared_data(backend, train_rows, val_rows,
                                           metadata, avg_row_size,
                                           dataset_idx)
Example #2
0
    def test_prepare_data_compress_sparse(self):
        util.clear_training_cache()

        expected_metadata = \
            {
                'float': {
                    'spark_data_type': FloatType,
                    'is_sparse_vector_only': False,
                    'intermediate_format': constants.NOCHANGE,
                    'max_size': 1,
                    'shape': 1
                },
                'dense': {
                    'spark_data_type': DenseVector,
                    'is_sparse_vector_only': False,
                    'intermediate_format': constants.ARRAY,
                    'max_size': 2,
                    'shape': 2
                },
                'sparse': {
                    'spark_data_type': SparseVector,
                    'is_sparse_vector_only': True,
                    'intermediate_format': constants.CUSTOM_SPARSE,
                    'max_size': 1,
                    'shape': 2
                },
                'mixed': {
                    'spark_data_type': DenseVector,
                    'is_sparse_vector_only': False,
                    'intermediate_format': constants.ARRAY,
                    'max_size': 2,
                    'shape': 2
                },
            }

        with mock.patch('horovod.spark.common.util._get_metadata',
                        side_effect=util._get_metadata) as mock_get_metadata:
            with spark_session('test_prepare_data') as spark:
                data = [[
                    0.0,
                    DenseVector([1.0, 1.0]),
                    SparseVector(2, {1: 1.0}),
                    DenseVector([1.0, 1.0])
                ],
                        [
                            1.0,
                            DenseVector([1.0, 1.0]),
                            SparseVector(2, {1: 1.0}),
                            SparseVector(2, {1: 1.0})
                        ]]

                schema = StructType([
                    StructField('float', FloatType()),
                    StructField('dense', VectorUDT()),
                    StructField('sparse', VectorUDT()),
                    StructField('mixed', VectorUDT())
                ])

                df = create_test_data_from_schema(spark, data, schema)

                with local_store() as store:
                    with util.prepare_data(
                            num_processes=2,
                            store=store,
                            df=df,
                            feature_columns=['dense', 'sparse', 'mixed'],
                            label_columns=['float'],
                            compress_sparse=True) as dataset_idx:
                        mock_get_metadata.assert_called()
                        assert dataset_idx == 0

                        train_rows, val_rows, metadata, avg_row_size = util.get_dataset_properties(
                            dataset_idx)
                        self.assertDictEqual(metadata, expected_metadata)
Example #3
0
    def test_df_cache(self):
        # Clean the cache before starting the test
        util.clear_training_cache()
        util._training_cache.get_dataset = mock.Mock(
            side_effect=util._training_cache.get_dataset)

        with spark_session('test_df_cache') as spark:
            with local_store() as store:
                df = create_xor_data(spark)
                df2 = create_xor_data(spark)
                df3 = create_xor_data(spark)

                key = util._training_cache.create_key(df, store, None)
                key2 = util._training_cache.create_key(df2, store, None)
                key3 = util._training_cache.create_key(df3, store, None)

                # All keys are distinct
                assert key != key2
                assert key != key3
                assert key2 != key3

                # The cache should be empty to start
                assert not util._training_cache.is_cached(key, store)
                assert not util._training_cache.is_cached(key2, store)
                assert not util._training_cache.is_cached(key3, store)

                # First insertion into the cache
                with util.prepare_data(num_processes=2,
                                       store=store,
                                       df=df,
                                       feature_columns=['features'],
                                       label_columns=['y']) as dataset_idx:
                    train_rows, val_rows, metadata, avg_row_size = util.get_dataset_properties(
                        dataset_idx)
                    util._training_cache.get_dataset.assert_not_called()
                    assert len(util._training_cache._key_to_dataset) == 1
                    assert util._training_cache.is_cached(key, store)
                    assert dataset_idx == 0

                    # The first dataset is still in use, so we assign the next integer in sequence to this
                    # dataset
                    assert not util._training_cache.is_cached(key2, store)
                    with util.prepare_data(num_processes=2,
                                           store=store,
                                           df=df2,
                                           feature_columns=['features'],
                                           label_columns=['y'
                                                          ]) as dataset_idx2:
                        util._training_cache.get_dataset.assert_not_called()
                        assert len(util._training_cache._key_to_dataset) == 2
                        assert util._training_cache.is_cached(key2, store)
                        assert dataset_idx2 == 1

                # Even though the first dataset is no longer in use, it is still cached
                with util.prepare_data(num_processes=2,
                                       store=store,
                                       df=df,
                                       feature_columns=['features'],
                                       label_columns=['y']) as dataset_idx1:
                    train_rows1, val_rows1, metadata1, avg_row_size1 = util.get_dataset_properties(
                        dataset_idx1)
                    util._training_cache.get_dataset.assert_called()
                    assert train_rows == train_rows1
                    assert val_rows == val_rows1
                    assert metadata == metadata1
                    assert avg_row_size == avg_row_size1
                    assert dataset_idx1 == 0

                # The first dataset is no longer in use, so we can reclaim its dataset index
                assert not util._training_cache.is_cached(key3, store)
                with util.prepare_data(num_processes=2,
                                       store=store,
                                       df=df3,
                                       feature_columns=['features'],
                                       label_columns=['y']) as dataset_idx3:
                    train_rows3, val_rows3, metadata3, avg_row_size3 = util.get_dataset_properties(
                        dataset_idx3)
                    assert train_rows == train_rows3
                    assert val_rows == val_rows3
                    assert metadata == metadata3
                    assert avg_row_size == avg_row_size3
                    assert dataset_idx3 == 0

                # Same dataframe, different validation
                bad_key = util._training_cache.create_key(df, store, 0.1)
                assert not util._training_cache.is_cached(bad_key, store)