def _default_prior(event_shape, posterior, prior, posterior_kwargs): if not isinstance(event_shape, (Sequence, MutableSequence, tf.TensorShape)): raise ValueError("event_shape must be list of integer but given: " f"{event_shape} type: {type(event_shape)}") if isinstance(prior, (Distribution, DistributionLambda, Callable)): return prior elif not isinstance(prior, (string_types, type(None))): raise ValueError("prior must be string or instance of " f"Distribution or DistributionLambda, but given: {prior}") # no prior given layer, dist = parse_distribution(posterior) if isinstance(prior, dict): kw = dict(prior) prior = None else: kw = {} event_size = int(np.prod(event_shape)) ## helper function def _kwargs(**args): for k, v in args.items(): if k not in kw: kw[k] = v return kw ## Normal if layer == obl.GaussianLayer: prior = obd.Independent( obd.Normal(**_kwargs(loc=tf.zeros(shape=event_shape), scale=tf.ones(shape=event_shape))), reinterpreted_batch_ndims=1, ) ## Multivariate Normal elif issubclass(layer, obl.MultivariateNormalLayer): cov = layer._partial_kwargs['covariance'] if cov == 'diag': # diagonal covariance loc = tf.zeros(shape=event_shape) if tf.rank(loc) == 0: loc = tf.expand_dims(loc, axis=-1) prior = obd.MultivariateNormalDiag( **_kwargs(loc=loc, scale_identity_multiplier=1.)) else: # low-triangle covariance bijector = tfp.bijectors.FillScaleTriL( diag_bijector=tfp.bijectors.Identity(), diag_shift=1e-5) size = tf.reduce_prod(event_shape) loc = tf.zeros(shape=[size]) scale_tril = bijector.forward(tf.ones(shape=[size * (size + 1) // 2])) prior = obd.MultivariateNormalTriL( **_kwargs(loc=loc, scale_tril=scale_tril)) ## Log Normal elif layer == obl.LogNormalLayer: prior = obd.Independent( obd.LogNormal(**_kwargs(loc=tf.zeros(shape=event_shape), scale=tf.ones(shape=event_shape))), reinterpreted_batch_ndims=1, ) ## mixture elif issubclass(layer, obl.MixtureGaussianLayer): if hasattr(layer, '_partial_kwargs'): cov = layer._partial_kwargs['covariance'] else: cov = 'none' n_components = int(posterior_kwargs.get('n_components', 2)) if cov == 'diag': scale_shape = [n_components, event_size] fn = lambda l, s: obd.MultivariateNormalDiag(loc=l, scale_diag=tf.nn.softplus(s)) elif cov == 'none': scale_shape = [n_components, event_size] fn = lambda l, s: obd.Independent( obd.Normal(loc=l, scale=tf.math.softplus(s)), reinterpreted_batch_ndims=1, ) elif cov in ('full', 'tril'): scale_shape = [n_components, event_size * (event_size + 1) // 2] fn = lambda l, s: obd.MultivariateNormalTriL( loc=l, scale_tril=tfp.bijectors.FillScaleTriL(diag_shift=1e-5) (tf.math.softplus(s))) loc = tf.cast(tf.fill([n_components, event_size], 0.), dtype=tf.float32) log_scale = tf.cast(tf.fill(scale_shape, np.log(np.expm1(1.))), dtype=tf.float32) p = 1. / n_components mixture_logits = tf.cast(tf.fill([n_components], np.log(p / (1 - p))), dtype=tf.float32) prior = obd.MixtureSameFamily( components_distribution=fn(loc, log_scale), mixture_distribution=obd.Categorical(logits=mixture_logits)) ## discrete elif dist in (obd.OneHotCategorical, obd.Categorical) or \ layer == obl.RelaxedOneHotCategoricalLayer: p = 1. / event_size prior = dist(**_kwargs(logits=[np.log(p / (1 - p))] * event_size), dtype=tf.float32) elif dist == obd.Dirichlet: prior = dist(**_kwargs(concentration=[1.] * event_size)) elif dist == obd.Bernoulli: prior = obd.Independent( obd.Bernoulli(**_kwargs(logits=np.zeros(event_shape)), dtype=tf.float32), reinterpreted_batch_ndims=len(event_shape), ) ## other return prior
posteriors_info = [ ('gaussian', 'mvndiag', 'mvntril'), ( D.Sample(D.Normal(loc=0., scale=1.), sample_shape=encoded_size, name='independent'), D.MultivariateNormalDiag(loc=tf.zeros(encoded_size), scale_diag=tf.ones(encoded_size), name='mvndiag'), D.MultivariateNormalTriL(loc=tf.zeros(encoded_size), scale_tril=bj.FillScaleTriL()(tf.ones( encoded_size * (encoded_size + 1) // 2)), name='mvntril'), D.MixtureSameFamily( components_distribution=D.MultivariateNormalDiag( loc=tf.zeros([10, encoded_size]), scale_diag=tf.ones([10, encoded_size])), mixture_distribution=D.Categorical(logits=tf.fill([10], 1.0 / 10)), name='gmm10'), D.MixtureSameFamily(components_distribution=D.MultivariateNormalDiag( loc=tf.zeros([100, encoded_size]), scale_diag=tf.ones([100, encoded_size])), mixture_distribution=D.Categorical( logits=tf.fill([100], 1.0 / 100)), name='gmm100'), ), ('identity', 'relu', 'softplus', 'softplus1'), ] # =========================================================================== # Main
def _default_prior(event_shape, posterior, prior, posterior_kwargs): if isinstance(prior, obd.Distribution): return prior layer, dist = parse_distribution(posterior) if isinstance(prior, dict): kw = dict(prior) prior = None else: kw = {} event_size = int(np.prod(event_shape)) ## helper function def _kwargs(**args): for k, v in args.items(): if k not in kw: kw[k] = v return kw ## Normal if layer == obl.GaussianLayer: prior = obd.Independent( obd.Normal(**_kwargs(loc=tf.zeros(shape=event_shape), scale=tf.ones(shape=event_shape))), 1) ## Multivariate Normal elif issubclass(layer, obl.MultivariateNormalLayer): cov = layer._partial_kwargs['covariance'] if cov == 'diag': # diagonal covariance loc = tf.zeros(shape=event_shape) if tf.rank(loc) == 0: loc = tf.expand_dims(loc, axis=-1) prior = obd.MultivariateNormalDiag( **_kwargs(loc=loc, scale_identity_multiplier=1.)) else: # low-triangle covariance bijector = tfp.bijectors.FillScaleTriL( diag_bijector=tfp.bijectors.Identity(), diag_shift=1e-5) size = tf.reduce_prod(event_shape) loc = tf.zeros(shape=[size]) scale_tril = bijector.forward( tf.ones(shape=[size * (size + 1) // 2])) prior = obd.MultivariateNormalTriL( **_kwargs(loc=loc, scale_tril=scale_tril)) ## Log Normal elif layer == obl.LogNormalLayer: prior = obd.Independent( obd.LogNormal(**_kwargs(loc=tf.zeros(shape=event_shape), scale=tf.ones(shape=event_shape))), 1) ## mixture elif issubclass(layer, obl.MixtureGaussianLayer): if hasattr(layer, '_partial_kwargs'): cov = layer._partial_kwargs['covariance'] else: cov = 'none' n_components = int(posterior_kwargs.get('n_components', 2)) if cov == 'diag': scale_shape = [n_components, event_size] fn = lambda l, s: obd.MultivariateNormalDiag( loc=l, scale_diag=tf.nn.softplus(s)) elif cov == 'none': scale_shape = [n_components, event_size] fn = lambda l, s: obd.Independent( obd.Normal(loc=l, scale=tf.math.softplus(s)), 1) elif cov in ('full', 'tril'): scale_shape = [n_components, event_size * (event_size + 1) // 2] fn = lambda l, s: obd.MultivariateNormalTriL( loc=l, scale_tril=tfp.bijectors.FillScaleTriL(diag_shift=1e-5) (tf.math.softplus(s))) loc = tf.cast(tf.fill([n_components, event_size], 0.), dtype=tf.float32) log_scale = tf.cast(tf.fill(scale_shape, np.log(np.expm1(1.))), dtype=tf.float32) mixture_logits = tf.cast(tf.fill([n_components], 1.), dtype=tf.float32) prior = obd.MixtureSameFamily( components_distribution=fn(loc, log_scale), mixture_distribution=obd.Categorical(logits=mixture_logits)) ## discrete elif dist in (obd.OneHotCategorical, obd.Categorical) or \ layer == obl.RelaxedOneHotCategoricalLayer: prior = dist(**_kwargs(logits=np.log([1. / event_size] * event_size), dtype=tf.float32)) elif dist == obd.Dirichlet: prior = dist(**_kwargs(concentration=[1.] * event_size)) elif dist == obd.Bernoulli: prior = obd.Independent( obd.Bernoulli(**_kwargs(logits=np.full(event_shape, np.log(0.5)), dtype=tf.float32)), len(event_shape)) ## other return prior