コード例 #1
0
        def analyzer_fn(inputs):
            a = tf.cast(inputs['a'], input_dtype)

            def assert_and_cast_dtype(tensor, out_dtype):
                self.assertEqual(tensor.dtype, out_dtype)
                return tf.cast(tensor,
                               tft_unit.canonical_numeric_dtype(out_dtype))

            return {
                'tukey_location':
                assert_and_cast_dtype(
                    tft.tukey_location(a,
                                       reduce_instance_dims=not elementwise),
                    output_dtypes['tukey_location']),
                'tukey_scale':
                assert_and_cast_dtype(
                    tft.tukey_scale(a, reduce_instance_dims=not elementwise),
                    output_dtypes['tukey_scale']),
                'tukey_hl':
                assert_and_cast_dtype(
                    tft.tukey_h_params(
                        a, reduce_instance_dims=not elementwise)[0],
                    output_dtypes['tukey_hl']),
                'tukey_hr':
                assert_and_cast_dtype(
                    tft.tukey_h_params(
                        a, reduce_instance_dims=not elementwise)[1],
                    output_dtypes['tukey_hr']),
            }
コード例 #2
0
        def analyzer_fn(inputs):
            a = inputs['a']

            return {
                'tukey_location':
                tft.tukey_location(a, reduce_instance_dims=False),
                'tukey_scale': tft.tukey_scale(a, reduce_instance_dims=False),
                'tukey_hl': tft.tukey_h_params(a,
                                               reduce_instance_dims=False)[0],
                'tukey_hr': tft.tukey_h_params(a,
                                               reduce_instance_dims=False)[1],
            }
コード例 #3
0
        def analyzer_fn(inputs):
            a = tf.cast(inputs['a'], input_dtype)

            def assert_and_cast_dtype(tensor):
                self.assertEqual(tensor.dtype, output_dtype)
                return tf.cast(tensor, canonical_output_dtype)

            return {
                'tukey_location': assert_and_cast_dtype(tft.tukey_location(a)),
                'tukey_scale': assert_and_cast_dtype(tft.tukey_scale(a)),
                'tukey_hl': assert_and_cast_dtype(tft.tukey_h_params(a)[0]),
                'tukey_hr': assert_and_cast_dtype(tft.tukey_h_params(a)[1]),
            }