예제 #1
0
    def from_config(cls, config):
        """Creates a layer from its config.

    This method is the reverse of `get_config`, capable of instantiating the
    same layer from the config dictionary.

    Args:
      config: A Python dictionary, typically the output of `get_config`.

    Returns:
      layer: A layer instance.
    """
        config = config.copy()
        function_keys = [
            'kernel_posterior_fn',
            'kernel_posterior_tensor_fn',
            'kernel_prior_fn',
            'kernel_divergence_fn',
            'bias_posterior_fn',
            'bias_posterior_tensor_fn',
            'bias_prior_fn',
            'bias_divergence_fn',
        ]
        for function_key in function_keys:
            serial = config[function_key]
            function_type = config.pop(function_key + '_type')
            if serial is not None:
                config[function_key] = tfp_layers_util.deserialize_function(
                    serial, function_type=function_type)
        return cls(**config)
예제 #2
0
  def from_config(cls, config):
    """Creates a layer from its config.

    This method is the reverse of `get_config`, capable of instantiating the
    same layer from the config dictionary.

    Args:
      config: A Python dictionary, typically the output of `get_config`.

    Returns:
      layer: A layer instance.
    """
    config = config.copy()
    function_keys = [
        'kernel_posterior_fn',
        'kernel_posterior_tensor_fn',
        'kernel_prior_fn',
        'kernel_divergence_fn',
        'bias_posterior_fn',
        'bias_posterior_tensor_fn',
        'bias_prior_fn',
        'bias_divergence_fn',
    ]
    for function_key in function_keys:
      serial = config[function_key]
      function_type = config.pop(function_key + '_type')
      if serial is not None:
        config[function_key] = tfp_layers_util.deserialize_function(
            serial,
            function_type=function_type)
    return cls(**config)