def test_example_protos_from_path_get_two(self):
     cns_path = os.path.join(tf.compat.v1.test.get_temp_dir(),
                             'dummy_example')
     example_one = test_utils.make_fake_example(1)
     example_two = test_utils.make_fake_example(2)
     example_three = test_utils.make_fake_example(3)
     test_utils.write_out_examples(
         [example_one, example_two, example_three], cns_path)
     dummy_examples = platform_utils.example_protos_from_path(cns_path, 2)
     self.assertEqual(2, len(dummy_examples))
     self.assertEqual(example_one, dummy_examples[0])
     self.assertEqual(example_two, dummy_examples[1])
    def test_example_protos_from_path_use_wildcard(self):
        cns_path = os.path.join(tf.compat.v1.test.get_temp_dir(),
                                'wildcard_example1')
        example1 = test_utils.make_fake_example(1)
        test_utils.write_out_examples([example1], cns_path)
        cns_path = os.path.join(tf.compat.v1.test.get_temp_dir(),
                                'wildcard_example2')
        example2 = test_utils.make_fake_example(2)
        test_utils.write_out_examples([example2], cns_path)

        wildcard_path = os.path.join(tf.compat.v1.test.get_temp_dir(),
                                     'wildcard_example*')
        dummy_examples = platform_utils.example_protos_from_path(wildcard_path)
        self.assertEqual(2, len(dummy_examples))
    def test_get_numeric_features_to_observed_range(self):
        example = test_utils.make_fake_example(single_int_val=2)

        data = inference_utils.get_numeric_features_to_observed_range(
            [example])

        # Returns a sorted list by feature_name.
        self.assertDictEqual(
            {
                'repeated_float': {
                    'observedMin': 1.,
                    'observedMax': 4.,
                },
                'repeated_int': {
                    'observedMin': 10,
                    'observedMax': 20,
                },
                'single_float': {
                    'observedMin': 24.5,
                    'observedMax': 24.5,
                },
                'single_int': {
                    'observedMin': 2.,
                    'observedMax': 2.,
                },
            }, data)
 def test_example_protos_from_path_get_all_in_file(self):
     cns_path = os.path.join(tf.compat.v1.test.get_temp_dir(),
                             'dummy_example')
     example = test_utils.make_fake_example()
     test_utils.write_out_examples([example], cns_path)
     dummy_examples = platform_utils.example_protos_from_path(cns_path)
     self.assertEqual(1, len(dummy_examples))
     self.assertEqual(example, dummy_examples[0])
    def test_infer_mutants_handler(self, mock_mutant_charts_for_feature):

        # A no-op that just passes the example passed to mutant_charts_for_feature
        # back through. This tests that the URL parameters get processed properly
        # within infer_mutants_handler.
        def pass_through(example, feature_name, serving_bundles, viz_params):
            return {
                'example':
                str(example),
                'feature_name':
                feature_name,
                'serving_bundles': [{
                    'inference_address':
                    serving_bundles[0].inference_address,
                    'model_name':
                    serving_bundles[0].model_name,
                    'model_type':
                    serving_bundles[0].model_type,
                }],
                'viz_params': {
                    'x_min': viz_params.x_min,
                    'x_max': viz_params.x_max
                }
            }

        mock_mutant_charts_for_feature.side_effect = pass_through

        example = test_utils.make_fake_example()
        self.plugin.examples = [example]

        response = self.server.get(
            '/data/plugin/whatif/infer_mutants?' +
            urllib_parse.urlencode({
                'feature_name': 'single_int',
                'model_name': '/ml/cassandrax/iris_classification',
                'inference_address': 'ml-serving-temp.prediction',
                'model_type': 'classification',
                'model_version': ',',
                'model_signature': ',',
                'x_min': '-10',
                'x_max': '10',
            }))
        result = self._DeserializeResponse(response.get_data())
        self.assertEqual(str([example]), result['example'])
        self.assertEqual('single_int', result['feature_name'])
        self.assertEqual('ml-serving-temp.prediction',
                         result['serving_bundles'][0]['inference_address'])
        self.assertEqual('/ml/cassandrax/iris_classification',
                         result['serving_bundles'][0]['model_name'])
        self.assertEqual('classification',
                         result['serving_bundles'][0]['model_type'])
        self.assertAlmostEqual(-10, result['viz_params']['x_min'])
        self.assertAlmostEqual(10, result['viz_params']['x_max'])
    def test_parse_original_feature_from_example(self):
        example = test_utils.make_fake_example()
        original_feature = inference_utils.parse_original_feature_from_example(
            example, 'repeated_float')
        self.assertEqual('repeated_float', original_feature.feature_name)
        self.assertEqual([1.0, 2.0, 3.0, 4.0], original_feature.original_value)
        self.assertEqual('float_list', original_feature.feature_type)
        self.assertEqual(4, original_feature.length)

        original_feature = inference_utils.parse_original_feature_from_example(
            example, 'repeated_int')
        self.assertEqual('repeated_int', original_feature.feature_name)
        self.assertEqual([10, 20], original_feature.original_value)
        self.assertEqual('int64_list', original_feature.feature_type)
        self.assertEqual(2, original_feature.length)

        original_feature = inference_utils.parse_original_feature_from_example(
            example, 'single_int')
        self.assertEqual('single_int', original_feature.feature_name)
        self.assertEqual([0], original_feature.original_value)
        self.assertEqual('int64_list', original_feature.feature_type)
        self.assertEqual(1, original_feature.length)
    def test_eligible_features_from_example_proto(self):
        example = test_utils.make_fake_example(single_int_val=2)
        self.plugin.examples = [example]

        response = self.server.get('/data/plugin/whatif/eligible_features')
        self.assertEqual(200, response.status_code)

        # Returns a list of dict objects that have been sorted by feature_name.
        data = self._DeserializeResponse(response.get_data())

        sorted_feature_names = [
            'non_numeric', 'repeated_float', 'repeated_int', 'single_float',
            'single_int'
        ]
        self.assertEqual(sorted_feature_names, [d['name'] for d in data])
        np.testing.assert_almost_equal(
            [-1, 1., 10, 24.5, 2.], [d.get('observedMin', -1) for d in data])
        np.testing.assert_almost_equal(
            [-1, 4., 20, 24.5, 2.], [d.get('observedMax', -1) for d in data])

        # Test that only non_numeric feature has samples.
        self.assertFalse(any(d.get('samples') for d in data[1:]))
        self.assertEqual(['cat'], data[0]['samples'])
 def make_and_write_fake_example(self):
     """Make example and write it to self.examples_path."""
     example = test_utils.make_fake_example()
     test_utils.write_out_examples([example], self.examples_path)
     return example
 def test_get_numeric_features(self):
     example = test_utils.make_fake_example(single_int_val=2)
     data = inference_utils.get_numeric_feature_names(example)
     self.assertEqual(
         ['repeated_float', 'repeated_int', 'single_float', 'single_int'],
         data)