Example #1
0
    def test_sparse_index_field_equality(self):
        f1 = sch.SparseIndexField('foo', False)
        f2 = sch.SparseIndexField('foo', False)
        f3 = sch.SparseIndexField('bar', False)

        self.assertEqual(f1, f2)
        self.assertNotEqual(f2, f3)
Example #2
0
def _sparse_column_schema_from_json(feature_dict):
    """Translate a JSON sparse feature dict into a ColumnSchema."""
    # assume there is only one value column
    value_feature = feature_dict['valueFeature'][0]
    domain = _domain_from_json(value_feature['domain'])

    index_feature_dicts = feature_dict['indexFeature']

    # int() is needed because protobuf JSON encodes int64 as string
    axes = [
        sch.Axis(int(index_feature_dict['size']))
        for index_feature_dict in index_feature_dicts
    ]

    value_field_name = value_feature['name']
    index_fields = [
        sch.SparseIndexField(index_feature_dict['name'],
                             index_feature_dict['isSorted'])
        for index_feature_dict in index_feature_dicts
    ]

    representation = sch.SparseColumnRepresentation(value_field_name,
                                                    index_fields)

    return sch.ColumnSchema(domain, axes, representation)
Example #3
0
        def preprocessing_fn(inputs):
            sparse_sum = tft.map(lambda x: tf.sparse_reduce_sum(x, axis=1),
                                 inputs['sparse'])
            sparse_copy = tft.map(
                lambda y: tf.SparseTensor(y.indices, y.values, y.dense_shape),
                inputs['sparse'])
            varlen_copy = tft.map(
                lambda y: tf.SparseTensor(y.indices, y.values, y.dense_shape),
                inputs['varlen'])

            sparse_copy.schema = sch.ColumnSchema(
                sch.LogicalColumnSchema(sch.dtype_to_domain(tf.float32),
                                        sch.LogicalShape([sch.Axis(10)])),
                sch.SparseColumnRepresentation(
                    'val_copy', [sch.SparseIndexField('idx_copy', False)]))

            return {
                'fixed': sparse_sum,  # Schema should be inferred.
                'sparse': inputs['sparse'],  # Schema manually attached above.
                'varlen': inputs['varlen'],  # Schema should be inferred.
                'sparse_copy':
                sparse_copy,  # Schema should propagate from input.
                'varlen_copy':
                varlen_copy  # Schema should propagate from input.
            }
Example #4
0
    def test_column_representation_equality(self):
        fixed1 = sch.FixedColumnRepresentation(1.1)
        fixed2 = sch.FixedColumnRepresentation(1.1)
        fixed3 = sch.FixedColumnRepresentation()

        list1 = sch.ListColumnRepresentation()
        list2 = sch.ListColumnRepresentation()

        sparse1 = sch.SparseColumnRepresentation('val', [
            sch.SparseIndexField('idx1', False),
            sch.SparseIndexField('idx2', True)
        ])
        sparse2 = sch.SparseColumnRepresentation('val', [
            sch.SparseIndexField('idx1', False),
            sch.SparseIndexField('idx2', True)
        ])
        sparse3 = sch.SparseColumnRepresentation('val', [
            sch.SparseIndexField('idx1', False),
            sch.SparseIndexField('idx2', False)
        ])

        self.assertEqual(fixed1, fixed2)
        self.assertNotEqual(fixed1, fixed3)
        self.assertNotEqual(fixed1, list1)
        self.assertNotEqual(fixed1, sparse1)

        self.assertEqual(list1, list2)
        self.assertNotEqual(list1, sparse1)

        self.assertEqual(sparse1, sparse2)
        self.assertNotEqual(sparse1, sparse3)
Example #5
0
def get_manually_created_schema():
    """Provide a test schema built from scratch using the Schema classes."""
    schema = sch.Schema()

    # FixedLenFeatures
    schema.column_schemas['fixed_bool_with_default'] = (sch.ColumnSchema(
        tf.bool, [1], sch.FixedColumnRepresentation(default_value=False)))

    schema.column_schemas['fixed_bool_without_default'] = (sch.ColumnSchema(
        tf.bool, [5], sch.FixedColumnRepresentation()))

    schema.column_schemas['fixed_int_with_default'] = (sch.ColumnSchema(
        tf.int64, [1], sch.FixedColumnRepresentation(default_value=0)))

    schema.column_schemas['fixed_categorical_int_with_range'] = (
        sch.ColumnSchema(sch.IntDomain(tf.int64, -5, 10, True), [1],
                         sch.FixedColumnRepresentation(0)))

    schema.column_schemas['fixed_categorical_int_with_vocab'] = (
        sch.ColumnSchema(
            sch.IntDomain(tf.int64, vocabulary_file='test_filename'), [1],
            sch.FixedColumnRepresentation(0)))

    schema.column_schemas['fixed_int_without_default'] = (sch.ColumnSchema(
        tf.int64, [5], sch.FixedColumnRepresentation()))

    schema.column_schemas['fixed_float_with_default'] = (sch.ColumnSchema(
        tf.float32, [1], sch.FixedColumnRepresentation(default_value=0.0)))

    schema.column_schemas['fixed_float_without_default'] = (sch.ColumnSchema(
        tf.float32, [5], sch.FixedColumnRepresentation()))

    schema.column_schemas['fixed_string_with_default'] = (sch.ColumnSchema(
        tf.string, [1],
        sch.FixedColumnRepresentation(default_value='default')))

    schema.column_schemas['fixed_string_without_default'] = (sch.ColumnSchema(
        tf.string, [5], sch.FixedColumnRepresentation()))

    schema.column_schemas['3d_fixed_int_without_default'] = (sch.ColumnSchema(
        tf.int64, [5, 6, 7], sch.FixedColumnRepresentation()))

    # VarLenFeatures
    schema.column_schemas['var_bool'] = (sch.ColumnSchema(
        tf.bool, None, sch.ListColumnRepresentation()))

    schema.column_schemas['var_int'] = (sch.ColumnSchema(
        tf.int64, None, sch.ListColumnRepresentation()))

    schema.column_schemas['var_float'] = (sch.ColumnSchema(
        tf.float32, None, sch.ListColumnRepresentation()))

    schema.column_schemas['var_string'] = (sch.ColumnSchema(
        tf.string, None, sch.ListColumnRepresentation()))

    # SparseFeatures
    schema.column_schemas['sparse_bool'] = (sch.ColumnSchema(
        tf.bool, [15],
        sch.SparseColumnRepresentation(
            'sparse_bool_value',
            [sch.SparseIndexField('sparse_bool_index', True)])))

    schema.column_schemas['sparse_int'] = (sch.ColumnSchema(
        tf.int64, [150],
        sch.SparseColumnRepresentation(
            'sparse_int_value',
            [sch.SparseIndexField('sparse_int_index', False)])))

    schema.column_schemas['sparse_float'] = (sch.ColumnSchema(
        tf.float32, [1500],
        sch.SparseColumnRepresentation(
            'sparse_float_value',
            [sch.SparseIndexField('sparse_float_index', False)])))

    schema.column_schemas['sparse_string'] = (sch.ColumnSchema(
        tf.string, [15000],
        sch.SparseColumnRepresentation(
            'sparse_string_value',
            [sch.SparseIndexField('sparse_string_index', True)])))

    return schema
Example #6
0
def get_manually_created_schema():
    """Provide a test schema built from scratch using the Schema classes."""
    schema = sch.Schema()

    # This verbose stuff may be replaced with convienience methods in the future.

    # FixedLenFeatures
    schema.column_schemas['fixed_bool_with_default'] = (sch.ColumnSchema(
        sch.LogicalColumnSchema(sch.dtype_to_domain(tf.bool),
                                sch.LogicalShape([sch.Axis(1)])),
        sch.FixedColumnRepresentation(False)))

    schema.column_schemas['fixed_bool_without_default'] = (sch.ColumnSchema(
        sch.LogicalColumnSchema(sch.dtype_to_domain(tf.bool),
                                sch.LogicalShape([sch.Axis(5)])),
        sch.FixedColumnRepresentation()))

    schema.column_schemas['fixed_int_with_default'] = (sch.ColumnSchema(
        sch.LogicalColumnSchema(sch.dtype_to_domain(tf.int64),
                                sch.LogicalShape([sch.Axis(1)])),
        sch.FixedColumnRepresentation(0)))

    schema.column_schemas['fixed_int_without_default'] = (sch.ColumnSchema(
        sch.LogicalColumnSchema(sch.dtype_to_domain(tf.int64),
                                sch.LogicalShape([sch.Axis(5)])),
        sch.FixedColumnRepresentation()))

    schema.column_schemas['fixed_float_with_default'] = (sch.ColumnSchema(
        sch.LogicalColumnSchema(sch.dtype_to_domain(tf.float32),
                                sch.LogicalShape([sch.Axis(1)])),
        sch.FixedColumnRepresentation(0.0)))

    schema.column_schemas['fixed_float_without_default'] = (sch.ColumnSchema(
        sch.LogicalColumnSchema(sch.dtype_to_domain(tf.float32),
                                sch.LogicalShape([sch.Axis(5)])),
        sch.FixedColumnRepresentation()))

    schema.column_schemas['fixed_string_with_default'] = (sch.ColumnSchema(
        sch.LogicalColumnSchema(sch.dtype_to_domain(tf.string),
                                sch.LogicalShape([sch.Axis(1)])),
        sch.FixedColumnRepresentation('default')))

    schema.column_schemas['fixed_string_without_default'] = (sch.ColumnSchema(
        sch.LogicalColumnSchema(sch.dtype_to_domain(tf.string),
                                sch.LogicalShape([sch.Axis(5)])),
        sch.FixedColumnRepresentation()))

    schema.column_schemas['3d_fixed_int_without_default'] = (sch.ColumnSchema(
        sch.LogicalColumnSchema(
            sch.dtype_to_domain(tf.int64),
            sch.LogicalShape([sch.Axis(5),
                              sch.Axis(6),
                              sch.Axis(7)])), sch.FixedColumnRepresentation()))

    # VarLenFeatures
    schema.column_schemas['var_bool'] = (sch.ColumnSchema(
        sch.LogicalColumnSchema(sch.dtype_to_domain(tf.bool),
                                sch.LogicalShape([sch.Axis(None)])),
        sch.ListColumnRepresentation()))

    schema.column_schemas['var_int'] = (sch.ColumnSchema(
        sch.LogicalColumnSchema(sch.dtype_to_domain(tf.int64),
                                sch.LogicalShape([sch.Axis(None)])),
        sch.ListColumnRepresentation()))

    schema.column_schemas['var_float'] = (sch.ColumnSchema(
        sch.LogicalColumnSchema(sch.dtype_to_domain(tf.float32),
                                sch.LogicalShape([sch.Axis(None)])),
        sch.ListColumnRepresentation()))

    schema.column_schemas['var_string'] = (sch.ColumnSchema(
        sch.LogicalColumnSchema(sch.dtype_to_domain(tf.string),
                                sch.LogicalShape([sch.Axis(None)])),
        sch.ListColumnRepresentation()))

    # SparseFeatures
    schema.column_schemas['sparse_bool'] = (sch.ColumnSchema(
        sch.LogicalColumnSchema(sch.dtype_to_domain(tf.bool),
                                sch.LogicalShape([sch.Axis(15)])),
        sch.SparseColumnRepresentation(
            'sparse_bool_value',
            [sch.SparseIndexField('sparse_bool_index', True)])))

    schema.column_schemas['sparse_int'] = (sch.ColumnSchema(
        sch.LogicalColumnSchema(sch.dtype_to_domain(tf.int64),
                                sch.LogicalShape([sch.Axis(150)])),
        sch.SparseColumnRepresentation(
            'sparse_int_value',
            [sch.SparseIndexField('sparse_int_index', False)])))

    schema.column_schemas['sparse_float'] = (sch.ColumnSchema(
        sch.LogicalColumnSchema(sch.dtype_to_domain(tf.float32),
                                sch.LogicalShape([sch.Axis(1500)])),
        sch.SparseColumnRepresentation(
            'sparse_float_value',
            [sch.SparseIndexField('sparse_float_index', False)])))

    schema.column_schemas['sparse_string'] = (sch.ColumnSchema(
        sch.LogicalColumnSchema(sch.dtype_to_domain(tf.string),
                                sch.LogicalShape([sch.Axis(15000)])),
        sch.SparseColumnRepresentation(
            'sparse_string_value',
            [sch.SparseIndexField('sparse_string_index', True)])))

    return schema