def testLDJRatio(self): q = tfb.JointMap({ 'a': tfb.Exp(), 'b': tfb.Scale(2.), 'c': tfb.Shift(3.) }) p = tfb.JointMap({ 'a': tfb.Exp(), 'b': tfb.Scale(3.), 'c': tfb.Shift(4.) }) a = np.asarray([[[1, 2], [2, 3]]], dtype=np.float32) # shape=[1, 2, 2] b = np.asarray([[0, 4]], dtype=np.float32) # shape=[1, 2] c = np.asarray([[5, 6]], dtype=np.float32) # shape=[1, 2] x = {'a': a, 'b': b, 'c': c} y = {'a': a + 1, 'b': b + 1, 'c': c + 1} event_ndims = {'a': 1, 'b': 0, 'c': 0} fldj_ratio_true = p.forward_log_det_jacobian( x, event_ndims) - q.forward_log_det_jacobian(y, event_ndims) fldj_ratio = ldj_ratio.forward_log_det_jacobian_ratio( p, x, q, y, event_ndims) self.assertAllClose(fldj_ratio_true, fldj_ratio) ildj_ratio_true = p.inverse_log_det_jacobian( x, event_ndims) - q.inverse_log_det_jacobian(y, event_ndims) ildj_ratio = ldj_ratio.inverse_log_det_jacobian_ratio( p, x, q, y, event_ndims) self.assertAllClose(ildj_ratio_true, ildj_ratio) event_ndims = {'a': 1, 'b': 2, 'c': 0} fldj_ratio_true = p.forward_log_det_jacobian( x, event_ndims) - q.forward_log_det_jacobian(y, event_ndims) fldj_ratio = ldj_ratio.forward_log_det_jacobian_ratio( p, x, q, y, event_ndims) self.assertAllClose(fldj_ratio_true, fldj_ratio) ildj_ratio_true = p.inverse_log_det_jacobian( x, event_ndims) - q.inverse_log_det_jacobian(y, event_ndims) ildj_ratio = ldj_ratio.inverse_log_det_jacobian_ratio( p, x, q, y, event_ndims) self.assertAllClose(ildj_ratio_true, ildj_ratio)
def _ildj_ratio_chain(p, x, q, y): """Sum-of-diffs ILDJRatio for Chains.""" if len(p.bijectors) != len(q.bijectors): raise ValueError('Mismatched lengths of bijectors: `p` has ' f'{len(p.bijectors)} but `q` has {len(q.bijectors)}.') ratios = [] for p, q in zip(p.bijectors, q.bijectors): ratios.append(ldj_ratio.inverse_log_det_jacobian_ratio( p, x, q, y, p.inverse_min_event_ndims)) x, y = p.inverse(x), q.inverse(y) return tf.add_n(ratios)
def _ildj_ratio_composition(p, x, q, y, event_ndims, p_kwargs, q_kwargs): """Composition ILDJ ratio.""" p_bijectors_with_metadata = p._get_bijectors_with_metadata( # pylint: disable=protected-access x, event_ndims, forward=False, **p_kwargs) q_bijectors_with_metadata = q._get_bijectors_with_metadata( # pylint: disable=protected-access y, event_ndims, forward=False, **q_kwargs) if len(p_bijectors_with_metadata) != len(q_bijectors_with_metadata): raise ValueError( f'Composition "{p.name}" and "{q.name}" have different numbers of ' f'component bijectors: {len(p_bijectors_with_metadata)} != ' f'{len(q_bijectors_with_metadata)}.') # We do a running sum for the purpose of dtype inference. ldj_ratio_sum = tf.zeros([], dtype=tf.float32) assertions = [] for p_bm, q_bm in zip(p_bijectors_with_metadata, q_bijectors_with_metadata): ldj_ratio = ldj_ratio_lib.inverse_log_det_jacobian_ratio( p=p_bm.bijector, x=p_bm.x, q=q_bm.bijector, y=q_bm.x, event_ndims=p_bm.x_event_ndims, p_kwargs=p_bm.kwargs, q_kwargs=q_bm.kwargs) ldj_ratio = tf.convert_to_tensor(ldj_ratio, dtype_hint=ldj_ratio_sum.dtype) if not dtype_util.is_floating(ldj_ratio.dtype): raise TypeError( f'Nested bijector "{p_bm.bijector.name}" of Composition "{p.name}" ' f'and bijector "{q_bm.bijector.name}" of Composition "{q.name}" ' f'returned ILDJ ratio with a non-floating dtype: {ldj_ratio.dtype}' ) ldj_ratio_sum = _max_precision_sum(ldj_ratio_sum, ldj_ratio) assertions.extend(p_bm.assertions) assertions.extend(q_bm.assertions) with tf.control_dependencies(assertions): return tf.identity(ldj_ratio_sum, name='ildj_ratio')
def _transformed_log_prob_ratio(p, x, q, y): """Computes p.log_prob(x) - q.log_prob(y) for p and q both TDs.""" x_ = p.bijector.inverse(x) y_ = q.bijector.inverse(y) base_log_prob_ratio = log_prob_ratio.log_prob_ratio( p.distribution, x_, q.distribution, y_) event_ndims = tf.nest.map_structure( ps.rank_from_shape, p.event_shape_tensor, tf.nest.map_structure(tensorshape_util.merge_with, p.event_shape, q.event_shape)) ildj_ratio = ldj_ratio.inverse_log_det_jacobian_ratio( p.bijector, x, q.bijector, y, event_ndims) return base_log_prob_ratio + tf.cast(ildj_ratio, base_log_prob_ratio.dtype)