def test_build_csv_transforming_serving_input_fn_with_label(self): feed_dict = ['15,6,1', '12,17,2'] basedir = tempfile.mkdtemp() raw_metadata = dataset_metadata.DatasetMetadata( schema=_make_raw_schema([])) transform_savedmodel_dir = os.path.join(basedir, 'transform-savedmodel') _write_transform_savedmodel(transform_savedmodel_dir) serving_input_fn = ( input_fn_maker.build_csv_transforming_serving_input_fn( raw_metadata=raw_metadata, raw_keys=['raw_a', 'raw_b', 'raw_label'], transform_savedmodel_dir=transform_savedmodel_dir)) with tf.Graph().as_default(): with tf.Session().as_default() as session: outputs, labels, inputs = serving_input_fn() self.assertCountEqual( outputs.keys(), {'transformed_a', 'transformed_b', 'transformed_label'}) self.assertIsNone(labels) self.assertEqual(set(inputs.keys()), {'csv_example'}) feed_inputs = {inputs['csv_example']: feed_dict} transformed_a, transformed_b, transformed_label = session.run( [ outputs['transformed_a'], outputs['transformed_b'], outputs['transformed_label'] ], feed_dict=feed_inputs) batch_shape = (len(feed_dict), 1) # transformed_b is sparse so _convert_scalars_to_vectors did not fix it sparse_batch_shape = (len(feed_dict), ) transformed_b_dict = dict( zip([tuple(x + [0]) for x in transformed_b.indices.tolist()], transformed_b.values.tolist())) self.assertEqual(batch_shape, tuple(transformed_a.shape)) self.assertEqual(sparse_batch_shape, tuple(transformed_b.dense_shape)) self.assertEqual(batch_shape, tuple(transformed_label.shape)) self.assertEqual(21, transformed_a[0][0]) self.assertEqual(9, transformed_b_dict[(0, 0)]) self.assertEqual(1000, transformed_label[0][0]) self.assertEqual(29, transformed_a[1][0]) self.assertEqual(-5, transformed_b_dict[(1, 0)]) self.assertEqual(2000, transformed_label[1][0])
def test_build_csv_transforming_serving_input_fn_with_defaults(self): feed_dict = [',,'] basedir = tempfile.mkdtemp() raw_metadata = dataset_metadata.DatasetMetadata( schema=_make_raw_schema([])) transform_savedmodel_dir = os.path.join(basedir, 'transform-savedmodel') _write_transform_savedmodel(transform_savedmodel_dir) serving_input_fn = ( input_fn_maker.build_csv_transforming_serving_input_fn( raw_metadata=raw_metadata, raw_keys=['raw_a', 'raw_b', 'raw_label'], transform_savedmodel_dir=transform_savedmodel_dir)) with tf.Graph().as_default(): with tf.Session().as_default() as session: outputs, labels, inputs = serving_input_fn() self.assertCountEqual( outputs.keys(), {'transformed_a', 'transformed_b', 'transformed_label'}) self.assertIsNone(labels) self.assertEqual(set(inputs.keys()), {'csv_example'}) feed_inputs = {inputs['csv_example']: feed_dict} transformed_a, transformed_b, transformed_label = session.run( [ outputs['transformed_a'], outputs['transformed_b'], outputs['transformed_label'] ], feed_dict=feed_inputs) self.assertEqual((1, 1), tuple(transformed_a.shape)) # transformed_b is sparse so _convert_scalars_to_vectors did not fix it self.assertEqual((1, ), tuple(transformed_b.dense_shape)) self.assertEqual((1, 1), tuple(transformed_label.shape)) transformed_b_dict = dict( zip([tuple(x) for x in transformed_b.indices.tolist()], transformed_b.values.tolist())) # Note the feed dict is empty. So these values come from the defaults # in _make_raw_schema() self.assertEqual(1, transformed_a[0][0]) # transformed_b is sparse so _convert_scalars_to_vectors did not fix it self.assertEqual(-1, transformed_b_dict[(0, )]) self.assertEqual(-1000, transformed_label[0][0])
def test_build_csv_transforming_serving_input_fn_with_label(self): feed_dict = ['15,6,1', '12,17,2'] basedir = tempfile.mkdtemp() raw_metadata = dataset_metadata.DatasetMetadata( schema=_make_raw_schema([])) transform_savedmodel_dir = os.path.join(basedir, 'transform-savedmodel') _write_transform_savedmodel(transform_savedmodel_dir) serving_input_fn = ( input_fn_maker.build_csv_transforming_serving_input_fn( raw_metadata=raw_metadata, raw_keys=['raw_a', 'raw_b', 'raw_label'], transform_savedmodel_dir=transform_savedmodel_dir)) with tf.Graph().as_default(): with tf.Session().as_default() as session: outputs, labels, inputs = serving_input_fn() feed_inputs = {inputs['csv_example']: feed_dict} transformed_a, transformed_b, transformed_label = session.run( [ outputs['transformed_a'], outputs['transformed_b'], outputs['transformed_label'] ], feed_dict=feed_inputs) self.assertEqual(21, transformed_a[0][0]) self.assertEqual(9, transformed_b[0][0]) self.assertEqual(1000, transformed_label[0][0]) self.assertEqual(29, transformed_a[1][0]) self.assertEqual(-5, transformed_b[1][0]) self.assertEqual(2000, transformed_label[1][0]) self.assertItemsEqual( outputs, {'transformed_a', 'transformed_b', 'transformed_label'}) self.assertIsNone(labels) self.assertEqual(set(inputs.keys()), {'csv_example'})