Пример #1
0
    def test_build_csv_transforming_serving_input_fn_with_label(self):
        feed_dict = ['15,6,1', '12,17,2']

        basedir = tempfile.mkdtemp()

        raw_metadata = dataset_metadata.DatasetMetadata(
            schema=_make_raw_schema([]))

        transform_savedmodel_dir = os.path.join(basedir,
                                                'transform-savedmodel')
        _write_transform_savedmodel(transform_savedmodel_dir)

        serving_input_fn = (
            input_fn_maker.build_csv_transforming_serving_input_fn(
                raw_metadata=raw_metadata,
                raw_keys=['raw_a', 'raw_b', 'raw_label'],
                transform_savedmodel_dir=transform_savedmodel_dir))

        with tf.Graph().as_default():
            with tf.Session().as_default() as session:
                outputs, labels, inputs = serving_input_fn()

                self.assertCountEqual(
                    outputs.keys(),
                    {'transformed_a', 'transformed_b', 'transformed_label'})
                self.assertIsNone(labels)
                self.assertEqual(set(inputs.keys()), {'csv_example'})

                feed_inputs = {inputs['csv_example']: feed_dict}
                transformed_a, transformed_b, transformed_label = session.run(
                    [
                        outputs['transformed_a'], outputs['transformed_b'],
                        outputs['transformed_label']
                    ],
                    feed_dict=feed_inputs)

        batch_shape = (len(feed_dict), 1)

        # transformed_b is sparse so _convert_scalars_to_vectors did not fix it
        sparse_batch_shape = (len(feed_dict), )
        transformed_b_dict = dict(
            zip([tuple(x + [0]) for x in transformed_b.indices.tolist()],
                transformed_b.values.tolist()))

        self.assertEqual(batch_shape, tuple(transformed_a.shape))
        self.assertEqual(sparse_batch_shape, tuple(transformed_b.dense_shape))
        self.assertEqual(batch_shape, tuple(transformed_label.shape))

        self.assertEqual(21, transformed_a[0][0])
        self.assertEqual(9, transformed_b_dict[(0, 0)])
        self.assertEqual(1000, transformed_label[0][0])
        self.assertEqual(29, transformed_a[1][0])
        self.assertEqual(-5, transformed_b_dict[(1, 0)])
        self.assertEqual(2000, transformed_label[1][0])
Пример #2
0
    def test_build_csv_transforming_serving_input_fn_with_defaults(self):
        feed_dict = [',,']

        basedir = tempfile.mkdtemp()

        raw_metadata = dataset_metadata.DatasetMetadata(
            schema=_make_raw_schema([]))

        transform_savedmodel_dir = os.path.join(basedir,
                                                'transform-savedmodel')
        _write_transform_savedmodel(transform_savedmodel_dir)

        serving_input_fn = (
            input_fn_maker.build_csv_transforming_serving_input_fn(
                raw_metadata=raw_metadata,
                raw_keys=['raw_a', 'raw_b', 'raw_label'],
                transform_savedmodel_dir=transform_savedmodel_dir))

        with tf.Graph().as_default():
            with tf.Session().as_default() as session:
                outputs, labels, inputs = serving_input_fn()

                self.assertCountEqual(
                    outputs.keys(),
                    {'transformed_a', 'transformed_b', 'transformed_label'})
                self.assertIsNone(labels)
                self.assertEqual(set(inputs.keys()), {'csv_example'})

                feed_inputs = {inputs['csv_example']: feed_dict}
                transformed_a, transformed_b, transformed_label = session.run(
                    [
                        outputs['transformed_a'], outputs['transformed_b'],
                        outputs['transformed_label']
                    ],
                    feed_dict=feed_inputs)

        self.assertEqual((1, 1), tuple(transformed_a.shape))
        # transformed_b is sparse so _convert_scalars_to_vectors did not fix it
        self.assertEqual((1, ), tuple(transformed_b.dense_shape))
        self.assertEqual((1, 1), tuple(transformed_label.shape))

        transformed_b_dict = dict(
            zip([tuple(x) for x in transformed_b.indices.tolist()],
                transformed_b.values.tolist()))

        # Note the feed dict is empty. So these values come from the defaults
        # in _make_raw_schema()
        self.assertEqual(1, transformed_a[0][0])
        # transformed_b is sparse so _convert_scalars_to_vectors did not fix it
        self.assertEqual(-1, transformed_b_dict[(0, )])
        self.assertEqual(-1000, transformed_label[0][0])
Пример #3
0
    def test_build_csv_transforming_serving_input_fn_with_label(self):
        feed_dict = ['15,6,1', '12,17,2']

        basedir = tempfile.mkdtemp()

        raw_metadata = dataset_metadata.DatasetMetadata(
            schema=_make_raw_schema([]))

        transform_savedmodel_dir = os.path.join(basedir,
                                                'transform-savedmodel')
        _write_transform_savedmodel(transform_savedmodel_dir)

        serving_input_fn = (
            input_fn_maker.build_csv_transforming_serving_input_fn(
                raw_metadata=raw_metadata,
                raw_keys=['raw_a', 'raw_b', 'raw_label'],
                transform_savedmodel_dir=transform_savedmodel_dir))

        with tf.Graph().as_default():
            with tf.Session().as_default() as session:
                outputs, labels, inputs = serving_input_fn()
                feed_inputs = {inputs['csv_example']: feed_dict}
                transformed_a, transformed_b, transformed_label = session.run(
                    [
                        outputs['transformed_a'], outputs['transformed_b'],
                        outputs['transformed_label']
                    ],
                    feed_dict=feed_inputs)

        self.assertEqual(21, transformed_a[0][0])
        self.assertEqual(9, transformed_b[0][0])
        self.assertEqual(1000, transformed_label[0][0])
        self.assertEqual(29, transformed_a[1][0])
        self.assertEqual(-5, transformed_b[1][0])
        self.assertEqual(2000, transformed_label[1][0])
        self.assertItemsEqual(
            outputs, {'transformed_a', 'transformed_b', 'transformed_label'})
        self.assertIsNone(labels)
        self.assertEqual(set(inputs.keys()), {'csv_example'})