Exemplo n.º 1
0
    def test_update(self):
        count_col = FeatureRequestTotal.feature_name_from_class()
        mean_col = FeaturePathDepthAverage.feature_name_from_class()
        schema = T.StructType([
            T.StructField(self.feature.current_features_column,
                          T.MapType(T.StringType(), T.FloatType())),
            T.StructField(self.feature.past_features_column,
                          T.MapType(T.StringType(), T.FloatType())),
        ])

        sub_df = self.session.createDataFrame([{
            self.feature.current_features_column: {
                self.feature.feature_name: 6.,
                count_col: 3.,
                mean_col: 5.,
            },
            self.feature.past_features_column: {
                self.feature.feature_name: 2.,
                count_col: 1.,
                mean_col: 4.,
            }
        }],
                                              schema=schema)
        result_df = self.feature.update(sub_df)

        result_df.show()
        value = result_df.select(
            self.feature.updated_feature_col_name).collect()[0][
                self.feature.updated_feature_col_name]
        from baskerville.features.helpers import update_variance
        expected_value = update_variance(2., 6., 1., 3., 4., 5.)
        print(expected_value)
        self.assertAlmostEqual(value, expected_value, places=2)
Exemplo n.º 2
0
 def update_row(cls, current, past, *args, **kwargs):
     return update_variance(
         past.get(cls.feature_name_from_class()),
         current[cls.feature_name_from_class()],
         past.get(FeatureRequestTotal.feature_name_from_class()),
         current[FeatureRequestTotal.feature_name_from_class()],
         past.get(FeaturePathDepthAverage.feature_name_from_class()),
         current[FeaturePathDepthAverage.feature_name_from_class()])
    def test_update_row(self):
        requests = FeatureRequestTotal()
        ave = FeatureRequestIntervalAverage()
        test_current = {self.feature.feature_name: 6.,
                        requests.feature_name: 3.,
                        ave.feature_name: 5.}
        test_past = {self.feature.feature_name: 2.,
                     requests.feature_name: 1.,
                     ave.feature_name: 4.}
        value = self.feature.update_row(
            test_current, test_past
        )

        from baskerville.features.helpers import update_variance
        expected_value = update_variance(2., 6., 1., 3., 4., 5.)

        self.assertAlmostEqual(value, expected_value, places=2)