Ejemplo n.º 1
0
 def testGetFeatureValuesForModelSpecFieldNoValues(self):
     model_spec = config_pb2.ModelSpec(name='model1',
                                       example_weight_key='feature2')
     extracts = {}
     got = model_util.get_feature_values_for_model_spec_field(
         [model_spec], 'example_weight', 'example_weights', extracts)
     self.assertIsNone(got)
Ejemplo n.º 2
0
 def testGetFeatureValuesForModelSpecFieldWithMultiModelTransforedFeatures(
         self, model_specs, field, multi_output_field, expected_values):
     extracts = {
         # Only need the num_rows from RecordBatch so use fake array of same len
         # as features.
         constants.ARROW_RECORD_BATCH_KEY:
         pa.RecordBatch.from_arrays([pa.array([1])], ['dummy']),
         constants.FEATURES_KEY: [
             {
                 'feature1': [1.0, 1.1, 1.2],
                 'feature2': [2.0, 2.1, 2.2],
             },
         ],
         constants.TRANSFORMED_FEATURES_KEY: [
             {
                 'model1': {
                     'feature2': [4.0, 4.1, 4.2],
                     'feature3': [5.0, 5.1, 5.2]
                 },
                 'model2': {
                     'feature2': [6.0, 6.1, 6.2],
                     'feature3': [7.0, 7.1, 7.2]
                 }
             },
         ]
     }
     got = model_util.get_feature_values_for_model_spec_field(
         model_specs, field, multi_output_field, extracts)
     self.assertAlmostEqual(expected_values, got)
Ejemplo n.º 3
0
 def extract_labels(  # pylint: disable=invalid-name
     batched_extracts: types.Extracts) -> types.Extracts:
     """Extract labels from extracts containing features."""
     result = copy.copy(batched_extracts)
     result[constants.LABELS_KEY] = (
         model_util.get_feature_values_for_model_spec_field(
             list(eval_config.model_specs), 'label_key', 'label_keys',
             result, True))
     return result
 def extract_example_weights(  # pylint: disable=invalid-name
     batched_extracts: types.Extracts) -> types.Extracts:
     """Extract example weights from extracts containing features."""
     result = copy.copy(batched_extracts)
     example_weights = model_util.get_feature_values_for_model_spec_field(
         list(eval_config.model_specs), 'example_weight_key',
         'example_weight_keys', result)
     if example_weights is not None:
         result[constants.EXAMPLE_WEIGHTS_KEY] = example_weights
     return result
Ejemplo n.º 5
0
 def testGetFeatureValuesForModelSpecFieldNoValues(self):
     model_spec = config_pb2.ModelSpec(name='model1',
                                       example_weight_key='feature2')
     extracts = {
         constants.ARROW_RECORD_BATCH_KEY:
         pa.RecordBatch.from_arrays([pa.array([1])], ['dummy']),
     }
     got = model_util.get_feature_values_for_model_spec_field(
         [model_spec], 'example_weight', 'example_weights', extracts)
     self.assertIsNone(got)
Ejemplo n.º 6
0
 def extract_predictions(  # pylint: disable=invalid-name
     batched_extracts: types.Extracts) -> types.Extracts:
   """Extract predictions from extracts containing features."""
   result = copy.copy(batched_extracts)
   predictions = model_util.get_feature_values_for_model_spec_field(
       list(eval_config.model_specs), 'prediction_key', 'prediction_keys',
       result)
   if predictions is not None:
     result[constants.PREDICTIONS_KEY] = predictions
   return result
Ejemplo n.º 7
0
 def testGetFeatureValuesForModelSpecField(self, model_specs, field,
                                           multi_output_field,
                                           expected_values):
     extracts = {
         constants.FEATURES_KEY: {
             'feature1': [1.0, 1.1, 1.2],
             'feature2': [2.0, 2.1, 2.2],
             'feature3': [3.0, 3.1, 3.2],
         }
     }
     got = model_util.get_feature_values_for_model_spec_field(
         model_specs, field, multi_output_field, extracts)
     self.assertAlmostEqual(expected_values, got)
Ejemplo n.º 8
0
 def testGetFeatureValuesForModelSpecFieldWithSingleModelTransforedFeatures(
         self, model_specs, field, multi_output_field, expected_values):
     extracts = {
         constants.FEATURES_KEY: {
             'feature1': [1.0, 1.1, 1.2],
             'feature2': [2.0, 2.1, 2.2],
         },
         constants.TRANSFORMED_FEATURES_KEY: {
             'feature2': [4.0, 4.1, 4.2],
         }
     }
     got = model_util.get_feature_values_for_model_spec_field(
         model_specs, field, multi_output_field, extracts)
     self.assertAlmostEqual(expected_values, got)
Ejemplo n.º 9
0
 def testGetFeatureValuesForModelSpecField(self, model_specs, field,
                                           multi_output_field,
                                           expected_values):
     extracts = {
         # Only need the num_rows from RecordBatch so use fake array of same len
         # as features.
         constants.ARROW_RECORD_BATCH_KEY:
         pa.RecordBatch.from_arrays([pa.array([1])], ['dummy']),
         constants.FEATURES_KEY: [
             {
                 'feature1': [1.0, 1.1, 1.2],
                 'feature2': [2.0, 2.1, 2.2],
                 'feature3': [3.0, 3.1, 3.2],
             },
         ]
     }
     got = model_util.get_feature_values_for_model_spec_field(
         model_specs, field, multi_output_field, extracts)
     self.assertAlmostEqual(expected_values, got)
Ejemplo n.º 10
0
 def _extract_features_and_labels(self, batched_extract):
     """Extract features from extracts containing arrow table."""
     # This function is a combination of
     # _ExtractFeatures.extract_features in extractors/features_extractor.py
     # and _ExtractLabels.extract_labels in extractors/labels_extractor.py
     result = copy.copy(batched_extract)
     (record_batch, serialized_examples) = (
         features_extractor.
         _drop_unsupported_columns_and_fetch_raw_data_column(  # pylint: disable=protected-access
             batched_extract[constants.ARROW_RECORD_BATCH_KEY]))
     features = result[
         constants.
         FEATURES_KEY] if constants.FEATURES_KEY in result else {}
     features.update(util.record_batch_to_tensor_values(record_batch))
     result[constants.FEATURES_KEY] = features
     result[constants.INPUT_KEY] = serialized_examples
     labels = (model_util.get_feature_values_for_model_spec_field(
         list(self._eval_config.model_specs), "label_key", "label_keys",
         result, True))
     result[constants.LABELS_KEY] = self._transform_labels(labels)
     return result