def new(params, event_shape=(), given_logits=True, validate_args=False, name='ZIBernoulliLayer'): """Create the distribution instance from a `params` vector.""" params = tf.convert_to_tensor(value=params, name='params') event_shape = dist_util.expand_to_vector( tf.convert_to_tensor(value=event_shape, name='event_shape', dtype=tf.int32), tensor_name='event_shape', ) output_shape = tf.concat( [tf.shape(input=params)[:-1], event_shape], axis=0, ) (bernoulli_params, rate_params) = tf.split(params, 2, axis=-1) bernoulli_params = tf.reshape(bernoulli_params, output_shape) bern = Bernoulli(logits=bernoulli_params if given_logits else None, probs=bernoulli_params if not given_logits else None, validate_args=validate_args) zibern = ZeroInflated(count_distribution=bern, logits=tf.reshape(rate_params, output_shape), validate_args=validate_args) return Independent(zibern, reinterpreted_batch_ndims=tf.size(input=event_shape), name=name)
def observation(distribution: str): if distribution == 'qlogistic': n_params = 2 obs = DistributionLambda( lambda params: QuantizedLogistic( *[ # loc p if i == 0 else # Ensure scales are positive and do not collapse to near-zero tf.nn.softplus(p) + tf.cast(tf.exp(-7.), tf.float32) for i, p in enumerate(tf.split(params, 2, -1)) ], low=0, high=255, inputs_domain='sigmoid', reinterpreted_batch_ndims=3), convert_to_tensor_fn=Distribution.sample, name='image') elif distribution == 'bernoulli': n_params = 1 obs = DistributionLambda(lambda params: Independent( Bernoulli(logits=params, dtype=tf.float32), len(IMAGE_SHAPE)), convert_to_tensor_fn=Distribution.sample, name='image') else: raise NotImplementedError return n_params, obs
def llk_pixels(model: VariationalModel, valid_ds: tf.data.Dataset): llk = [] for x, y in valid_ds.take(5): px, _ = _call(model, x, y, decode=True) px = as_tuple(px)[0] if hasattr(px, 'distribution'): px = px.distribution if isinstance(px, Bernoulli): px = Bernoulli(logits=px.logits) elif isinstance(px, Normal): px = Normal(loc=px.loc, scale=px.scale) elif isinstance(px, QuantizedLogistic): px = QuantizedLogistic(loc=px.loc, scale=px.scale, low=px.low, high=px.high, inputs_domain=px.inputs_domain, reinterpreted_batch_ndims=None) else: return # nothing to do llk.append(px.log_prob(x)) # average over all channels llk_image = tf.reduce_mean(tf.reduce_mean(tf.concat(llk, 0), axis=0), axis=-1) llk = tf.reshape(llk_image, -1) tf.summary.histogram('valid/llk_pixels', llk, step=model.step) # show the image heatmap of llk pixels fig = plt.figure(figsize=(3, 3)) ax = plt.gca() im = ax.pcolormesh(llk_image.numpy(), cmap='Spectral', vmin=np.min(llk), vmax=np.max(llk)) ax.axis('off') ax.margins(0.) # color bar ticks = np.linspace(np.min(llk), np.max(llk), 5) cbar = plt.colorbar(im, ax=ax, fraction=0.04, pad=0.02, ticks=ticks) cbar.ax.set_yticklabels([f'{i:.2f}' for i in ticks]) cbar.ax.tick_params(labelsize=6) plt.tight_layout() tf.summary.image('llk_heatmap', vs.plot_to_image(fig, dpi=100))
def elbo_components(self, inputs, training=None, mask=None, **kwargs): px, qz = self(inputs, training=training, mask=mask) # === 1. reconstructed information llk = {} for p, x in zip(as_tuple(px), as_tuple(inputs)): name = p.name.split('_')[1] if hasattr(p, 'distribution'): p = p.distribution if isinstance(p, Bernoulli): p = Bernoulli(logits=p.logits) elif isinstance(px, Normal): p = Normal(loc=p.loc, scale=p.scale) elif isinstance(p, QuantizedLogistic): p = QuantizedLogistic(loc=p.loc, scale=p.scale, low=p.low, high=p.high, inputs_domain=p.inputs_domain, reinterpreted_batch_ndims=None) lk = p.log_prob(x) if self.R != 0.: lk = tf.minimum(lk, self.R) lk = tf.reduce_sum(lk, tf.range(1, x.shape.rank)) llk[f'llk_{name}'] = lk # === 2. latent capacity kl = {} for q in as_tuple(qz): name = q.name.split('_')[1] kl_q = q.KL_divergence(analytic=self.analytic) if self.C > 0: zdim = int(np.prod(q.event_shape)) C = tf.constant(self.C * zdim, dtype=self.dtype) if self.random_capacity: C = C * tf.random.uniform(shape=[], minval=0., maxval=1., dtype=self.dtype) kl_q = tf.math.abs(kl_q - C) kl_q = self.beta * kl_q kl[f'kl_{name}'] = kl_q return llk, kl
def new(params, event_shape=(), dtype=None, validate_args=False, continuous=False, lims=(0.499, 0.501), name='BernoulliLayer'): """Create the distribution instance from a `params` vector.""" params = tf.convert_to_tensor(value=params, name='params') event_shape = dist_util.expand_to_vector( tf.convert_to_tensor(value=event_shape, name='event_shape', dtype_hint=tf.int32), tensor_name='event_shape', ) new_shape = tf.concat( [tf.shape(input=params)[:-1], event_shape], axis=0, ) if continuous: dist = ContinuousBernoulli(logits=tf.reshape(params, new_shape), dtype=dtype or params.dtype.base_dtype, lims=lims, validate_args=validate_args) else: dist = Bernoulli(logits=tf.reshape(params, new_shape), dtype=dtype or params.dtype.base_dtype, validate_args=validate_args) dist = Independent( dist, reinterpreted_batch_ndims=tf.size(input=event_shape), name=name, ) dist.logits = dist.distribution._logits # pylint: disable=protected-access dist.probs = dist.distribution._probs # pylint: disable=protected-access return dist
def __init__(self, count_distribution, inflated_distribution=None, logits=None, probs=None, eps=1e-8, validate_args=False, allow_nan_stats=True, name="ZeroInflated"): r"""Initialize a zero-inflated distribution. A `ZeroInflated` is defined by a zero-inflation rate (`inflated_distribution`, representing the probabilities of excess zeros) and a `Distribution` object having matching dtype, batch shape, event shape, and continuity properties (the dist). Arguments: count_distribution : A `tfp.distributions.Distribution` instance. The instance must have `batch_shape` matching the zero-inflation distribution. inflated_distribution: `tfp.distributions.Bernoulli`-like instance. Manages the probability of excess zeros, the zero-inflated rate. Must have either scalar `batch_shape` or `batch_shape` matching `count_distribution.batch_shape`. logits: An N-D `Tensor` representing the log-odds of a excess zeros A zero-inflation rate, where the probability of excess zeros is sigmoid(logits). Only one of `logits` or `probs` should be passed in. probs: An N-D `Tensor` representing the probability of a zero event. Each entry in the `Tensor` parameterizes an independent ZeroInflated distribution. Only one of `logits` or `probs` should be passed in. validate_args: Python `bool`, default `False`. If `True`, raise a runtime error if batch or event ranks are inconsistent between pi and any of the distributions. This is only checked if the ranks cannot be determined statically at graph construction time. allow_nan_stats: Boolean, default `True`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member. If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: A name for this distribution (optional). References: Liu, L. & Blei, D.M.. (2017). Zero-Inflated Exponential Family Embeddings. Proceedings of the 34th International Conference on Machine Learning, in PMLR 70:2140-2148 """ parameters = dict(locals()) with tf.compat.v1.name_scope(name) as name: # main count distribution if not isinstance(count_distribution, distribution.Distribution): raise TypeError( "count_distribution must be a Distribution instance" " but saw: %s" % count_distribution) # Zero inflation distribution if inflated_distribution is None: inflated_distribution = Bernoulli( logits=logits, probs=probs, dtype=tf.int32, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name="ZeroInflatedRate") elif not isinstance(inflated_distribution, distribution.Distribution): raise TypeError( "inflated_distribution must be a Distribution instance" " but saw: %s" % inflated_distribution) # Matching the event shape inflated_ndim = len(inflated_distribution.event_shape) count_ndim = len(count_distribution.event_shape) if inflated_ndim < count_ndim: inflated_distribution = Independent(inflated_distribution, count_ndim - inflated_ndim) self._count_distribution = count_distribution self._inflated_distribution = inflated_distribution # if self._count_distribution.batch_shape.ndims is None: raise ValueError( "Expected to know rank(batch_shape) from count_distribution" ) if self._inflated_distribution.batch_shape.ndims is None: raise ValueError( "Expected to know rank(batch_shape) from inflated_distribution" ) self._eps = tensor_util.convert_nonref_to_tensor( eps, dtype_hint=count_distribution.dtype, name='eps') # We let the zero-inflated distribution access _graph_parents since its arguably # more like a baseclass. super(ZeroInflated, self).__init__( dtype=self._count_distribution.dtype, reparameterization_type=self._count_distribution. reparameterization_type, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, name=name, )
def __init__(self, count_distribution, inflated_distribution=None, logits=None, probs=None, validate_args=False, allow_nan_stats=True, name="ZeroInflated"): """Initialize a zero-inflated distribution. A `ZeroInflated` is defined by a zero-inflation rate (`inflated_distribution`, representing the probabilities of excess zeros) and a `Distribution` object having matching dtype, batch shape, event shape, and continuity properties (the dist). Parameters ---------- count_distribution : A `tfp.distributions.Distribution` instance. The instance must have `batch_shape` matching the zero-inflation distribution. inflated_distribution: `tfp.distributions.Bernoulli`-like instance. Manages the probability of excess zeros, the zero-inflated rate. Must have either scalar `batch_shape` or `batch_shape` matching `count_distribution.batch_shape`. logits: An N-D `Tensor` representing the log-odds of a excess zeros A zero-inflation rate, where the probability of excess zeros is sigmoid(logits). Only one of `logits` or `probs` should be passed in. probs: An N-D `Tensor` representing the probability of a zero event. Each entry in the `Tensor` parameterizes an independent ZeroInflated distribution. Only one of `logits` or `probs` should be passed in. validate_args: Python `bool`, default `False`. If `True`, raise a runtime error if batch or event ranks are inconsistent between pi and any of the distributions. This is only checked if the ranks cannot be determined statically at graph construction time. allow_nan_stats: Boolean, default `True`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member. If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: A name for this distribution (optional). References ---------- Liu, L. & Blei, D.M.. (2017). Zero-Inflated Exponential Family Embeddings. Proceedings of the 34th International Conference on Machine Learning, in PMLR 70:2140-2148 """ parameters = dict(locals()) self._runtime_assertions = [] with tf.compat.v1.name_scope(name) as name: if not isinstance(count_distribution, distribution.Distribution): raise TypeError("count_distribution must be a Distribution instance" " but saw: %s" % count_distribution) self._count_distribution = count_distribution if inflated_distribution is None: inflated_distribution = Bernoulli(logits=logits, probs=probs, dtype=tf.int32, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name="ZeroInflatedRate") elif not isinstance(inflated_distribution, distribution.Distribution): raise TypeError("inflated_distribution must be a Distribution instance" " but saw: %s" % inflated_distribution) self._inflated_distribution = inflated_distribution if self._count_distribution.batch_shape.ndims is None: raise ValueError( "Expected to know rank(batch_shape) from count_disttribution") if self._inflated_distribution.batch_shape.ndims is None: raise ValueError( "Expected to know rank(batch_shape) from inflated_distribution") # create Independent Bernoulli distribution that the batch_shape # of count_distribution matching batch_shape of inflated_distribution inflated_batch_ndims = self._inflated_distribution.batch_shape.ndims count_batch_ndims = self._count_distribution.batch_shape.ndims if count_batch_ndims < inflated_batch_ndims: self._inflated_distribution = Independent( self._inflated_distribution, reinterpreted_batch_ndims=inflated_batch_ndims - count_batch_ndims, name="ZeroInflatedRate") elif count_batch_ndims > inflated_batch_ndims: raise ValueError( "count_distribution has %d-D batch_shape, which smaller" "than %d-D batch_shape of inflated_distribution" % (count_batch_ndims, inflated_batch_ndims)) # Ensure that all batch and event ndims are consistent. if validate_args: self._runtime_assertions.append( tf.assert_equal( self._count_distribution.batch_shape_tensor(), self._inflated_distribution.batch_shape_tensor(), message=( "dist batch shape must match logits|probs batch shape"))) # We let the zero-inflated distribution access _graph_parents since its arguably # more like a baseclass. reparameterization_type = [ self._count_distribution.reparameterization_type, self._inflated_distribution.reparameterization_type ] if any(i == reparameterization.NOT_REPARAMETERIZED for i in reparameterization_type): reparameterization_type = reparameterization.NOT_REPARAMETERIZED else: reparameterization_type = reparameterization.FULLY_REPARAMETERIZED super(ZeroInflated, self).__init__(dtype=self._count_distribution.dtype, reparameterization_type=reparameterization_type, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=self._count_distribution._graph_parents + self._inflated_distribution._graph_parents, name=name)
p1 = getattr(d1, name) p2 = getattr(d2, name) p = getattr(d, name) assert np.all( np.isclose(p.numpy(), tf.concat((p1, p2), axis=0).numpy())) except NotImplementedError: pass shape = (8, 2) count = np.random.randint(0, 20, size=shape).astype('float32') probs = np.random.rand(*shape).astype('float32') logits = np.random.rand(*shape).astype('float32') assert_consistent_statistics(Bernoulli(probs=probs), Bernoulli(logits=logits)) assert_consistent_statistics(Bernoulli(logits=logits), Bernoulli(logits=logits)) assert_consistent_statistics( Independent(Bernoulli(probs=probs), reinterpreted_batch_ndims=1), Independent(Bernoulli(logits=logits), reinterpreted_batch_ndims=1)) assert_consistent_statistics( NegativeBinomial(total_count=count, logits=logits), NegativeBinomial(total_count=count, probs=probs)) assert_consistent_statistics( Independent(NegativeBinomial(total_count=count, logits=logits), reinterpreted_batch_ndims=1), Independent(NegativeBinomial(total_count=count, probs=probs), reinterpreted_batch_ndims=1)) assert_consistent_statistics(