Exemple #1
0
def getScalingDenseLayer(input_location, input_scale):
    recip_input_scale = np.reciprocal(input_scale)
    
    waux = np.diag(recip_input_scale)
    baux = -input_location*recip_input_scale
    
    dL = Dense(input_location.shape[0], activation = None, input_shape = input_location.shape)
    dL.build(input_shape = input_location.shape)
    dL.set_weights([waux, baux])
    dL.trainable = False
    return dL
Exemple #2
0
def inputsSelection(inputs_shape, ndex):
    if not hasattr(ndex,'index'):
        ndex = list(ndex)
    input_mask = np.zeros([inputs_shape[-1], len(ndex)])
    for i in range(inputs_shape[-1]):
        for v in ndex:
            if i == v:
                input_mask[i,ndex.index(v)] = 1
        
    dL = Dense(len(ndex), activation = None, input_shape = inputs_shape, 
               use_bias = False)
    dL.build(input_shape = inputs_shape)
    dL.set_weights([input_mask])
    dL.trainable = False
    return dL
Exemple #3
0
class ContextualConvexDense(ContextualDense):
    """
    Contextual Dense layer that generates weights using a convex combination
    of Dense models from a dictionary.
    """
    def __init__(
        self,
        units,
        dict_size,
        dict_kernel_initializer="glorot_uniform",
        dict_bias_initializer="glorot_uniform",
        dict_kernel_regularizer=None,
        dict_bias_regularizer=None,
        dict_kernel_constraint=None,
        dict_bias_constraint=None,
        selector_use_bias=True,
        **kwargs,
    ):
        super(ContextualConvexDense, self).__init__(units, **kwargs)

        # Regularizers and constraints for the weight generator.
        self.dict_size = dict_size
        self.dict_kernel_initializer = initializers.get(
            dict_kernel_initializer)
        self.dict_bias_initializer = initializers.get(dict_bias_initializer)
        self.dict_kernel_regularizer = regularizers.get(
            dict_kernel_regularizer)
        self.dict_bias_regularizer = regularizers.get(dict_bias_regularizer)
        self.dict_kernel_constraint = constraints.get(dict_kernel_constraint)
        self.dict_bias_constraint = constraints.get(dict_bias_constraint)
        self.selector_use_bias = selector_use_bias

        # Contextual (soft) model selector from the dictionary.
        self.selector = Dense(
            self.dict_size,
            use_bias=self.selector_use_bias,
            activation="softmax",
            name="selector",
        )

        # Internal.
        self.dict_weights = None

    def build_weight_generator(self, context_shape, feature_shape):
        # Build attention.
        self.selector.build(context_shape)

        # Build dictionary of weights.
        self.dict_weights = {
            "kernels":
            self.add_weight(
                "kernels",
                shape=(self.dict_size, self.feature_dim * self.units),
                initializer=self.dict_kernel_initializer,
                regularizer=self.dict_kernel_regularizer,
                constraint=self.dict_kernel_constraint,
                dtype=self.dtype,
                trainable=True,
            )
        }

        # Build dictionary of biases, if necessary.
        if self.use_bias:
            self.dict_weights["biases"] = self.add_weight(
                "biases",
                shape=(self.dict_size, self.units),
                initializer=self.dict_bias_initializer,
                regularizer=self.dict_bias_regularizer,
                constraint=self.dict_bias_constraint,
                dtype=self.dtype,
                trainable=True,
            )

        self.built = True

    def generate_contextual_weights(self, context):
        context = tf.convert_to_tensor(context)

        # Compute attention over the dictionary elements.
        attention = self.selector(context)

        # Compute contextual weights.
        contextual_weights = {
            # <float32> [batch_size, weights_dim].
            name: tf.tensordot(attention, weights, [[-1], [0]])
            for name, weights in self.dict_weights.items()
        }

        # Reshape contextual kernels appropriately.
        contextual_weights["kernels"] = tf.reshape(
            contextual_weights["kernels"], (-1, self.feature_dim, self.units))

        return contextual_weights

    def get_config(self):
        config = {
            "dict_size":
            self.dict_size,
            "dict_kernel_initializer":
            initializers.serialize(self.dict_kernel_initializer),
            "dict_bias_initializer":
            initializers.serialize(self.dict_bias_initializer),
            "dict_kernel_regularizer":
            regularizers.serialize(self.dict_kernel_regularizer),
            "dict_bias_regularizer":
            regularizers.serialize(self.dict_bias_regularizer),
            "dict_kernel_constraint":
            regularizers.serialize(self.dict_kernel_constraint),
            "dict_bias_constraint":
            regularizers.serialize(self.dict_bias_constraint),
        }
        base_config = super(ContextualConvexDense, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))
Exemple #4
0
class BiDenseLatents(HierarchicalLatents):
    """Bidirectional inference for hierarchical latent variables

  Parameters
  ----------
  layer : `keras.layers.Layer`
      the decoder layer for top-down (generative)
  encoder : `keras.layers.Layer`, optional
      the encoder layer for bottom-up (inference)
  units : int
      number of latent units
  dense_kw : `Dict[str, Any]`, optional
      keyword for initialize `Dense` layer for latents
  pool_mode : {'avg', 'max'}
      perform downsampling on images before `Dense` projection
  pool_size : int
      pooling size
  output_activation : {'str', Callable}
      last activation before residual connection
  deterministic_features : bool
      if True, concatenate deterministic features to the samples from posterior
      (or prior)
  residual_coef : float
      if greater than 0, add residual connection
  merge_normal : bool
      merge two normal distribution
  """
    def __init__(self,
                 layer: Layer,
                 encoder: Optional[Layer] = None,
                 units: int = 32,
                 dense_kw: Optional[Dict[str, Any]] = None,
                 pool_mode: Literal['avg', 'max'] = 'avg',
                 pool_size: Optional[int] = None,
                 output_activation: Union[None, 'str', Callable[[Any],
                                                                Any]] = None,
                 deterministic_features: bool = True,
                 residual_coef: float = 1.0,
                 merge_normal: bool = False,
                 **kwargs):
        super().__init__(layer=layer, name=kwargs.pop('name', None))
        if encoder is not None:
            encoder._old_call = encoder.call
            encoder.call = MethodType(_call, encoder)
        self.encoder = encoder
        self.residual_coef = residual_coef
        self.deterministic_features = deterministic_features
        if output_activation is None and hasattr(self.layer, 'activation'):
            output_activation = self.layer.activation
        self.output_activation = keras.activations.get(output_activation)
        # === 1. for creating layer
        self._network_kw = dict(units=2 * units)
        if dense_kw is not None:
            self._network_kw.update(dense_kw)
        # === 2. distribution
        if merge_normal:
            self._merge_normal = MergeNormal()
        else:
            self._merge_normal = None
        # === 2. others
        self._latents_shape = None
        self._dense_prior = None
        self._dense_posterior = None
        self._dense_deter = None
        self._dense_out = None
        self._dist_prior = None
        self._dist_posterior = None
        # === 3. util layers
        self.concat = keras.layers.Concatenate(axis=-1)
        self.flatten = keras.layers.Flatten()
        if pool_size is not None and pool_size > 1 and self.input_ndim > 2:
            self.pooling = _NDIMS_POOL[self.input_ndim][pool_mode](
                pool_size, name='Pooling')
            self.unpooling = _NDIMS_UNPOOL[self.input_ndim](pool_size,
                                                            name='Unpooling')
        else:
            self.pooling = Activation('linear', name='Pooling')
            self.unpooling = Activation('linear', name='Unpooling')

    @property
    def is_inference(self) -> bool:
        return (not self._is_sampling and self.encoder is not None
                and hasattr(self.encoder, '_last_outputs')
                and self.encoder._last_outputs is not None)

    def build(self, input_shape=None):
        super().build(input_shape)
        if self._disable:
            return
        org_decoder_shape = self.layer.compute_output_shape(input_shape)
        # === 0. pooling
        self.pooling.build(org_decoder_shape)
        pool_decoder_shape = self.pooling.compute_output_shape(
            org_decoder_shape)
        decoder_shape = self.flatten.compute_output_shape(pool_decoder_shape)
        # === 1. create projection layer
        if self.encoder is not None:
            self._dense_posterior = Dense(**self._network_kw,
                                          name='DensePosterior')
            # posterior projection
            shape = self.concat.compute_output_shape(
                [decoder_shape, decoder_shape])
            self._dense_posterior.build(shape)
        # prior projection
        self._dense_prior = Dense(**self._network_kw, name='DensePrior')
        self._dense_prior.build(decoder_shape)
        # deterministic projection
        kw = dict(self._network_kw)
        kw['units'] /= 2
        if self.deterministic_features:
            self._dense_deter = Dense(**kw, name='DenseDeterministic')
            self._dense_deter.build(decoder_shape)
        # === 2. create distribution
        # compute the parameter shape for the distribution
        params_shape = self._dense_prior.compute_output_shape(decoder_shape)

        self._dist_posterior = DistributionLambda(
            make_distribution_fn=partial(_create_dist,
                                         event_ndims=len(params_shape) - 1,
                                         dtype=self.dtype),
            name=f'{self.name}_posterior')
        self._dist_posterior.build(params_shape)
        self._dist_prior = DistributionLambda(make_distribution_fn=partial(
            _create_dist, event_ndims=len(params_shape) - 1, dtype=self.dtype),
                                              name=f'{self.name}_prior')
        self._dist_prior.build(params_shape)
        # dynamically infer the shape
        latents_shape = tf.convert_to_tensor(
            self._dist_posterior(keras.layers.Input(params_shape[1:]))).shape
        self._latents_shape = latents_shape[1:]
        if self.deterministic_features:
            deter_shape = self._dense_deter.compute_output_shape(decoder_shape)
            latents_shape = self.concat.compute_output_shape(
                [deter_shape, latents_shape])

        # === 3. final output affine
        if self.residual_coef > 0:
            units = int(np.prod(pool_decoder_shape[1:]))
            layers = [
                Dense(units),
                Reshape(pool_decoder_shape[1:]), self.unpooling
            ]
            if self.input_ndim > 2:
                conv, _ = _NDIMS_CONV[self.input_ndim]
                layers.append(conv(org_decoder_shape[-1], 3, 1,
                                   padding='same'))
            self._dense_out = keras.Sequential(layers, name='DenseOutput')
            self._dense_out.build(latents_shape)

    def call(self, inputs, training=None, mask=None, **kwargs):
        # === 1. call the layer
        hidden_d = super().call(inputs, training=training, mask=mask, **kwargs)
        if self._disable:
            return hidden_d
        # === 2. project and create the distribution
        flat_hd = self.flatten(self.pooling(hidden_d))
        prior = self._dist_prior(self._dense_prior(flat_hd))
        self._prior = prior
        # === 3. inference
        dist = prior
        if self.is_inference:
            hidden_e = self.encoder._last_outputs
            # just stop inference if there is no Encoder state
            tf.debugging.assert_equal(
                tf.shape(hidden_e), tf.shape(hidden_d),
                f'Shape of inference {hidden_e.shape} and '
                f'generative {hidden_d.shape} mismatch. '
                f'Change to sampling mode if possible')
            # (Kingma 2016) use add, we concat here
            h = self.concat([hidden_e, hidden_d])
            posterior = self._dist_posterior(
                self._dense_posterior(self.flatten(self.pooling(h))))
            # (Maaloe 2016) merging two Normal distribution
            if self._merge_normal is not None:
                posterior = self._merge_normal([posterior, prior])
            self._posterior = posterior
            dist = posterior
        # === 4. output
        outputs = tf.convert_to_tensor(dist)
        if self.deterministic_features:
            hidden_deter = self._dense_deter(flat_hd)
            outputs = self.concat([outputs, hidden_deter])
        if self.residual_coef > 0.:
            outputs = self._dense_out(outputs)
            outputs = self.output_activation(outputs)
            outputs = outputs + self.residual_coef * hidden_d
        return outputs
class DistributionDense(Layer):
  """ Using `Dense` layer to parameterize the tensorflow_probability
  `Distribution`

  Parameters
  ----------
  event_shape : List[int]
      distribution event shape, by default ()
  posterior : {str, DistributionLambda, Callable[..., Distribution]}
      Instrution for creating the posterior distribution, could be one of
      the following:
      - string : alias of the distribution, e.g. 'normal', 'mvndiag', etc.
      - DistributionLambda : an instance or type.
      - Callable : a callable that accept a Tensor as inputs and return a Distribution.
  posterior_kwargs : Dict[str, Any], optional
      keywords arguments for initialize the DistributionLambda if a type is
      given as posterior.
  prior : Union[Distribution, Callable[[], Distribution]]
      prior Distribution, or a callable which return a prior.
  autoregressive: bool
      using maksed autoregressive dense network, by default False
  dropout : float, optional
      dropout on the dense layer, by default 0.0
  projection : bool, optional
      enable dense layers for projecting the inputs into parameters for distribution,
      by default True
  flatten_inputs : bool, optional
      flatten to 2D, by default False
  units : Optional[int], optional
      explicitly given total number of distribution parameters, by default None

  Return
  -------
  `tensorflow_probability.Distribution`
  """

  def __init__(
      self,
      event_shape: Union[int, Sequence[int]] = (),
      units: Optional[int] = None,
      posterior: Union[str, DistributionLambda,
                       Callable[[Tensor], Distribution]] = 'normal',
      posterior_kwargs: Optional[Dict[str, Any]] = None,
      prior: Optional[Union[Distribution, Callable[[], Distribution]]] = None,
      convert_to_tensor_fn: Callable[
        [Distribution], Tensor] = Distribution.sample,
      activation: Union[str, Callable[[Tensor], Tensor]] = 'linear',
      autoregressive: bool = False,
      use_bias: bool = True,
      kernel_initializer: Union[str, Initializer] = 'glorot_normal',
      bias_initializer: Union[str, Initializer] = 'zeros',
      kernel_regularizer: Union[None, str, Regularizer] = None,
      bias_regularizer: Union[None, str, Regularizer] = None,
      activity_regularizer: Union[None, str, Regularizer] = None,
      kernel_constraint: Union[None, str, Constraint] = None,
      bias_constraint: Union[None, str, Constraint] = None,
      dropout: float = 0.0,
      projection: bool = True,
      flatten_inputs: bool = False,
      **kwargs,
  ):
    if posterior_kwargs is None:
      posterior_kwargs = {}
    ## store init arguments (this is not intended for serialization but
    # for cloning)
    init_args = dict(locals())
    del init_args['self']
    del init_args['__class__']
    del init_args['kwargs']
    init_args.update(kwargs)
    self._init_args = init_args
    ## check prior type
    assert isinstance(prior, (Distribution, type(None))) or callable(prior), \
      ("prior can only be None or instance of Distribution, DistributionLambda"
       f",  but given: {prior}-{type(prior)}")
    self._projection = bool(projection)
    self.flatten_inputs = bool(flatten_inputs)
    ## duplicated event_shape or event_size in posterior_kwargs
    posterior_kwargs = dict(posterior_kwargs)
    if 'event_shape' in posterior_kwargs:
      event_shape = posterior_kwargs.pop('event_shape')
    if 'event_size' in posterior_kwargs:
      event_shape = posterior_kwargs.pop('event_size')
    convert_to_tensor_fn = posterior_kwargs.pop('convert_to_tensor_fn',
                                                Distribution.sample)
    ## process the posterior
    if isinstance(posterior, DistributionLambda):  # instance
      self._posterior_layer = posterior
      self._posterior_class = type(posterior)
    elif (inspect.isclass(posterior) and
          issubclass(posterior, DistributionLambda)):  # subclass
      self._posterior_layer = None
      self._posterior_class = posterior
    elif isinstance(posterior, string_types):  # alias
      from odin.bay.distribution_alias import parse_distribution
      self._posterior_layer = None
      self._posterior_class, _ = parse_distribution(posterior)
    elif callable(posterior):  # callable
      if isinstance(posterior, LambdaType):
        posterior = tf.autograph.experimental.do_not_convert(posterior)
      self._posterior_layer = DistributionLambda(
        make_distribution_fn=posterior,
        convert_to_tensor_fn=convert_to_tensor_fn)
      self._posterior_class = type(posterior)
    else:
      raise ValueError('posterior could be: string, DistributionLambda, '
                       f'callable or type; but give: {posterior}')
    self._posterior = posterior
    self._posterior_kwargs = posterior_kwargs
    self._posterior_sample_shape = ()
    ## create layers
    self._convert_to_tensor_fn = convert_to_tensor_fn
    self._prior = prior
    self._event_shape = event_shape
    self._dropout = dropout
    ## set more descriptive name
    name = kwargs.pop('name', None)
    if name is None:
      posterior_name = (posterior if isinstance(posterior, string_types) else
                        posterior.__class__.__name__)
      name = f'dense_{posterior_name}'
    kwargs['name'] = name
    ## params_size could be static function or method
    if not projection:
      self._params_size = 0
    else:
      if not hasattr(self.posterior_layer, 'params_size'):
        if units is None:
          raise ValueError(
            f'posterior layer of type {type(self.posterior_layer)} '
            "doesn't has method params_size, number of parameters "
            'must be provided as `units` argument, but given: None')
        self._params_size = int(units)
      else:
        self._params_size = int(
          _params_size(self.posterior_layer, event_shape,
                       **self._posterior_kwargs))
    super().__init__(**kwargs)
    self.autoregressive = autoregressive
    if autoregressive:
      from odin.bay.layers.autoregressive_layers import AutoregressiveDense
      self._dense = AutoregressiveDense(
        params=self._params_size / self.event_size,
        event_shape=(self.event_size,),
        activation=activation,
        use_bias=use_bias,
        kernel_initializer=kernel_initializer,
        bias_initializer=bias_initializer,
        kernel_regularizer=kernel_regularizer,
        bias_regularizer=bias_regularizer,
        activity_regularizer=activity_regularizer,
        kernel_constraint=kernel_constraint,
        bias_constraint=bias_constraint)
    else:
      self._dense = Dense(units=self._params_size,
                          activation=activation,
                          use_bias=use_bias,
                          kernel_initializer=kernel_initializer,
                          bias_initializer=bias_initializer,
                          kernel_regularizer=kernel_regularizer,
                          bias_regularizer=bias_regularizer,
                          activity_regularizer=activity_regularizer,
                          kernel_constraint=kernel_constraint,
                          bias_constraint=bias_constraint)
    # store the distribution from last call,
    self._most_recently_built_distribution = None
    spec = inspect.getfullargspec(self.posterior_layer)
    self._posterior_call_kw = set(spec.args + spec.kwonlyargs)

  def build(self, input_shape=None) -> 'DistributionDense':
    self._dense.build(input_shape)
    return self

  def compute_output_shape(self, input_shape):
    return tuple(input_shape[:-1]) + as_tuple(self.event_shape)

  @property
  def params_size(self) -> int:
    return self._params_size

  @property
  def projection(self) -> bool:
    return self._projection and self.params_size > 0

  @property
  def is_binary(self) -> bool:
    return is_binary_distribution(self.posterior_layer)

  @property
  def is_discrete(self) -> bool:
    return is_discrete_distribution(self.posterior_layer)

  @property
  def is_mixture(self) -> bool:
    return is_mixture_distribution(self.posterior_layer)

  @property
  def is_zero_inflated(self) -> bool:
    return is_zeroinflated_distribution(self.posterior_layer)

  @property
  def event_shape(self) -> List[int]:
    shape = self._event_shape
    if not (tf.is_tensor(shape) or isinstance(shape, tf.TensorShape)):
      shape = tf.nest.flatten(shape)
    return shape

  @property
  def event_size(self) -> int:
    return tf.cast(tf.reduce_prod(self._event_shape), tf.int32)

  @property
  def prior(self) -> Optional[Union[Distribution, Callable[[], Distribution]]]:
    return self._prior

  @prior.setter
  def prior(self,
            p: Optional[Union[Distribution, Callable[[],
                                                     Distribution]]] = None):
    self._prior = p

  def set_prior(self,
                p: Optional[Union[Distribution,
                                  Callable[[], Distribution]]] = None):
    self.prior = p
    return self

  def _sample_fn(self, dist):
    return dist.sample(sample_shape=self._posterior_sample_shape)

  @property
  def convert_to_tensor_fn(self) -> Callable[..., Tensor]:
    if self._convert_to_tensor_fn == Distribution.sample:
      return self._sample_fn
    else:
      return self._convert_to_tensor_fn

  @property
  def posterior_layer(
      self) -> Union[DistributionLambda, Callable[..., Distribution]]:
    if not isinstance(self._posterior_layer, DistributionLambda):
      self._posterior_layer = self._posterior_class(
        self._event_shape,
        convert_to_tensor_fn=self.convert_to_tensor_fn,
        **self._posterior_kwargs)
    return self._posterior_layer

  @property
  def posterior(self) -> Distribution:
    r""" Return the most recent parametrized distribution,
    i.e. the result from the last `call` """
    return self._most_recently_built_distribution

  def sample(self, sample_shape=(), seed=None):
    """ Sample from prior distribution """
    if self._prior is None:
      raise RuntimeError("prior hasn't been provided for the %s" %
                         self.__class__.__name__)
    return self.prior.sample(sample_shape=sample_shape, seed=seed)

  def __call__(self, inputs, *args, **kwargs):
    distribution = super().__call__(inputs, *args, **kwargs)
    return distribution

  def call(self,
           inputs,
           training=None,
           sample_shape=(),
           projection=None,
           **kwargs):
    ## NOTE: a 2D inputs is important here, but we don't want to flatten
    # automatically
    if self.flatten_inputs:
      inputs = tf.reshape(inputs, (tf.shape(inputs)[0], -1))
    params = inputs
    ## do not use tf.cond here, it infer the wrong shape when
    # trying to build the layer in Graph mode.
    projection = projection if projection is not None else self.projection
    if projection:
      params = self._dense(params)
      if self.autoregressive:
        params = tf.concat(tf.unstack(params, axis=-1), axis=-1)
    ## applying dropout
    if self._dropout > 0:
      params = bk.dropout(params, p_drop=self._dropout, training=training)
    ## create posterior distribution
    self._posterior_sample_shape = sample_shape
    kw = dict()
    if 'training' in self._posterior_call_kw:
      kw['training'] = training
    if 'sample_shape' in self._posterior_call_kw:
      kw['sample_shape'] = sample_shape
    for k, v in kwargs.items():
      if k in self._posterior_call_kw:
        kw[k] = v
    posterior = self.posterior_layer(params, **kw)
    # tensorflow tries to serialize the distribution, which raise exception
    # when saving the graphs, to avoid this, store it as non-tracking list.
    with trackable.no_automatic_dependency_tracking_scope(self):
      # self._no_dependency
      self._most_recently_built_distribution = posterior
    ## NOTE: all distribution has the method kl_divergence, so we cannot use it
    posterior.KL_divergence = KLdivergence(
      posterior, prior=self.prior,
      sample_shape=None)  # None mean reuse sampled data here
    return posterior

  def kl_divergence(self,
                    prior=None,
                    analytic=True,
                    sample_shape=1,
                    reverse=True):
    """ KL(q||p) where `p` is the posterior distribution returned from last
    call

    Parameters
    -----------
    prior : instance of `tensorflow_probability.Distribution`
        prior distribution of the latent
    analytic : `bool` (default=`True`). Using closed form solution for
        calculating divergence, otherwise, sampling with MCMC
    reverse : `bool`.
        If `True`, calculate `KL(q||p)` else `KL(p||q)`
    sample_shape : `int` (default=`1`)
        number of MCMC sample if `analytic=False`

    Returns
    --------
      kullback_divergence : Tensor [sample_shape, batch_size, ...]
    """
    if prior is None:
      prior = self._prior
    assert isinstance(prior, Distribution), "prior is not given!"
    if self.posterior is None:
      raise RuntimeError(
        "DistributionDense must be called to create the distribution before "
        "calculating the kl-divergence.")

    kullback_div = kl_divergence(q=self.posterior,
                                 p=prior,
                                 analytic=bool(analytic),
                                 reverse=reverse,
                                 q_sample=sample_shape)
    if analytic:
      kullback_div = tf.expand_dims(kullback_div, axis=0)
      if isinstance(sample_shape, Number) and sample_shape > 1:
        ndims = kullback_div.shape.ndims
        kullback_div = tf.tile(kullback_div, [sample_shape] + [1] * (ndims - 1))
    return kullback_div

  def log_prob(self, x):
    r""" Calculating the log probability (i.e. log likelihood) using the last
    distribution returned from call """
    return self.posterior.log_prob(x)

  def __repr__(self):
    return self.__str__()

  def __str__(self):
    if self.prior is None:
      prior = 'None'
    elif isinstance(self.prior, Distribution):
      prior = (
        f"<{self.prior.__class__.__name__} "
        f"batch:{self.prior.batch_shape} event:{self.prior.event_shape}>")
    else:
      prior = str(self.prior)
    posterior = self._posterior_class.__name__
    if hasattr(self, 'input_shape'):
      inshape = self.input_shape
    else:
      inshape = None
    if hasattr(self, 'output_shape'):
      outshape = self.output_shape
    else:
      outshape = None
    return (
      f"<'{self.name}' autoregr:{self.autoregressive} proj:{self.projection} "
      f"in:{inshape} out:{outshape} event:{self.event_shape} "
      f"#params:{self._params_size} post:{posterior} prior:{prior} "
      f"dropout:{self._dropout:.2f} kw:{self._posterior_kwargs}>")

  def get_config(self) -> dict:
    return dict(self._init_args)
Exemple #6
0
class ContextualMixture(Layer):
    """
    The layer that represents a contextual mixture.

    Internally, the gating mechanism is represented by a dense layer built on
    the contextual inputs which softly combines the outputs of each expert.
    Each expert can be an arbitrary network as long as it is compatible with
    the features as inputs.
    """
    def __init__(
        self,
        experts,
        activity_regularizer=None,
        gate_use_bias=False,
        gate_kernel_initializer="glorot_uniform",
        gate_bias_initializer="zeros",
        gate_kernel_regularizer=None,
        gate_bias_regularizer=None,
        gate_kernel_constraint=None,
        gate_bias_constraint=None,
        **kwargs,
    ):
        super(ContextualMixture, self).__init__(
            activity_regularizer=regularizers.get(activity_regularizer),
            **kwargs,
        )

        # Sanity check.
        self.experts = tuple(experts)
        for expert in self.experts:
            if not isinstance(expert, tf.Module):
                raise ValueError(
                    "Please initialize `{name}` expert with a "
                    "`tf.Module` instance. You passed: {input}".format(
                        name=self.__class__.__name__, input=expert))

        # Regularizers and constraints for the weight generator.
        self.gate_use_bias = gate_use_bias
        self.gate_kernel_initializer = initializers.get(
            gate_kernel_initializer)
        self.gate_bias_initializer = initializers.get(gate_bias_initializer)
        self.gate_kernel_regularizer = regularizers.get(
            gate_kernel_regularizer)
        self.gate_bias_regularizer = regularizers.get(gate_bias_regularizer)
        self.gate_kernel_constraint = constraints.get(gate_kernel_constraint)
        self.gate_bias_constraint = constraints.get(gate_bias_constraint)

        self.supports_masking = True
        self.input_spec = [
            InputSpec(min_ndim=2),  # Context input spec.
            InputSpec(min_ndim=2),  # Features input spec.
        ]

        # Instantiate contextual attention for gating.
        self.gating_attention = Dense(len(self.experts),
                                      activation=tf.nn.softmax,
                                      name="attention")

        # Internals.
        self.context_shape = None
        self.feature_shape = None

    def _build_sanity_check(self, context_shape, feature_shape):
        dtype = tf.dtypes.as_dtype(self.dtype or K.floatx())
        if not (dtype.is_floating or dtype.is_complex):
            raise TypeError("Unable to build `ContextualDense` layer with "
                            f"non-floating point dtype {dtype}.")
        context_shape = tensor_shape.TensorShape(context_shape)
        if tensor_shape.dimension_value(context_shape[-1]) is None:
            raise ValueError("The last dimension of the context "
                             "should be defined. Found `None`.")
        feature_shape = tensor_shape.TensorShape(feature_shape)
        if tensor_shape.dimension_value(feature_shape[-1]) is None:
            raise ValueError("The last dimension of the features "
                             "should be defined. Found `None`.")

        # Ensure that output shapes of all experts are identical.
        expected_shape = self.experts[0].compute_output_shape(feature_shape)
        for expert in self.experts[1:]:
            expert_output_shape = expert.compute_output_shape(feature_shape)
            if expert_output_shape != expected_shape:
                raise ValueError(
                    f"Experts have different output shapes! "
                    f"Expected shape: {expected_shape}. "
                    f"Found expert {expert} with shape: {expert_output_shape}."
                )

    def build(self, input_shape=None):
        context_shape, feature_shape = input_shape

        # Sanity checks.
        self._build_sanity_check(context_shape, feature_shape)

        # Build gating attention.
        self.gating_attention.build(context_shape)

        # Build the wrapped layer.
        for expert in self.experts:
            expert.build(feature_shape)

        self.built = True

    def call(self, inputs, **kwargs):
        context, features = inputs
        context = tf.convert_to_tensor(context)
        features = tf.convert_to_tensor(features)

        # Compute outputs for each expert.
        # <float32> [batch_size, num_experts, units].
        expert_outputs = tf.stack(
            [expert(features) for expert in self.experts], axis=1)

        # Compute gating attention.
        # <float32> [batch_size, num_experts, 1].
        gating_attention = tf.expand_dims(self.gating_attention(context), -1)

        # Compute output as attention-weighted linear combination.
        # <float32> [batch_size, units].
        outputs = tf.reduce_sum(gating_attention * expert_outputs, axis=1)

        return outputs

    def compute_output_shape(self, input_shape):
        return self.experts[0].compute_output_shape(input_shape)

    def get_config(self):
        config = {
            "gate_use_bias":
            self.gate_use_bias,
            "gate_kernel_initializer":
            initializers.serialize(self.gate_kernel_initializer),
            "gate_bias_initializer":
            initializers.serialize(self.gate_bias_initializer),
            "gate_kernel_regularizer":
            regularizers.serialize(self.gate_kernel_regularizer),
            "gate_bias_regularizer":
            regularizers.serialize(self.gate_bias_regularizer),
            "gate_kernel_constraint":
            regularizers.serialize(self.gate_kernel_constraint),
            "gate_bias_constraint":
            regularizers.serialize(self.gate_bias_constraint),
        }
        base_config = super(ContextualMixture, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))