Пример #1
0
    def test_deserialize_config_missing_key(self):
        config_missing_key = {
            'config': {
                # Dtype is missing and should cause a failure.
                # 'dtype': 'int32',
                'default_value': None,
                'key': 'a',
                'normalizer_fn': None,
                'shape': (2, )
            },
            'class_name': 'NumericColumn'
        }

        with self.assertRaisesRegexp(ValueError,
                                     'Invalid config:.*expected keys.*dtype'):
            serialization.deserialize_feature_column(config_missing_key)
Пример #2
0
 def from_config(cls, config, custom_objects=None, columns_by_name=None):
     """See 'FeatureColumn` base class."""
     from tensorflow.python.feature_column.serialization import \
         deserialize_feature_column  # pylint: disable=g-import-not-at-top
     _check_config_keys(config, cls._fields)
     kwargs = _standardize_and_copy_config(config)
     kwargs['categorical_column'] = deserialize_feature_column(
         config['categorical_column'], custom_objects, columns_by_name)
     return cls(**kwargs)
  def from_config(cls, config, custom_objects=None):
    # Import here to avoid circular imports.
    from tensorflow.python.feature_column import serialization  # pylint: disable=g-import-not-at-top
    config_cp = config.copy()
    columns_by_name = {}
    config_cp['feature_columns'] = [serialization.deserialize_feature_column(
        c, custom_objects, columns_by_name) for c in config['feature_columns']]
    config_cp['partitioner'] = generic_utils.deserialize_keras_object(
        config['partitioner'], custom_objects)

    return cls(**config_cp)
Пример #4
0
    def from_config(cls, config, custom_objects=None, columns_by_name=None):
        """See 'FeatureColumn` base class."""
        from tensorflow.python.feature_column.serialization import (
            deserialize_feature_column, )  # pylint: disable=g-import-not-at-top

        fc_lib._check_config_keys(config, cls._fields)
        kwargs = fc_lib._standardize_and_copy_config(config)
        kwargs["categorical_columns"] = tuple([
            deserialize_feature_column(c, custom_objects, columns_by_name)
            for c in config["categorical_columns"]
        ])

        return cls(**kwargs)
Пример #5
0
  def test_serialization(self):
    """Tests that column can be serialized."""
    def _custom_fn(input_tensor):
      return input_tensor + 42

    column = sfc.sequence_numeric_column(
        key='my-key', shape=(2,), default_value=3, dtype=dtypes.int32,
        normalizer_fn=_custom_fn)
    configs = serialization.serialize_feature_column(column)
    column = serialization.deserialize_feature_column(
        configs, custom_objects={_custom_fn.__name__: _custom_fn})
    self.assertEqual(column.key, 'my-key')
    self.assertEqual(column.shape, (2,))
    self.assertEqual(column.default_value, 3)
    self.assertEqual(column.normalizer_fn(3), 45)
    with self.assertRaisesRegex(ValueError,
                                'Instance: 0 is not a FeatureColumn'):
      serialization.serialize_feature_column(int())
Пример #6
0
 def test_deserialize_invalid_config(self):
     with self.assertRaisesRegexp(ValueError, 'Improper config format: {}'):
         serialization.deserialize_feature_column({})