def test_category_id_column(self): cc = CategoryIDColumn(FieldDesc(name='c1'), 128) for model_type in [TENSORFLOW, XGBOOST]: compiled_cc = self.compile_fc(cc, model_type) self.assertEqual(compiled_cc.key, 'c1') self.assertEqual(compiled_cc.num_buckets, 128) cc = CategoryIDColumn(FieldDesc(name='c1', vocabulary=set(['a', 'b'])), 128) for model_type in [TENSORFLOW, XGBOOST]: compiled_cc = self.compile_fc(cc, model_type) vocab = sorted(compiled_cc.vocabulary_list) self.assertEqual(vocab, ['a', 'b'])
def test_indicator_column(self): cc = CategoryIDColumn(FieldDesc(name='c1'), 128) ic = IndicatorColumn(category_column=cc) for model_type in [TENSORFLOW, XGBOOST]: compiled_chc = self.compile_fc(ic, model_type) compiled_cc = compiled_chc.categorical_column self.assertEqual(compiled_cc.key, 'c1') self.assertEqual(compiled_cc.num_buckets, 128)
def update_feature_column(fc, fd_map): """ Update the FeatureColumn object by the FieldDesc map. Args: fc (FeatureColumn): a FeatureColumn object. Only EmbeddingColumn and IndicatorColumn without category_column info would be updated currently. fd_map (dict[str -> FieldDesc]): a FieldDesc map, where the key is the field name. Returns: None. """ if isinstance(fc, EmbeddingColumn) and fc.category_column is None: field_desc = fd_map[fc.name] if field_desc is None: raise ValueError("column not found or inferred: %s" % fc.name) # FIXME(typhoonzero): when to use sequence_category_id_column? # if column fieldDesc is SPARSE, the sparse shape should # be in cs.Shape[0] bucket_size = field_desc.shape[0] if not field_desc.is_sparse: assert field_desc.max_id > 0, \ "use dense column on embedding column " \ "but did not got a correct MaxID" bucket_size = field_desc.max_id + 1 fc.category_column = CategoryIDColumn(field_desc, bucket_size) return if isinstance(fc, IndicatorColumn) and fc.category_column is None: field_desc = fd_map[fc.name] if field_desc is None: raise ValueError("column not found or inferred: %s" % fc.name) assert field_desc.is_sparse, \ "cannot use sparse column with indicator column" assert field_desc.max_id > 0, \ "use indicator column but did not got a correct MaxID" bucket_size = field_desc.max_id + 1 fc.category_column = CategoryIDColumn(field_desc, bucket_size)
def new_feature_column(field_desc): """ Create a new FeatureColumn object by the given FieldDesc object. Args: field_desc (FieldDesc): a given FieldDesc object. Returns: If field_desc.dtype is STRING, return an EmbeddingColumn object. Otherwise, return a NumericColumn object. """ if field_desc.dtype != DataType.STRING: return NumericColumn(field_desc) else: category_column = CategoryIDColumn(field_desc, len(field_desc.vocabulary)) # NOTE(typhoonzero): a default embedding size of 128 is enough # for most cases. embedding = EmbeddingColumn(category_column=category_column, dimension=128, combiner="sum") embedding.name = field_desc.name return embedding
def test_without_cross(self): features = { 'feature_columns': [ EmbeddingColumn(dimension=256, combiner="mean", name="c3"), EmbeddingColumn(category_column=CategoryIDColumn( FieldDesc(name="c5", dtype=DataType.INT64, shape=[10000], delimiter=",", is_sparse=True), bucket_size=5000), dimension=64, combiner="sqrtn", name="c5"), ] } label = NumericColumn( FieldDesc(name="class", dtype=DataType.INT64, shape=[1])) select = "select c1, c2, c3, c4, c5, c6, class " \ "from feature_derivation_case.train" conn = testing.get_singleton_db_connection() features, label = fd.infer_feature_columns(conn, select, features, label) self.check_json_dump(features) self.check_json_dump(label) self.assertEqual(len(features), 1) self.assertTrue("feature_columns" in features) features = features["feature_columns"] self.assertEqual(len(features), 6) fc1 = features[0] self.assertTrue(isinstance(fc1, NumericColumn)) self.assertEqual(len(fc1.get_field_desc()), 1) field_desc = fc1.get_field_desc()[0] self.assertEqual(field_desc.name, "c1") self.assertEqual(field_desc.dtype, DataType.FLOAT32) self.assertEqual(field_desc.format, DataFormat.PLAIN) self.assertFalse(field_desc.is_sparse) self.assertEqual(field_desc.shape, [1]) fc2 = features[1] self.assertTrue(isinstance(fc2, NumericColumn)) self.assertEqual(len(fc2.get_field_desc()), 1) field_desc = fc2.get_field_desc()[0] self.assertEqual(field_desc.name, "c2") self.assertEqual(field_desc.dtype, DataType.FLOAT32) self.assertEqual(field_desc.format, DataFormat.PLAIN) self.assertFalse(field_desc.is_sparse) self.assertEqual(field_desc.shape, [1]) fc3 = features[2] self.assertTrue(isinstance(fc3, EmbeddingColumn)) self.assertEqual(len(fc3.get_field_desc()), 1) field_desc = fc3.get_field_desc()[0] self.assertEqual(field_desc.name, "c3") self.assertEqual(field_desc.dtype, DataType.INT64) self.assertEqual(field_desc.format, DataFormat.CSV) self.assertFalse(field_desc.is_sparse) self.assertEqual(field_desc.shape, [4]) self.assertEqual(fc3.dimension, 256) self.assertEqual(fc3.combiner, "mean") self.assertEqual(fc3.name, "c3") self.assertTrue(isinstance(fc3.category_column, CategoryIDColumn)) self.assertEqual(fc3.category_column.bucket_size, 10) fc4 = features[3] self.assertTrue(isinstance(fc4, NumericColumn)) self.assertEqual(len(fc4.get_field_desc()), 1) field_desc = fc4.get_field_desc()[0] self.assertEqual(field_desc.name, "c4") self.assertEqual(field_desc.dtype, DataType.FLOAT32) self.assertEqual(field_desc.format, DataFormat.CSV) self.assertFalse(field_desc.is_sparse) self.assertEqual(field_desc.shape, [4]) fc5 = features[4] self.assertTrue(isinstance(fc5, EmbeddingColumn)) self.assertEqual(len(fc5.get_field_desc()), 1) field_desc = fc5.get_field_desc()[0] self.assertEqual(field_desc.name, "c5") self.assertEqual(field_desc.dtype, DataType.INT64) self.assertEqual(field_desc.format, DataFormat.CSV) self.assertTrue(field_desc.is_sparse) self.assertEqual(field_desc.shape, [10000]) self.assertEqual(fc5.dimension, 64) self.assertEqual(fc5.combiner, "sqrtn") self.assertEqual(fc5.name, "c5") self.assertTrue(isinstance(fc5.category_column, CategoryIDColumn)) self.assertEqual(fc5.category_column.bucket_size, 5000) fc6 = features[5] self.assertTrue(isinstance(fc6, EmbeddingColumn)) self.assertEqual(len(fc6.get_field_desc()), 1) field_desc = fc6.get_field_desc()[0] self.assertEqual(field_desc.name, "c6") self.assertEqual(field_desc.dtype, DataType.STRING) self.assertEqual(field_desc.format, DataFormat.PLAIN) self.assertFalse(field_desc.is_sparse) self.assertEqual(field_desc.shape, [1]) self.assertEqual(field_desc.vocabulary, set(['FEMALE', 'MALE', 'NULL'])) self.assertEqual(fc6.dimension, 128) self.assertEqual(fc6.combiner, "sum") self.assertEqual(fc6.name, "c6") self.assertTrue(isinstance(fc6.category_column, CategoryIDColumn)) self.assertEqual(fc6.category_column.bucket_size, 3) self.assertTrue(isinstance(label, NumericColumn)) self.assertEqual(len(label.get_field_desc()), 1) field_desc = label.get_field_desc()[0] self.assertEqual(field_desc.name, "class") self.assertEqual(field_desc.dtype, DataType.INT64) self.assertEqual(field_desc.format, DataFormat.PLAIN) self.assertFalse(field_desc.is_sparse) self.assertEqual(field_desc.shape, [])