def call(self, inputs): net = self.layer(inputs) logits, loc, unconstrained_scale = tf.split(net, 3, axis=-1) scale = tf.nn.softplus( unconstrained_scale) + tf.keras.backend.epsilon() return generated_random_variables.MixtureSameFamily( mixture_distribution=generated_random_variables.Categorical( logits=logits).distribution, components_distribution=generated_random_variables.Logistic( loc=loc, scale=scale).distribution)
def call(self, inputs): net = self.layer(inputs) logits, loc, scale = tf.split(net, 3, axis=-1) if self.logits_constraint: logits = self.logits_constraint(logits) if self.loc_constraint: loc = self.loc_constraint(loc) if self.scale_constraint: scale = self.scale_constraint(scale) return generated_random_variables.MixtureSameFamily( mixture_distribution=generated_random_variables.Categorical( logits=logits).distribution, components_distribution=generated_random_variables.Logistic( loc=loc, scale=scale).distribution)
def __call__(self, shape, dtype=None): if not self.built: self.build(shape, dtype) loc = self.loc if self.loc_constraint: loc = self.loc_constraint(loc) return generated_random_variables.Independent( generated_random_variables.MixtureSameFamily( mixture_distribution=generated_random_variables.Categorical( probs=tf.broadcast_to([[1 / self.num_components] * self.num_components], list(shape) + [self.num_components])).distribution, components_distribution=generated_random_variables. Deterministic(loc=loc).distribution).distribution, reinterpreted_batch_ndims=len(shape))