def getScalingDenseLayer(input_location, input_scale): recip_input_scale = np.reciprocal(input_scale) waux = np.diag(recip_input_scale) baux = -input_location*recip_input_scale dL = Dense(input_location.shape[0], activation = None, input_shape = input_location.shape) dL.build(input_shape = input_location.shape) dL.set_weights([waux, baux]) dL.trainable = False return dL
def inputsSelection(inputs_shape, ndex): if not hasattr(ndex,'index'): ndex = list(ndex) input_mask = np.zeros([inputs_shape[-1], len(ndex)]) for i in range(inputs_shape[-1]): for v in ndex: if i == v: input_mask[i,ndex.index(v)] = 1 dL = Dense(len(ndex), activation = None, input_shape = inputs_shape, use_bias = False) dL.build(input_shape = inputs_shape) dL.set_weights([input_mask]) dL.trainable = False return dL
class ContextualConvexDense(ContextualDense): """ Contextual Dense layer that generates weights using a convex combination of Dense models from a dictionary. """ def __init__( self, units, dict_size, dict_kernel_initializer="glorot_uniform", dict_bias_initializer="glorot_uniform", dict_kernel_regularizer=None, dict_bias_regularizer=None, dict_kernel_constraint=None, dict_bias_constraint=None, selector_use_bias=True, **kwargs, ): super(ContextualConvexDense, self).__init__(units, **kwargs) # Regularizers and constraints for the weight generator. self.dict_size = dict_size self.dict_kernel_initializer = initializers.get( dict_kernel_initializer) self.dict_bias_initializer = initializers.get(dict_bias_initializer) self.dict_kernel_regularizer = regularizers.get( dict_kernel_regularizer) self.dict_bias_regularizer = regularizers.get(dict_bias_regularizer) self.dict_kernel_constraint = constraints.get(dict_kernel_constraint) self.dict_bias_constraint = constraints.get(dict_bias_constraint) self.selector_use_bias = selector_use_bias # Contextual (soft) model selector from the dictionary. self.selector = Dense( self.dict_size, use_bias=self.selector_use_bias, activation="softmax", name="selector", ) # Internal. self.dict_weights = None def build_weight_generator(self, context_shape, feature_shape): # Build attention. self.selector.build(context_shape) # Build dictionary of weights. self.dict_weights = { "kernels": self.add_weight( "kernels", shape=(self.dict_size, self.feature_dim * self.units), initializer=self.dict_kernel_initializer, regularizer=self.dict_kernel_regularizer, constraint=self.dict_kernel_constraint, dtype=self.dtype, trainable=True, ) } # Build dictionary of biases, if necessary. if self.use_bias: self.dict_weights["biases"] = self.add_weight( "biases", shape=(self.dict_size, self.units), initializer=self.dict_bias_initializer, regularizer=self.dict_bias_regularizer, constraint=self.dict_bias_constraint, dtype=self.dtype, trainable=True, ) self.built = True def generate_contextual_weights(self, context): context = tf.convert_to_tensor(context) # Compute attention over the dictionary elements. attention = self.selector(context) # Compute contextual weights. contextual_weights = { # <float32> [batch_size, weights_dim]. name: tf.tensordot(attention, weights, [[-1], [0]]) for name, weights in self.dict_weights.items() } # Reshape contextual kernels appropriately. contextual_weights["kernels"] = tf.reshape( contextual_weights["kernels"], (-1, self.feature_dim, self.units)) return contextual_weights def get_config(self): config = { "dict_size": self.dict_size, "dict_kernel_initializer": initializers.serialize(self.dict_kernel_initializer), "dict_bias_initializer": initializers.serialize(self.dict_bias_initializer), "dict_kernel_regularizer": regularizers.serialize(self.dict_kernel_regularizer), "dict_bias_regularizer": regularizers.serialize(self.dict_bias_regularizer), "dict_kernel_constraint": regularizers.serialize(self.dict_kernel_constraint), "dict_bias_constraint": regularizers.serialize(self.dict_bias_constraint), } base_config = super(ContextualConvexDense, self).get_config() return dict(list(base_config.items()) + list(config.items()))
class BiDenseLatents(HierarchicalLatents): """Bidirectional inference for hierarchical latent variables Parameters ---------- layer : `keras.layers.Layer` the decoder layer for top-down (generative) encoder : `keras.layers.Layer`, optional the encoder layer for bottom-up (inference) units : int number of latent units dense_kw : `Dict[str, Any]`, optional keyword for initialize `Dense` layer for latents pool_mode : {'avg', 'max'} perform downsampling on images before `Dense` projection pool_size : int pooling size output_activation : {'str', Callable} last activation before residual connection deterministic_features : bool if True, concatenate deterministic features to the samples from posterior (or prior) residual_coef : float if greater than 0, add residual connection merge_normal : bool merge two normal distribution """ def __init__(self, layer: Layer, encoder: Optional[Layer] = None, units: int = 32, dense_kw: Optional[Dict[str, Any]] = None, pool_mode: Literal['avg', 'max'] = 'avg', pool_size: Optional[int] = None, output_activation: Union[None, 'str', Callable[[Any], Any]] = None, deterministic_features: bool = True, residual_coef: float = 1.0, merge_normal: bool = False, **kwargs): super().__init__(layer=layer, name=kwargs.pop('name', None)) if encoder is not None: encoder._old_call = encoder.call encoder.call = MethodType(_call, encoder) self.encoder = encoder self.residual_coef = residual_coef self.deterministic_features = deterministic_features if output_activation is None and hasattr(self.layer, 'activation'): output_activation = self.layer.activation self.output_activation = keras.activations.get(output_activation) # === 1. for creating layer self._network_kw = dict(units=2 * units) if dense_kw is not None: self._network_kw.update(dense_kw) # === 2. distribution if merge_normal: self._merge_normal = MergeNormal() else: self._merge_normal = None # === 2. others self._latents_shape = None self._dense_prior = None self._dense_posterior = None self._dense_deter = None self._dense_out = None self._dist_prior = None self._dist_posterior = None # === 3. util layers self.concat = keras.layers.Concatenate(axis=-1) self.flatten = keras.layers.Flatten() if pool_size is not None and pool_size > 1 and self.input_ndim > 2: self.pooling = _NDIMS_POOL[self.input_ndim][pool_mode]( pool_size, name='Pooling') self.unpooling = _NDIMS_UNPOOL[self.input_ndim](pool_size, name='Unpooling') else: self.pooling = Activation('linear', name='Pooling') self.unpooling = Activation('linear', name='Unpooling') @property def is_inference(self) -> bool: return (not self._is_sampling and self.encoder is not None and hasattr(self.encoder, '_last_outputs') and self.encoder._last_outputs is not None) def build(self, input_shape=None): super().build(input_shape) if self._disable: return org_decoder_shape = self.layer.compute_output_shape(input_shape) # === 0. pooling self.pooling.build(org_decoder_shape) pool_decoder_shape = self.pooling.compute_output_shape( org_decoder_shape) decoder_shape = self.flatten.compute_output_shape(pool_decoder_shape) # === 1. create projection layer if self.encoder is not None: self._dense_posterior = Dense(**self._network_kw, name='DensePosterior') # posterior projection shape = self.concat.compute_output_shape( [decoder_shape, decoder_shape]) self._dense_posterior.build(shape) # prior projection self._dense_prior = Dense(**self._network_kw, name='DensePrior') self._dense_prior.build(decoder_shape) # deterministic projection kw = dict(self._network_kw) kw['units'] /= 2 if self.deterministic_features: self._dense_deter = Dense(**kw, name='DenseDeterministic') self._dense_deter.build(decoder_shape) # === 2. create distribution # compute the parameter shape for the distribution params_shape = self._dense_prior.compute_output_shape(decoder_shape) self._dist_posterior = DistributionLambda( make_distribution_fn=partial(_create_dist, event_ndims=len(params_shape) - 1, dtype=self.dtype), name=f'{self.name}_posterior') self._dist_posterior.build(params_shape) self._dist_prior = DistributionLambda(make_distribution_fn=partial( _create_dist, event_ndims=len(params_shape) - 1, dtype=self.dtype), name=f'{self.name}_prior') self._dist_prior.build(params_shape) # dynamically infer the shape latents_shape = tf.convert_to_tensor( self._dist_posterior(keras.layers.Input(params_shape[1:]))).shape self._latents_shape = latents_shape[1:] if self.deterministic_features: deter_shape = self._dense_deter.compute_output_shape(decoder_shape) latents_shape = self.concat.compute_output_shape( [deter_shape, latents_shape]) # === 3. final output affine if self.residual_coef > 0: units = int(np.prod(pool_decoder_shape[1:])) layers = [ Dense(units), Reshape(pool_decoder_shape[1:]), self.unpooling ] if self.input_ndim > 2: conv, _ = _NDIMS_CONV[self.input_ndim] layers.append(conv(org_decoder_shape[-1], 3, 1, padding='same')) self._dense_out = keras.Sequential(layers, name='DenseOutput') self._dense_out.build(latents_shape) def call(self, inputs, training=None, mask=None, **kwargs): # === 1. call the layer hidden_d = super().call(inputs, training=training, mask=mask, **kwargs) if self._disable: return hidden_d # === 2. project and create the distribution flat_hd = self.flatten(self.pooling(hidden_d)) prior = self._dist_prior(self._dense_prior(flat_hd)) self._prior = prior # === 3. inference dist = prior if self.is_inference: hidden_e = self.encoder._last_outputs # just stop inference if there is no Encoder state tf.debugging.assert_equal( tf.shape(hidden_e), tf.shape(hidden_d), f'Shape of inference {hidden_e.shape} and ' f'generative {hidden_d.shape} mismatch. ' f'Change to sampling mode if possible') # (Kingma 2016) use add, we concat here h = self.concat([hidden_e, hidden_d]) posterior = self._dist_posterior( self._dense_posterior(self.flatten(self.pooling(h)))) # (Maaloe 2016) merging two Normal distribution if self._merge_normal is not None: posterior = self._merge_normal([posterior, prior]) self._posterior = posterior dist = posterior # === 4. output outputs = tf.convert_to_tensor(dist) if self.deterministic_features: hidden_deter = self._dense_deter(flat_hd) outputs = self.concat([outputs, hidden_deter]) if self.residual_coef > 0.: outputs = self._dense_out(outputs) outputs = self.output_activation(outputs) outputs = outputs + self.residual_coef * hidden_d return outputs
class DistributionDense(Layer): """ Using `Dense` layer to parameterize the tensorflow_probability `Distribution` Parameters ---------- event_shape : List[int] distribution event shape, by default () posterior : {str, DistributionLambda, Callable[..., Distribution]} Instrution for creating the posterior distribution, could be one of the following: - string : alias of the distribution, e.g. 'normal', 'mvndiag', etc. - DistributionLambda : an instance or type. - Callable : a callable that accept a Tensor as inputs and return a Distribution. posterior_kwargs : Dict[str, Any], optional keywords arguments for initialize the DistributionLambda if a type is given as posterior. prior : Union[Distribution, Callable[[], Distribution]] prior Distribution, or a callable which return a prior. autoregressive: bool using maksed autoregressive dense network, by default False dropout : float, optional dropout on the dense layer, by default 0.0 projection : bool, optional enable dense layers for projecting the inputs into parameters for distribution, by default True flatten_inputs : bool, optional flatten to 2D, by default False units : Optional[int], optional explicitly given total number of distribution parameters, by default None Return ------- `tensorflow_probability.Distribution` """ def __init__( self, event_shape: Union[int, Sequence[int]] = (), units: Optional[int] = None, posterior: Union[str, DistributionLambda, Callable[[Tensor], Distribution]] = 'normal', posterior_kwargs: Optional[Dict[str, Any]] = None, prior: Optional[Union[Distribution, Callable[[], Distribution]]] = None, convert_to_tensor_fn: Callable[ [Distribution], Tensor] = Distribution.sample, activation: Union[str, Callable[[Tensor], Tensor]] = 'linear', autoregressive: bool = False, use_bias: bool = True, kernel_initializer: Union[str, Initializer] = 'glorot_normal', bias_initializer: Union[str, Initializer] = 'zeros', kernel_regularizer: Union[None, str, Regularizer] = None, bias_regularizer: Union[None, str, Regularizer] = None, activity_regularizer: Union[None, str, Regularizer] = None, kernel_constraint: Union[None, str, Constraint] = None, bias_constraint: Union[None, str, Constraint] = None, dropout: float = 0.0, projection: bool = True, flatten_inputs: bool = False, **kwargs, ): if posterior_kwargs is None: posterior_kwargs = {} ## store init arguments (this is not intended for serialization but # for cloning) init_args = dict(locals()) del init_args['self'] del init_args['__class__'] del init_args['kwargs'] init_args.update(kwargs) self._init_args = init_args ## check prior type assert isinstance(prior, (Distribution, type(None))) or callable(prior), \ ("prior can only be None or instance of Distribution, DistributionLambda" f", but given: {prior}-{type(prior)}") self._projection = bool(projection) self.flatten_inputs = bool(flatten_inputs) ## duplicated event_shape or event_size in posterior_kwargs posterior_kwargs = dict(posterior_kwargs) if 'event_shape' in posterior_kwargs: event_shape = posterior_kwargs.pop('event_shape') if 'event_size' in posterior_kwargs: event_shape = posterior_kwargs.pop('event_size') convert_to_tensor_fn = posterior_kwargs.pop('convert_to_tensor_fn', Distribution.sample) ## process the posterior if isinstance(posterior, DistributionLambda): # instance self._posterior_layer = posterior self._posterior_class = type(posterior) elif (inspect.isclass(posterior) and issubclass(posterior, DistributionLambda)): # subclass self._posterior_layer = None self._posterior_class = posterior elif isinstance(posterior, string_types): # alias from odin.bay.distribution_alias import parse_distribution self._posterior_layer = None self._posterior_class, _ = parse_distribution(posterior) elif callable(posterior): # callable if isinstance(posterior, LambdaType): posterior = tf.autograph.experimental.do_not_convert(posterior) self._posterior_layer = DistributionLambda( make_distribution_fn=posterior, convert_to_tensor_fn=convert_to_tensor_fn) self._posterior_class = type(posterior) else: raise ValueError('posterior could be: string, DistributionLambda, ' f'callable or type; but give: {posterior}') self._posterior = posterior self._posterior_kwargs = posterior_kwargs self._posterior_sample_shape = () ## create layers self._convert_to_tensor_fn = convert_to_tensor_fn self._prior = prior self._event_shape = event_shape self._dropout = dropout ## set more descriptive name name = kwargs.pop('name', None) if name is None: posterior_name = (posterior if isinstance(posterior, string_types) else posterior.__class__.__name__) name = f'dense_{posterior_name}' kwargs['name'] = name ## params_size could be static function or method if not projection: self._params_size = 0 else: if not hasattr(self.posterior_layer, 'params_size'): if units is None: raise ValueError( f'posterior layer of type {type(self.posterior_layer)} ' "doesn't has method params_size, number of parameters " 'must be provided as `units` argument, but given: None') self._params_size = int(units) else: self._params_size = int( _params_size(self.posterior_layer, event_shape, **self._posterior_kwargs)) super().__init__(**kwargs) self.autoregressive = autoregressive if autoregressive: from odin.bay.layers.autoregressive_layers import AutoregressiveDense self._dense = AutoregressiveDense( params=self._params_size / self.event_size, event_shape=(self.event_size,), activation=activation, use_bias=use_bias, kernel_initializer=kernel_initializer, bias_initializer=bias_initializer, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, activity_regularizer=activity_regularizer, kernel_constraint=kernel_constraint, bias_constraint=bias_constraint) else: self._dense = Dense(units=self._params_size, activation=activation, use_bias=use_bias, kernel_initializer=kernel_initializer, bias_initializer=bias_initializer, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, activity_regularizer=activity_regularizer, kernel_constraint=kernel_constraint, bias_constraint=bias_constraint) # store the distribution from last call, self._most_recently_built_distribution = None spec = inspect.getfullargspec(self.posterior_layer) self._posterior_call_kw = set(spec.args + spec.kwonlyargs) def build(self, input_shape=None) -> 'DistributionDense': self._dense.build(input_shape) return self def compute_output_shape(self, input_shape): return tuple(input_shape[:-1]) + as_tuple(self.event_shape) @property def params_size(self) -> int: return self._params_size @property def projection(self) -> bool: return self._projection and self.params_size > 0 @property def is_binary(self) -> bool: return is_binary_distribution(self.posterior_layer) @property def is_discrete(self) -> bool: return is_discrete_distribution(self.posterior_layer) @property def is_mixture(self) -> bool: return is_mixture_distribution(self.posterior_layer) @property def is_zero_inflated(self) -> bool: return is_zeroinflated_distribution(self.posterior_layer) @property def event_shape(self) -> List[int]: shape = self._event_shape if not (tf.is_tensor(shape) or isinstance(shape, tf.TensorShape)): shape = tf.nest.flatten(shape) return shape @property def event_size(self) -> int: return tf.cast(tf.reduce_prod(self._event_shape), tf.int32) @property def prior(self) -> Optional[Union[Distribution, Callable[[], Distribution]]]: return self._prior @prior.setter def prior(self, p: Optional[Union[Distribution, Callable[[], Distribution]]] = None): self._prior = p def set_prior(self, p: Optional[Union[Distribution, Callable[[], Distribution]]] = None): self.prior = p return self def _sample_fn(self, dist): return dist.sample(sample_shape=self._posterior_sample_shape) @property def convert_to_tensor_fn(self) -> Callable[..., Tensor]: if self._convert_to_tensor_fn == Distribution.sample: return self._sample_fn else: return self._convert_to_tensor_fn @property def posterior_layer( self) -> Union[DistributionLambda, Callable[..., Distribution]]: if not isinstance(self._posterior_layer, DistributionLambda): self._posterior_layer = self._posterior_class( self._event_shape, convert_to_tensor_fn=self.convert_to_tensor_fn, **self._posterior_kwargs) return self._posterior_layer @property def posterior(self) -> Distribution: r""" Return the most recent parametrized distribution, i.e. the result from the last `call` """ return self._most_recently_built_distribution def sample(self, sample_shape=(), seed=None): """ Sample from prior distribution """ if self._prior is None: raise RuntimeError("prior hasn't been provided for the %s" % self.__class__.__name__) return self.prior.sample(sample_shape=sample_shape, seed=seed) def __call__(self, inputs, *args, **kwargs): distribution = super().__call__(inputs, *args, **kwargs) return distribution def call(self, inputs, training=None, sample_shape=(), projection=None, **kwargs): ## NOTE: a 2D inputs is important here, but we don't want to flatten # automatically if self.flatten_inputs: inputs = tf.reshape(inputs, (tf.shape(inputs)[0], -1)) params = inputs ## do not use tf.cond here, it infer the wrong shape when # trying to build the layer in Graph mode. projection = projection if projection is not None else self.projection if projection: params = self._dense(params) if self.autoregressive: params = tf.concat(tf.unstack(params, axis=-1), axis=-1) ## applying dropout if self._dropout > 0: params = bk.dropout(params, p_drop=self._dropout, training=training) ## create posterior distribution self._posterior_sample_shape = sample_shape kw = dict() if 'training' in self._posterior_call_kw: kw['training'] = training if 'sample_shape' in self._posterior_call_kw: kw['sample_shape'] = sample_shape for k, v in kwargs.items(): if k in self._posterior_call_kw: kw[k] = v posterior = self.posterior_layer(params, **kw) # tensorflow tries to serialize the distribution, which raise exception # when saving the graphs, to avoid this, store it as non-tracking list. with trackable.no_automatic_dependency_tracking_scope(self): # self._no_dependency self._most_recently_built_distribution = posterior ## NOTE: all distribution has the method kl_divergence, so we cannot use it posterior.KL_divergence = KLdivergence( posterior, prior=self.prior, sample_shape=None) # None mean reuse sampled data here return posterior def kl_divergence(self, prior=None, analytic=True, sample_shape=1, reverse=True): """ KL(q||p) where `p` is the posterior distribution returned from last call Parameters ----------- prior : instance of `tensorflow_probability.Distribution` prior distribution of the latent analytic : `bool` (default=`True`). Using closed form solution for calculating divergence, otherwise, sampling with MCMC reverse : `bool`. If `True`, calculate `KL(q||p)` else `KL(p||q)` sample_shape : `int` (default=`1`) number of MCMC sample if `analytic=False` Returns -------- kullback_divergence : Tensor [sample_shape, batch_size, ...] """ if prior is None: prior = self._prior assert isinstance(prior, Distribution), "prior is not given!" if self.posterior is None: raise RuntimeError( "DistributionDense must be called to create the distribution before " "calculating the kl-divergence.") kullback_div = kl_divergence(q=self.posterior, p=prior, analytic=bool(analytic), reverse=reverse, q_sample=sample_shape) if analytic: kullback_div = tf.expand_dims(kullback_div, axis=0) if isinstance(sample_shape, Number) and sample_shape > 1: ndims = kullback_div.shape.ndims kullback_div = tf.tile(kullback_div, [sample_shape] + [1] * (ndims - 1)) return kullback_div def log_prob(self, x): r""" Calculating the log probability (i.e. log likelihood) using the last distribution returned from call """ return self.posterior.log_prob(x) def __repr__(self): return self.__str__() def __str__(self): if self.prior is None: prior = 'None' elif isinstance(self.prior, Distribution): prior = ( f"<{self.prior.__class__.__name__} " f"batch:{self.prior.batch_shape} event:{self.prior.event_shape}>") else: prior = str(self.prior) posterior = self._posterior_class.__name__ if hasattr(self, 'input_shape'): inshape = self.input_shape else: inshape = None if hasattr(self, 'output_shape'): outshape = self.output_shape else: outshape = None return ( f"<'{self.name}' autoregr:{self.autoregressive} proj:{self.projection} " f"in:{inshape} out:{outshape} event:{self.event_shape} " f"#params:{self._params_size} post:{posterior} prior:{prior} " f"dropout:{self._dropout:.2f} kw:{self._posterior_kwargs}>") def get_config(self) -> dict: return dict(self._init_args)
class ContextualMixture(Layer): """ The layer that represents a contextual mixture. Internally, the gating mechanism is represented by a dense layer built on the contextual inputs which softly combines the outputs of each expert. Each expert can be an arbitrary network as long as it is compatible with the features as inputs. """ def __init__( self, experts, activity_regularizer=None, gate_use_bias=False, gate_kernel_initializer="glorot_uniform", gate_bias_initializer="zeros", gate_kernel_regularizer=None, gate_bias_regularizer=None, gate_kernel_constraint=None, gate_bias_constraint=None, **kwargs, ): super(ContextualMixture, self).__init__( activity_regularizer=regularizers.get(activity_regularizer), **kwargs, ) # Sanity check. self.experts = tuple(experts) for expert in self.experts: if not isinstance(expert, tf.Module): raise ValueError( "Please initialize `{name}` expert with a " "`tf.Module` instance. You passed: {input}".format( name=self.__class__.__name__, input=expert)) # Regularizers and constraints for the weight generator. self.gate_use_bias = gate_use_bias self.gate_kernel_initializer = initializers.get( gate_kernel_initializer) self.gate_bias_initializer = initializers.get(gate_bias_initializer) self.gate_kernel_regularizer = regularizers.get( gate_kernel_regularizer) self.gate_bias_regularizer = regularizers.get(gate_bias_regularizer) self.gate_kernel_constraint = constraints.get(gate_kernel_constraint) self.gate_bias_constraint = constraints.get(gate_bias_constraint) self.supports_masking = True self.input_spec = [ InputSpec(min_ndim=2), # Context input spec. InputSpec(min_ndim=2), # Features input spec. ] # Instantiate contextual attention for gating. self.gating_attention = Dense(len(self.experts), activation=tf.nn.softmax, name="attention") # Internals. self.context_shape = None self.feature_shape = None def _build_sanity_check(self, context_shape, feature_shape): dtype = tf.dtypes.as_dtype(self.dtype or K.floatx()) if not (dtype.is_floating or dtype.is_complex): raise TypeError("Unable to build `ContextualDense` layer with " f"non-floating point dtype {dtype}.") context_shape = tensor_shape.TensorShape(context_shape) if tensor_shape.dimension_value(context_shape[-1]) is None: raise ValueError("The last dimension of the context " "should be defined. Found `None`.") feature_shape = tensor_shape.TensorShape(feature_shape) if tensor_shape.dimension_value(feature_shape[-1]) is None: raise ValueError("The last dimension of the features " "should be defined. Found `None`.") # Ensure that output shapes of all experts are identical. expected_shape = self.experts[0].compute_output_shape(feature_shape) for expert in self.experts[1:]: expert_output_shape = expert.compute_output_shape(feature_shape) if expert_output_shape != expected_shape: raise ValueError( f"Experts have different output shapes! " f"Expected shape: {expected_shape}. " f"Found expert {expert} with shape: {expert_output_shape}." ) def build(self, input_shape=None): context_shape, feature_shape = input_shape # Sanity checks. self._build_sanity_check(context_shape, feature_shape) # Build gating attention. self.gating_attention.build(context_shape) # Build the wrapped layer. for expert in self.experts: expert.build(feature_shape) self.built = True def call(self, inputs, **kwargs): context, features = inputs context = tf.convert_to_tensor(context) features = tf.convert_to_tensor(features) # Compute outputs for each expert. # <float32> [batch_size, num_experts, units]. expert_outputs = tf.stack( [expert(features) for expert in self.experts], axis=1) # Compute gating attention. # <float32> [batch_size, num_experts, 1]. gating_attention = tf.expand_dims(self.gating_attention(context), -1) # Compute output as attention-weighted linear combination. # <float32> [batch_size, units]. outputs = tf.reduce_sum(gating_attention * expert_outputs, axis=1) return outputs def compute_output_shape(self, input_shape): return self.experts[0].compute_output_shape(input_shape) def get_config(self): config = { "gate_use_bias": self.gate_use_bias, "gate_kernel_initializer": initializers.serialize(self.gate_kernel_initializer), "gate_bias_initializer": initializers.serialize(self.gate_bias_initializer), "gate_kernel_regularizer": regularizers.serialize(self.gate_kernel_regularizer), "gate_bias_regularizer": regularizers.serialize(self.gate_bias_regularizer), "gate_kernel_constraint": regularizers.serialize(self.gate_kernel_constraint), "gate_bias_constraint": regularizers.serialize(self.gate_bias_constraint), } base_config = super(ContextualMixture, self).get_config() return dict(list(base_config.items()) + list(config.items()))