示例#1
0
  def from_config(cls, config, custom_objects=None):
    """Creates a RNNModel from its config.

    Args:
      config: A Python dictionary, typically the output of `get_config`.
      custom_objects: Optional dictionary mapping names (strings) to custom
        classes or functions to be considered during deserialization.

    Returns:
      A RNNModel.
    """
    rnn_layer = keras_layers.deserialize(
        config.pop('rnn_layer'), custom_objects=custom_objects)
    sequence_feature_columns = fc.deserialize_feature_columns(
        config.pop('sequence_feature_columns'), custom_objects=custom_objects)
    context_feature_columns = config.pop('context_feature_columns', None)
    if context_feature_columns:
      context_feature_columns = fc.deserialize_feature_columns(
          context_feature_columns, custom_objects=custom_objects)
    activation = activations.deserialize(
        config.pop('activation', None), custom_objects=custom_objects)
    return cls(
        rnn_layer=rnn_layer, sequence_feature_columns=sequence_feature_columns,
        context_feature_columns=context_feature_columns, activation=activation,
        **config)
示例#2
0
 def from_config(cls, config, custom_objects=None):
     from tensorflow.python.feature_column.feature_column_lib import deserialize_feature_columns
     config_cp = config.copy()
     config_cp["columns"] = deserialize_feature_columns(config["columns"])
     config_cp["cross_columns"] = deserialize_feature_columns(
         config["cross_columns"])
     del config["columns"]
     del config["cross_columns"]
     return cls(config_cp, custom_objects=custom_objects)
示例#3
0
def deserialize_feature_columns(feature_column_configs, custom_objects=None):
  """Deserializes dict of feature column configs.

  Args:
    feature_column_configs: (dict) A dict mapping feature names to Keras feature
      column config, could be generated using `serialize_feature_columns`.
    custom_objects: (dict) Optional dictionary mapping names to custom classes
      or functions to be considered during deserialization.

  Returns:
    A dict mapping feature names to feature columns.
  """
  if not feature_column_configs:
    return {}

  feature_columns = {}
  sorted_fc_configs = sorted(six.iteritems(feature_column_configs))
  sorted_names, sorted_configs = zip(*sorted_fc_configs)

  sorted_feature_columns = fc.deserialize_feature_columns(
      sorted_configs, custom_objects=custom_objects)

  feature_columns = dict(zip(sorted_names, sorted_feature_columns))
  return feature_columns