def get_config(self):
        # Import here to avoid circular imports.
        from tensorflow.python.feature_column import serialization  # pylint: disable=g-import-not-at-top
        column_configs = serialization.serialize_feature_columns(
            self._feature_columns)
        config = {'feature_columns': column_configs}
        config['partitioner'] = generic_utils.serialize_keras_object(
            self._partitioner)

        base_config = super(  # pylint: disable=bad-super-call
            _BaseFeaturesLayer, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))
Example #2
0
 def get_config(self):
     from tensorflow.python.feature_column.serialization import serialize_feature_columns, serialize_feature_column
     config = {
         "item_embedding_size":
         self.item_embedding_size,
         "item_id_column":
         serialize_feature_column(self.item_id_column),
         "item_feature_columns":
         serialize_feature_columns(self.item_feature_columns),
         "target_id_column":
         serialize_feature_column(self.target_id_column)
     }
     base_config = super(EgesModel, self).get_config()
     return {**base_config, **config}
Example #3
0
  def test_deserialization_deduping(self):
    price = fc.numeric_column('price')
    bucketized_price = fc.bucketized_column(price, boundaries=[0, 1])

    configs = serialization.serialize_feature_columns([price, bucketized_price])

    deserialized_feature_columns = serialization.deserialize_feature_columns(
        configs)
    self.assertLen(deserialized_feature_columns, 2)
    new_price = deserialized_feature_columns[0]
    new_bucketized_price = deserialized_feature_columns[1]

    # Ensure these are not the original objects:
    self.assertIsNot(price, new_price)
    self.assertIsNot(bucketized_price, new_bucketized_price)
    # But they are equivalent:
    self.assertEqual(price, new_price)
    self.assertEqual(bucketized_price, new_bucketized_price)

    # Check that deduping worked:
    self.assertIs(new_bucketized_price.source_column, new_price)
    def deserialization_custom_objects(self):
        # Note that custom_objects is also tested extensively above per class, this
        # test ensures that the public wrappers also handle it correctly.
        def _custom_fn(input_tensor):
            return input_tensor + 42.

        price = fc.numeric_column('price', normalizer_fn=_custom_fn)

        configs = serialization.serialize_feature_columns([price])

        deserialized_feature_columns = serialization.deserialize_feature_columns(
            configs)

        self.assertLen(deserialized_feature_columns, 1)
        new_price = deserialized_feature_columns[0]

        # Ensure these are not the original objects:
        self.assertIsNot(price, new_price)
        # But they are equivalent:
        self.assertEqual(price, new_price)

        # Check that normalizer_fn points to the correct function.
        self.assertIs(new_price.normalizer_fn, _custom_fn)