Example #1
0
    def _infer_mutants_impl(self, feature_name, example_index,
                            inference_addresses, model_names, model_type,
                            model_versions, model_signatures, use_predict,
                            predict_input_tensor, predict_output_tensor, x_min,
                            x_max, feature_index_pattern, custom_predict_fn):
        """Helper for generating PD plots for a feature."""
        examples = (self.examples
                    if example_index == -1 else [self.examples[example_index]])
        serving_bundles = []
        for model_num in xrange(len(inference_addresses)):
            serving_bundles.append(
                inference_utils.ServingBundle(
                    inference_addresses[model_num],
                    model_names[model_num],
                    model_type,
                    model_versions[model_num],
                    model_signatures[model_num],
                    use_predict,
                    predict_input_tensor,
                    predict_output_tensor,
                    custom_predict_fn=custom_predict_fn))

        viz_params = inference_utils.VizParams(
            x_min, x_max, self.examples[0:NUM_EXAMPLES_TO_SCAN], NUM_MUTANTS,
            feature_index_pattern)
        return inference_utils.mutant_charts_for_feature(
            examples, feature_name, serving_bundles, viz_params)
    def test_mutant_charts_for_feature(
            self, mock_call_servo, mock_make_json_formatted_for_single_chart):
        example = self.make_and_write_fake_example()
        serving_bundles = [
            inference_utils.ServingBundle('', '', 'classification', '', '',
                                          False, '', '')
        ]
        num_mutants = 10
        viz_params = inference_utils.VizParams(x_min=1,
                                               x_max=10,
                                               examples=[example],
                                               num_mutants=num_mutants,
                                               feature_index_pattern=None)

        mock_call_servo = lambda _, __: None
        mock_make_json_formatted_for_single_chart = lambda _, __: {}
        charts = inference_utils.mutant_charts_for_feature([example],
                                                           'repeated_float',
                                                           serving_bundles,
                                                           viz_params)
        self.assertEqual('numeric', charts['chartType'])
        self.assertEqual(4, len(charts['data']))
        charts = inference_utils.mutant_charts_for_feature([example],
                                                           'repeated_int',
                                                           serving_bundles,
                                                           viz_params)
        self.assertEqual('numeric', charts['chartType'])
        self.assertEqual(2, len(charts['data']))
        charts = inference_utils.mutant_charts_for_feature([example],
                                                           'single_int',
                                                           serving_bundles,
                                                           viz_params)
        self.assertEqual('numeric', charts['chartType'])
        self.assertEqual(1, len(charts['data']))
        charts = inference_utils.mutant_charts_for_feature([example],
                                                           'non_numeric',
                                                           serving_bundles,
                                                           viz_params)
        self.assertEqual('categorical', charts['chartType'])
        self.assertEqual(3, len(charts['data']))
    def test_mutant_charts_for_feature_with_feature_index_pattern(
            self, mock_call_servo, mock_make_json_formatted_for_single_chart):
        example = self.make_and_write_fake_example()
        serving_bundles = [
            inference_utils.ServingBundle('', '', 'classification', '', '',
                                          False, '', '')
        ]
        num_mutants = 10
        viz_params = inference_utils.VizParams(x_min=1,
                                               x_max=10,
                                               examples=[example],
                                               num_mutants=num_mutants,
                                               feature_index_pattern='0 , 2-3')

        mock_call_servo = lambda _, __: None
        mock_make_json_formatted_for_single_chart = lambda _, __: {}
        charts = inference_utils.mutant_charts_for_feature([example],
                                                           'repeated_float',
                                                           serving_bundles,
                                                           viz_params)
        self.assertEqual('numeric', charts['chartType'])
        self.assertEqual(3, len(charts['data']))

        # These should return 3 charts even though all fields from the index
        # pattern don't exist for the example.
        charts = inference_utils.mutant_charts_for_feature([example],
                                                           'repeated_int',
                                                           serving_bundles,
                                                           viz_params)
        self.assertEqual('numeric', charts['chartType'])
        self.assertEqual(3, len(charts['data']))

        charts = inference_utils.mutant_charts_for_feature([example],
                                                           'single_int',
                                                           serving_bundles,
                                                           viz_params)
        self.assertEqual('numeric', charts['chartType'])
        self.assertEqual(3, len(charts['data']))
Example #4
0
 def infer_mutants_impl(self, info):
     """Performs mutant inference on specified examples."""
     example_index = int(info['example_index'])
     feature_name = info['feature_name']
     examples = (self.examples
                 if example_index == -1 else [self.examples[example_index]])
     examples = [self.json_to_proto(ex) for ex in examples]
     scan_examples = [self.json_to_proto(ex) for ex in self.examples[0:50]]
     serving_bundles = []
     serving_bundles.append(
         inference_utils.ServingBundle(
             self.config.get('inference_address'),
             self.config.get('model_name'), self.config.get('model_type'),
             self.config.get('model_version'),
             self.config.get('model_signature'),
             self.config.get('uses_predict_api'),
             self.config.get('predict_input_tensor'),
             self.config.get('predict_output_tensor'),
             self.estimator_and_spec.get('estimator'),
             self.estimator_and_spec.get('feature_spec'),
             self.custom_predict_fn))
     if ('inference_address_2' in self.config
             or self.compare_estimator_and_spec.get('estimator')
             or self.compare_custom_predict_fn):
         serving_bundles.append(
             inference_utils.ServingBundle(
                 self.config.get('inference_address_2'),
                 self.config.get('model_name_2'),
                 self.config.get('model_type'),
                 self.config.get('model_version_2'),
                 self.config.get('model_signature_2'),
                 self.config.get('uses_predict_api'),
                 self.config.get('predict_input_tensor'),
                 self.config.get('predict_output_tensor'),
                 self.compare_estimator_and_spec.get('estimator'),
                 self.compare_estimator_and_spec.get('feature_spec'),
                 self.compare_custom_predict_fn))
     viz_params = inference_utils.VizParams(info['x_min'], info['x_max'],
                                            scan_examples, 10,
                                            info['feature_index_pattern'])
     self.running_mutant_infer = True
     charts = inference_utils.mutant_charts_for_feature(
         examples, feature_name, serving_bundles, viz_params)
     self.running_mutant_infer = False
     return charts