Exemplo n.º 1
0
  def test_build_json_example_transforming_serving_input_fn(self):
    example_all = {
        'features': {
            'feature': {
                'raw_a': {
                    'int64List': {
                        'value': [42]
                    }
                },
                'raw_b': {
                    'int64List': {
                        'value': [43]
                    }
                },
                'raw_label': {
                    'int64List': {
                        'value': [44]
                    }
                }
            }
        }
    }
    # Default values for raw_a and raw_b come from _make_raw_schema()
    example_missing = {
        'features': {
            'feature': {
                'raw_label': {
                    'int64List': {
                        'value': [3]
                    }
                }
            }
        }
    }
    feed_dict = [json.dumps(example_all), json.dumps(example_missing)]

    basedir = tempfile.mkdtemp()

    raw_metadata = dataset_metadata.DatasetMetadata(schema=_make_raw_schema([]))

    transform_savedmodel_dir = os.path.join(basedir, 'transform-savedmodel')
    _write_transform_savedmodel(transform_savedmodel_dir)

    serving_input_fn = (
        input_fn_maker.build_json_example_transforming_serving_input_fn(
            raw_metadata=raw_metadata,
            raw_label_keys=[],
            raw_feature_keys=['raw_a', 'raw_b', 'raw_label'],
            transform_savedmodel_dir=transform_savedmodel_dir))

    with tf.Graph().as_default():
      with tf.Session().as_default() as session:
        outputs, labels, inputs = serving_input_fn()

        self.assertItemsEqual(
            outputs.keys(),
            {'transformed_a', 'transformed_b', 'transformed_label'})
        self.assertIsNone(labels)
        self.assertEqual(set(inputs.keys()), {'json_example'})

        feed_inputs = {inputs['json_example']: feed_dict}
        transformed_a, transformed_b, transformed_label = session.run(
            [outputs['transformed_a'], outputs['transformed_b'],
             outputs['transformed_label']],
            feed_dict=feed_inputs)

    batch_shape = (len(feed_dict), 1)

    # transformed_b is sparse so _convert_scalars_to_vectors did not fix it
    sparse_batch_shape = (len(feed_dict),)
    transformed_b_dict = dict(zip([tuple(x + [0])
                                   for x in transformed_b.indices.tolist()],
                                  transformed_b.values.tolist()))

    self.assertEqual(batch_shape, tuple(transformed_a.shape))
    self.assertEqual(sparse_batch_shape, tuple(transformed_b.dense_shape))
    self.assertEqual(batch_shape, tuple(transformed_label.shape))

    self.assertEqual(85, transformed_a[0][0])
    self.assertEqual(-1, transformed_b_dict[(0, 0)])
    self.assertEqual(44000, transformed_label[0][0])
    self.assertEqual(1, transformed_a[1][0])
    self.assertEqual(-1, transformed_b_dict[(1, 0)])
    self.assertEqual(3000, transformed_label[1][0])
Exemplo n.º 2
0
    def test_build_json_example_transforming_serving_input_fn(self):
        example_all = {
            'features': {
                'feature': {
                    'raw_a': {
                        'int64List': {
                            'value': [42]
                        }
                    },
                    'raw_b': {
                        'int64List': {
                            'value': [43]
                        }
                    },
                    'raw_label': {
                        'int64List': {
                            'value': [44]
                        }
                    }
                }
            }
        }
        # Default values for raw_a and raw_b come from _make_raw_schema()
        example_missing = {
            'features': {
                'feature': {
                    'raw_label': {
                        'int64List': {
                            'value': [3]
                        }
                    }
                }
            }
        }
        feed_dict = [json.dumps(example_all), json.dumps(example_missing)]

        basedir = tempfile.mkdtemp()

        raw_metadata = dataset_metadata.DatasetMetadata(
            schema=_make_raw_schema([]))

        transform_savedmodel_dir = os.path.join(basedir,
                                                'transform-savedmodel')
        _write_transform_savedmodel(transform_savedmodel_dir)

        serving_input_fn = (
            input_fn_maker.build_json_example_transforming_serving_input_fn(
                raw_metadata=raw_metadata,
                raw_label_keys=[],
                raw_feature_keys=['raw_a', 'raw_b', 'raw_label'],
                transform_savedmodel_dir=transform_savedmodel_dir))

        with tf.Graph().as_default():
            with tf.Session().as_default() as session:
                outputs, labels, inputs = serving_input_fn()
                feed_inputs = {inputs['json_example']: feed_dict}
                transformed_a, transformed_b, transformed_label = session.run(
                    [
                        outputs['transformed_a'], outputs['transformed_b'],
                        outputs['transformed_label']
                    ],
                    feed_dict=feed_inputs)

        self.assertEqual(85, transformed_a[0][0])
        self.assertEqual(-1, transformed_b[0][0])
        self.assertEqual(44000, transformed_label[0][0])
        self.assertEqual(1, transformed_a[1][0])
        self.assertEqual(-1, transformed_b[1][0])
        self.assertEqual(3000, transformed_label[1][0])
        self.assertItemsEqual(
            outputs, {'transformed_a', 'transformed_b', 'transformed_label'})
        self.assertIsNone(labels)
        self.assertEqual(set(inputs.keys()), {'json_example'})