Exemplo n.º 1
0
  def from_config(cls, config, custom_objects=None):
    # Import here to avoid circular imports.
    from tensorflow.python.feature_column import serialization  # pylint: disable=g-import-not-at-top
    config_cp = config.copy()
    config_cp['feature_columns'] = serialization.deserialize_feature_columns(
        config['feature_columns'], custom_objects=custom_objects)
    config_cp['partitioner'] = generic_utils.deserialize_keras_object(
        config['partitioner'], custom_objects)

    return cls(**config_cp)
Exemplo n.º 2
0
  def test_deserialization_deduping(self):
    price = fc.numeric_column('price')
    bucketized_price = fc.bucketized_column(price, boundaries=[0, 1])

    configs = serialization.serialize_feature_columns([price, bucketized_price])

    deserialized_feature_columns = serialization.deserialize_feature_columns(
        configs)
    self.assertLen(deserialized_feature_columns, 2)
    new_price = deserialized_feature_columns[0]
    new_bucketized_price = deserialized_feature_columns[1]

    # Ensure these are not the original objects:
    self.assertIsNot(price, new_price)
    self.assertIsNot(bucketized_price, new_bucketized_price)
    # But they are equivalent:
    self.assertEqual(price, new_price)
    self.assertEqual(bucketized_price, new_bucketized_price)

    # Check that deduping worked:
    self.assertIs(new_bucketized_price.source_column, new_price)
Exemplo n.º 3
0
    def deserialization_custom_objects(self):
        # Note that custom_objects is also tested extensively above per class, this
        # test ensures that the public wrappers also handle it correctly.
        def _custom_fn(input_tensor):
            return input_tensor + 42.

        price = fc.numeric_column('price', normalizer_fn=_custom_fn)

        configs = serialization.serialize_feature_columns([price])

        deserialized_feature_columns = serialization.deserialize_feature_columns(
            configs)

        self.assertLen(deserialized_feature_columns, 1)
        new_price = deserialized_feature_columns[0]

        # Ensure these are not the original objects:
        self.assertIsNot(price, new_price)
        # But they are equivalent:
        self.assertEqual(price, new_price)

        # Check that normalizer_fn points to the correct function.
        self.assertIs(new_price.normalizer_fn, _custom_fn)