def test_no_column_clause(self): columns = [ "sepal_length", "sepal_width", "petal_length", "petal_width", ] select = "select %s, class from iris.train" % ",".join(columns) conn = testing.get_singleton_db_connection() features = None label = NumericColumn( FieldDesc(name='class', dtype=DataType.INT64, shape=[1])) features, label = fd.infer_feature_columns(conn, select, features, label) self.check_json_dump(features) self.check_json_dump(label) self.assertEqual(len(features), 1) self.assertTrue("feature_columns" in features) features = features["feature_columns"] self.assertEqual(len(features), 4) for i, f in enumerate(features): self.assertTrue(isinstance(f, NumericColumn)) self.assertEqual(len(f.get_field_desc()), 1) field_desc = f.get_field_desc()[0] self.assertEqual(field_desc.name, columns[i]) self.assertEqual(field_desc.dtype, DataType.FLOAT32) self.assertEqual(field_desc.format, DataFormat.PLAIN) self.assertFalse(field_desc.is_sparse) self.assertEqual(field_desc.shape, [1]) self.assertTrue(isinstance(label, NumericColumn)) self.assertEqual(len(label.get_field_desc()), 1) field_desc = label.get_field_desc()[0] self.assertEqual(field_desc.name, "class") self.assertEqual(field_desc.dtype, DataType.INT64) self.assertEqual(field_desc.format, DataFormat.PLAIN) self.assertFalse(field_desc.is_sparse) self.assertEqual(field_desc.shape, [])
def test_with_cross(self): c1 = NumericColumn( FieldDesc(name='c1', dtype=DataType.INT64, shape=[1])) c2 = NumericColumn( FieldDesc(name='c2', dtype=DataType.INT64, shape=[1])) c4 = NumericColumn( FieldDesc(name='c4', dtype=DataType.INT64, shape=[1])) c5 = NumericColumn( FieldDesc(name='c5', dtype=DataType.INT64, shape=[1], is_sparse=True)) features = { 'feature_columns': [ c1, c2, CrossColumn([c4, c5], 128), CrossColumn([c1, c2], 256), ] } label = NumericColumn( FieldDesc(name='class', dtype=DataType.INT64, shape=[1])) select = "select c1, c2, c3, c4, c5, class " \ "from feature_derivation_case.train" conn = testing.get_singleton_db_connection() features, label = fd.infer_feature_columns(conn, select, features, label) self.check_json_dump(features) self.check_json_dump(label) self.assertEqual(len(features), 1) self.assertTrue("feature_columns" in features) features = features["feature_columns"] self.assertEqual(len(features), 5) fc1 = features[0] self.assertTrue(isinstance(fc1, NumericColumn)) self.assertEqual(len(fc1.get_field_desc()), 1) field_desc = fc1.get_field_desc()[0] self.assertEqual(field_desc.name, "c1") self.assertEqual(field_desc.dtype, DataType.FLOAT32) self.assertEqual(field_desc.format, DataFormat.PLAIN) self.assertFalse(field_desc.is_sparse) self.assertEqual(field_desc.shape, [1]) fc2 = features[1] self.assertTrue(isinstance(fc2, NumericColumn)) self.assertEqual(len(fc2.get_field_desc()), 1) field_desc = fc2.get_field_desc()[0] self.assertEqual(field_desc.name, "c2") self.assertEqual(field_desc.dtype, DataType.FLOAT32) self.assertEqual(field_desc.format, DataFormat.PLAIN) self.assertFalse(field_desc.is_sparse) self.assertEqual(field_desc.shape, [1]) fc3 = features[2] self.assertTrue(isinstance(fc3, NumericColumn)) self.assertEqual(len(fc3.get_field_desc()), 1) field_desc = fc3.get_field_desc()[0] self.assertEqual(field_desc.name, "c3") self.assertEqual(field_desc.dtype, DataType.INT64) self.assertEqual(field_desc.format, DataFormat.CSV) self.assertFalse(field_desc.is_sparse) self.assertEqual(field_desc.shape, [4]) fc4 = features[3] self.assertTrue(isinstance(fc4, CrossColumn)) self.assertEqual(len(fc4.get_field_desc()), 2) field_desc1 = fc4.get_field_desc()[0] self.assertEqual(field_desc1.name, "c4") self.assertEqual(field_desc1.dtype, DataType.FLOAT32) self.assertEqual(field_desc1.format, DataFormat.CSV) self.assertEqual(field_desc1.shape, [4]) self.assertFalse(field_desc1.is_sparse) field_desc2 = fc4.get_field_desc()[1] self.assertEqual(field_desc2.name, "c5") self.assertEqual(field_desc2.dtype, DataType.INT64) self.assertEqual(field_desc2.format, DataFormat.CSV) self.assertTrue(field_desc2.is_sparse) fc5 = features[4] self.assertTrue(isinstance(fc5, CrossColumn)) self.assertEqual(len(fc4.get_field_desc()), 2) field_desc1 = fc5.get_field_desc()[0] self.assertEqual(field_desc1.name, "c1") self.assertEqual(field_desc1.dtype, DataType.FLOAT32) self.assertEqual(field_desc1.format, DataFormat.PLAIN) self.assertEqual(field_desc1.shape, [1]) self.assertFalse(field_desc1.is_sparse) field_desc2 = fc5.get_field_desc()[1] self.assertEqual(field_desc2.name, "c2") self.assertEqual(field_desc2.dtype, DataType.FLOAT32) self.assertEqual(field_desc2.format, DataFormat.PLAIN) self.assertEqual(field_desc2.shape, [1]) self.assertFalse(field_desc2.is_sparse) self.assertTrue(isinstance(label, NumericColumn)) self.assertEqual(len(label.get_field_desc()), 1) field_desc = label.get_field_desc()[0] self.assertEqual(field_desc.name, "class") self.assertEqual(field_desc.dtype, DataType.INT64) self.assertEqual(field_desc.format, DataFormat.PLAIN) self.assertFalse(field_desc.is_sparse) self.assertEqual(field_desc.shape, [])
def test_without_cross(self): features = { 'feature_columns': [ EmbeddingColumn(dimension=256, combiner="mean", name="c3"), EmbeddingColumn(category_column=CategoryIDColumn( FieldDesc(name="c5", dtype=DataType.INT64, shape=[10000], delimiter=",", is_sparse=True), bucket_size=5000), dimension=64, combiner="sqrtn", name="c5"), ] } label = NumericColumn( FieldDesc(name="class", dtype=DataType.INT64, shape=[1])) select = "select c1, c2, c3, c4, c5, c6, class " \ "from feature_derivation_case.train" conn = testing.get_singleton_db_connection() features, label = fd.infer_feature_columns(conn, select, features, label) self.check_json_dump(features) self.check_json_dump(label) self.assertEqual(len(features), 1) self.assertTrue("feature_columns" in features) features = features["feature_columns"] self.assertEqual(len(features), 6) fc1 = features[0] self.assertTrue(isinstance(fc1, NumericColumn)) self.assertEqual(len(fc1.get_field_desc()), 1) field_desc = fc1.get_field_desc()[0] self.assertEqual(field_desc.name, "c1") self.assertEqual(field_desc.dtype, DataType.FLOAT32) self.assertEqual(field_desc.format, DataFormat.PLAIN) self.assertFalse(field_desc.is_sparse) self.assertEqual(field_desc.shape, [1]) fc2 = features[1] self.assertTrue(isinstance(fc2, NumericColumn)) self.assertEqual(len(fc2.get_field_desc()), 1) field_desc = fc2.get_field_desc()[0] self.assertEqual(field_desc.name, "c2") self.assertEqual(field_desc.dtype, DataType.FLOAT32) self.assertEqual(field_desc.format, DataFormat.PLAIN) self.assertFalse(field_desc.is_sparse) self.assertEqual(field_desc.shape, [1]) fc3 = features[2] self.assertTrue(isinstance(fc3, EmbeddingColumn)) self.assertEqual(len(fc3.get_field_desc()), 1) field_desc = fc3.get_field_desc()[0] self.assertEqual(field_desc.name, "c3") self.assertEqual(field_desc.dtype, DataType.INT64) self.assertEqual(field_desc.format, DataFormat.CSV) self.assertFalse(field_desc.is_sparse) self.assertEqual(field_desc.shape, [4]) self.assertEqual(fc3.dimension, 256) self.assertEqual(fc3.combiner, "mean") self.assertEqual(fc3.name, "c3") self.assertTrue(isinstance(fc3.category_column, CategoryIDColumn)) self.assertEqual(fc3.category_column.bucket_size, 10) fc4 = features[3] self.assertTrue(isinstance(fc4, NumericColumn)) self.assertEqual(len(fc4.get_field_desc()), 1) field_desc = fc4.get_field_desc()[0] self.assertEqual(field_desc.name, "c4") self.assertEqual(field_desc.dtype, DataType.FLOAT32) self.assertEqual(field_desc.format, DataFormat.CSV) self.assertFalse(field_desc.is_sparse) self.assertEqual(field_desc.shape, [4]) fc5 = features[4] self.assertTrue(isinstance(fc5, EmbeddingColumn)) self.assertEqual(len(fc5.get_field_desc()), 1) field_desc = fc5.get_field_desc()[0] self.assertEqual(field_desc.name, "c5") self.assertEqual(field_desc.dtype, DataType.INT64) self.assertEqual(field_desc.format, DataFormat.CSV) self.assertTrue(field_desc.is_sparse) self.assertEqual(field_desc.shape, [10000]) self.assertEqual(fc5.dimension, 64) self.assertEqual(fc5.combiner, "sqrtn") self.assertEqual(fc5.name, "c5") self.assertTrue(isinstance(fc5.category_column, CategoryIDColumn)) self.assertEqual(fc5.category_column.bucket_size, 5000) fc6 = features[5] self.assertTrue(isinstance(fc6, EmbeddingColumn)) self.assertEqual(len(fc6.get_field_desc()), 1) field_desc = fc6.get_field_desc()[0] self.assertEqual(field_desc.name, "c6") self.assertEqual(field_desc.dtype, DataType.STRING) self.assertEqual(field_desc.format, DataFormat.PLAIN) self.assertFalse(field_desc.is_sparse) self.assertEqual(field_desc.shape, [1]) self.assertEqual(field_desc.vocabulary, set(['FEMALE', 'MALE', 'NULL'])) self.assertEqual(fc6.dimension, 128) self.assertEqual(fc6.combiner, "sum") self.assertEqual(fc6.name, "c6") self.assertTrue(isinstance(fc6.category_column, CategoryIDColumn)) self.assertEqual(fc6.category_column.bucket_size, 3) self.assertTrue(isinstance(label, NumericColumn)) self.assertEqual(len(label.get_field_desc()), 1) field_desc = label.get_field_desc()[0] self.assertEqual(field_desc.name, "class") self.assertEqual(field_desc.dtype, DataType.INT64) self.assertEqual(field_desc.format, DataFormat.PLAIN) self.assertFalse(field_desc.is_sparse) self.assertEqual(field_desc.shape, [])