Exemplo n.º 1
0
    def test_no_column_clause(self):
        columns = [
            "sepal_length",
            "sepal_width",
            "petal_length",
            "petal_width",
        ]

        select = "select %s, class from iris.train" % ",".join(columns)

        conn = testing.get_singleton_db_connection()
        features = None
        label = NumericColumn(
            FieldDesc(name='class', dtype=DataType.INT64, shape=[1]))
        features, label = fd.infer_feature_columns(conn, select, features,
                                                   label)

        self.check_json_dump(features)
        self.check_json_dump(label)

        self.assertEqual(len(features), 1)
        self.assertTrue("feature_columns" in features)
        features = features["feature_columns"]
        self.assertEqual(len(features), 4)

        for i, f in enumerate(features):
            self.assertTrue(isinstance(f, NumericColumn))
            self.assertEqual(len(f.get_field_desc()), 1)
            field_desc = f.get_field_desc()[0]
            self.assertEqual(field_desc.name, columns[i])
            self.assertEqual(field_desc.dtype, DataType.FLOAT32)
            self.assertEqual(field_desc.format, DataFormat.PLAIN)
            self.assertFalse(field_desc.is_sparse)
            self.assertEqual(field_desc.shape, [1])

        self.assertTrue(isinstance(label, NumericColumn))
        self.assertEqual(len(label.get_field_desc()), 1)
        field_desc = label.get_field_desc()[0]
        self.assertEqual(field_desc.name, "class")
        self.assertEqual(field_desc.dtype, DataType.INT64)
        self.assertEqual(field_desc.format, DataFormat.PLAIN)
        self.assertFalse(field_desc.is_sparse)
        self.assertEqual(field_desc.shape, [])
Exemplo n.º 2
0
    def test_with_cross(self):
        c1 = NumericColumn(
            FieldDesc(name='c1', dtype=DataType.INT64, shape=[1]))
        c2 = NumericColumn(
            FieldDesc(name='c2', dtype=DataType.INT64, shape=[1]))
        c4 = NumericColumn(
            FieldDesc(name='c4', dtype=DataType.INT64, shape=[1]))
        c5 = NumericColumn(
            FieldDesc(name='c5',
                      dtype=DataType.INT64,
                      shape=[1],
                      is_sparse=True))

        features = {
            'feature_columns': [
                c1,
                c2,
                CrossColumn([c4, c5], 128),
                CrossColumn([c1, c2], 256),
            ]
        }

        label = NumericColumn(
            FieldDesc(name='class', dtype=DataType.INT64, shape=[1]))
        select = "select c1, c2, c3, c4, c5, class " \
                 "from feature_derivation_case.train"

        conn = testing.get_singleton_db_connection()
        features, label = fd.infer_feature_columns(conn, select, features,
                                                   label)

        self.check_json_dump(features)
        self.check_json_dump(label)

        self.assertEqual(len(features), 1)
        self.assertTrue("feature_columns" in features)
        features = features["feature_columns"]
        self.assertEqual(len(features), 5)

        fc1 = features[0]
        self.assertTrue(isinstance(fc1, NumericColumn))
        self.assertEqual(len(fc1.get_field_desc()), 1)
        field_desc = fc1.get_field_desc()[0]
        self.assertEqual(field_desc.name, "c1")
        self.assertEqual(field_desc.dtype, DataType.FLOAT32)
        self.assertEqual(field_desc.format, DataFormat.PLAIN)
        self.assertFalse(field_desc.is_sparse)
        self.assertEqual(field_desc.shape, [1])

        fc2 = features[1]
        self.assertTrue(isinstance(fc2, NumericColumn))
        self.assertEqual(len(fc2.get_field_desc()), 1)
        field_desc = fc2.get_field_desc()[0]
        self.assertEqual(field_desc.name, "c2")
        self.assertEqual(field_desc.dtype, DataType.FLOAT32)
        self.assertEqual(field_desc.format, DataFormat.PLAIN)
        self.assertFalse(field_desc.is_sparse)
        self.assertEqual(field_desc.shape, [1])

        fc3 = features[2]
        self.assertTrue(isinstance(fc3, NumericColumn))
        self.assertEqual(len(fc3.get_field_desc()), 1)
        field_desc = fc3.get_field_desc()[0]
        self.assertEqual(field_desc.name, "c3")
        self.assertEqual(field_desc.dtype, DataType.INT64)
        self.assertEqual(field_desc.format, DataFormat.CSV)
        self.assertFalse(field_desc.is_sparse)
        self.assertEqual(field_desc.shape, [4])

        fc4 = features[3]
        self.assertTrue(isinstance(fc4, CrossColumn))
        self.assertEqual(len(fc4.get_field_desc()), 2)
        field_desc1 = fc4.get_field_desc()[0]
        self.assertEqual(field_desc1.name, "c4")
        self.assertEqual(field_desc1.dtype, DataType.FLOAT32)
        self.assertEqual(field_desc1.format, DataFormat.CSV)
        self.assertEqual(field_desc1.shape, [4])
        self.assertFalse(field_desc1.is_sparse)
        field_desc2 = fc4.get_field_desc()[1]
        self.assertEqual(field_desc2.name, "c5")
        self.assertEqual(field_desc2.dtype, DataType.INT64)
        self.assertEqual(field_desc2.format, DataFormat.CSV)
        self.assertTrue(field_desc2.is_sparse)

        fc5 = features[4]
        self.assertTrue(isinstance(fc5, CrossColumn))
        self.assertEqual(len(fc4.get_field_desc()), 2)
        field_desc1 = fc5.get_field_desc()[0]
        self.assertEqual(field_desc1.name, "c1")
        self.assertEqual(field_desc1.dtype, DataType.FLOAT32)
        self.assertEqual(field_desc1.format, DataFormat.PLAIN)
        self.assertEqual(field_desc1.shape, [1])
        self.assertFalse(field_desc1.is_sparse)
        field_desc2 = fc5.get_field_desc()[1]
        self.assertEqual(field_desc2.name, "c2")
        self.assertEqual(field_desc2.dtype, DataType.FLOAT32)
        self.assertEqual(field_desc2.format, DataFormat.PLAIN)
        self.assertEqual(field_desc2.shape, [1])
        self.assertFalse(field_desc2.is_sparse)

        self.assertTrue(isinstance(label, NumericColumn))
        self.assertEqual(len(label.get_field_desc()), 1)
        field_desc = label.get_field_desc()[0]
        self.assertEqual(field_desc.name, "class")
        self.assertEqual(field_desc.dtype, DataType.INT64)
        self.assertEqual(field_desc.format, DataFormat.PLAIN)
        self.assertFalse(field_desc.is_sparse)
        self.assertEqual(field_desc.shape, [])
Exemplo n.º 3
0
    def test_without_cross(self):
        features = {
            'feature_columns': [
                EmbeddingColumn(dimension=256, combiner="mean", name="c3"),
                EmbeddingColumn(category_column=CategoryIDColumn(
                    FieldDesc(name="c5",
                              dtype=DataType.INT64,
                              shape=[10000],
                              delimiter=",",
                              is_sparse=True),
                    bucket_size=5000),
                                dimension=64,
                                combiner="sqrtn",
                                name="c5"),
            ]
        }

        label = NumericColumn(
            FieldDesc(name="class", dtype=DataType.INT64, shape=[1]))

        select = "select c1, c2, c3, c4, c5, c6, class " \
                 "from feature_derivation_case.train"
        conn = testing.get_singleton_db_connection()
        features, label = fd.infer_feature_columns(conn, select, features,
                                                   label)

        self.check_json_dump(features)
        self.check_json_dump(label)

        self.assertEqual(len(features), 1)
        self.assertTrue("feature_columns" in features)
        features = features["feature_columns"]
        self.assertEqual(len(features), 6)

        fc1 = features[0]
        self.assertTrue(isinstance(fc1, NumericColumn))
        self.assertEqual(len(fc1.get_field_desc()), 1)
        field_desc = fc1.get_field_desc()[0]
        self.assertEqual(field_desc.name, "c1")
        self.assertEqual(field_desc.dtype, DataType.FLOAT32)
        self.assertEqual(field_desc.format, DataFormat.PLAIN)
        self.assertFalse(field_desc.is_sparse)
        self.assertEqual(field_desc.shape, [1])

        fc2 = features[1]
        self.assertTrue(isinstance(fc2, NumericColumn))
        self.assertEqual(len(fc2.get_field_desc()), 1)
        field_desc = fc2.get_field_desc()[0]
        self.assertEqual(field_desc.name, "c2")
        self.assertEqual(field_desc.dtype, DataType.FLOAT32)
        self.assertEqual(field_desc.format, DataFormat.PLAIN)
        self.assertFalse(field_desc.is_sparse)
        self.assertEqual(field_desc.shape, [1])

        fc3 = features[2]
        self.assertTrue(isinstance(fc3, EmbeddingColumn))
        self.assertEqual(len(fc3.get_field_desc()), 1)
        field_desc = fc3.get_field_desc()[0]
        self.assertEqual(field_desc.name, "c3")
        self.assertEqual(field_desc.dtype, DataType.INT64)
        self.assertEqual(field_desc.format, DataFormat.CSV)
        self.assertFalse(field_desc.is_sparse)
        self.assertEqual(field_desc.shape, [4])
        self.assertEqual(fc3.dimension, 256)
        self.assertEqual(fc3.combiner, "mean")
        self.assertEqual(fc3.name, "c3")
        self.assertTrue(isinstance(fc3.category_column, CategoryIDColumn))
        self.assertEqual(fc3.category_column.bucket_size, 10)

        fc4 = features[3]
        self.assertTrue(isinstance(fc4, NumericColumn))
        self.assertEqual(len(fc4.get_field_desc()), 1)
        field_desc = fc4.get_field_desc()[0]
        self.assertEqual(field_desc.name, "c4")
        self.assertEqual(field_desc.dtype, DataType.FLOAT32)
        self.assertEqual(field_desc.format, DataFormat.CSV)
        self.assertFalse(field_desc.is_sparse)
        self.assertEqual(field_desc.shape, [4])

        fc5 = features[4]
        self.assertTrue(isinstance(fc5, EmbeddingColumn))
        self.assertEqual(len(fc5.get_field_desc()), 1)
        field_desc = fc5.get_field_desc()[0]
        self.assertEqual(field_desc.name, "c5")
        self.assertEqual(field_desc.dtype, DataType.INT64)
        self.assertEqual(field_desc.format, DataFormat.CSV)
        self.assertTrue(field_desc.is_sparse)
        self.assertEqual(field_desc.shape, [10000])
        self.assertEqual(fc5.dimension, 64)
        self.assertEqual(fc5.combiner, "sqrtn")
        self.assertEqual(fc5.name, "c5")
        self.assertTrue(isinstance(fc5.category_column, CategoryIDColumn))
        self.assertEqual(fc5.category_column.bucket_size, 5000)

        fc6 = features[5]
        self.assertTrue(isinstance(fc6, EmbeddingColumn))
        self.assertEqual(len(fc6.get_field_desc()), 1)
        field_desc = fc6.get_field_desc()[0]
        self.assertEqual(field_desc.name, "c6")
        self.assertEqual(field_desc.dtype, DataType.STRING)
        self.assertEqual(field_desc.format, DataFormat.PLAIN)
        self.assertFalse(field_desc.is_sparse)
        self.assertEqual(field_desc.shape, [1])
        self.assertEqual(field_desc.vocabulary, set(['FEMALE', 'MALE',
                                                     'NULL']))
        self.assertEqual(fc6.dimension, 128)
        self.assertEqual(fc6.combiner, "sum")
        self.assertEqual(fc6.name, "c6")
        self.assertTrue(isinstance(fc6.category_column, CategoryIDColumn))
        self.assertEqual(fc6.category_column.bucket_size, 3)

        self.assertTrue(isinstance(label, NumericColumn))
        self.assertEqual(len(label.get_field_desc()), 1)
        field_desc = label.get_field_desc()[0]
        self.assertEqual(field_desc.name, "class")
        self.assertEqual(field_desc.dtype, DataType.INT64)
        self.assertEqual(field_desc.format, DataFormat.PLAIN)
        self.assertFalse(field_desc.is_sparse)
        self.assertEqual(field_desc.shape, [])