def testDegenerateSplines(self): bijector = tfb.RationalQuadraticSpline([], [], 1, validate_args=True) xs = np.linspace(-2, 2, 20, dtype=np.float32) self.assertAllClose(xs, self.evaluate(bijector.forward(xs))) self.assertAllClose( 0, self.evaluate(bijector.forward_log_det_jacobian(xs, event_ndims=1))) self.assertAllClose( np.zeros_like(xs), self.evaluate(bijector.forward_log_det_jacobian(xs, event_ndims=0))) bijector = tfb.RationalQuadraticSpline([2.], [2.], [], validate_args=True) xs = np.linspace(-2, 2, 20, dtype=np.float32) self.assertAllClose(xs, self.evaluate(bijector.forward(xs))) self.assertAllClose( 0, self.evaluate(bijector.forward_log_det_jacobian(xs, event_ndims=1))) self.assertAllClose( np.zeros_like(xs), self.evaluate(bijector.forward_log_det_jacobian(xs, event_ndims=0)))
def testVerifiesBroadcastingStatic(self): with self.assertRaisesRegex(ValueError, '`bin_heights` must broadcast'): tfb.RationalQuadraticSpline([[2, 1, .5]] * 2, [[.5, 2, 1]] * 3, [.3, 2]) with self.assertRaisesRegex(ValueError, 'non-scalar `knot_slopes` must broadcast'): tfb.RationalQuadraticSpline([2, 1, .5], [.5, 2, 1], [.3, 2, .5])
def testAssertsNonPositiveSlope(self): with self.assertRaisesOpError('`knot_slopes` must be positive'): bijector = tfb.RationalQuadraticSpline(bin_widths=[.1, .2, 1], bin_heights=[1, .2, .1], knot_slopes=[-.5, 1], validate_args=True) self.evaluate(bijector.forward([.3])) with self.assertRaisesOpError('`knot_slopes` must be positive'): bijector = tfb.RationalQuadraticSpline(bin_widths=[.1, .2, 1], bin_heights=[1, .2, .1], knot_slopes=[1, 0.], validate_args=True) self.evaluate(bijector.forward([.3]))
def testAssertsNonPositiveBinSizes(self): with self.assertRaisesOpError('`bin_widths` must be positive'): bijector = tfb.RationalQuadraticSpline(bin_widths=[.3, .2, -.1], bin_heights=[.1, .2, .1], knot_slopes=[.4, .5], validate_args=True) self.evaluate(bijector.forward([.3])) with self.assertRaisesOpError('`bin_heights` must be positive'): bijector = tfb.RationalQuadraticSpline(bin_widths=[.3, .2, .1], bin_heights=[.5, 0, .1], knot_slopes=[.3, .7], validate_args=True) self.evaluate(bijector.forward([.3]))
def __call__(self, x, nunits): if not self._built: def _bin_positions(x): out_shape = ps.concat((ps.shape(x)[:-1], (nunits, self._nbins)), 0) x = tf.reshape(x, out_shape) return tf.math.softmax(x, axis=-1) * (2 - self._nbins * 1e-2) + 1e-2 def _slopes(x): out_shape = tf.concat(( ps.shape(x)[:-1], (nunits, self._nbins - 1)), 0) x = tf.reshape(x, out_shape) return tf.math.softplus(x) + 1e-2 self._bin_widths = tf.keras.layers.Dense( nunits * self._nbins, activation=_bin_positions, name='w') self._bin_heights = tf.keras.layers.Dense( nunits * self._nbins, activation=_bin_positions, name='h') self._knot_slopes = tf.keras.layers.Dense( nunits * (self._nbins - 1), activation=_slopes, name='s') self._built = True return tfb.RationalQuadraticSpline( bin_widths=self._bin_widths(x), bin_heights=self._bin_heights(x), knot_slopes=self._knot_slopes(x))
def testTheoreticalFldjSimple(self): bijector = tfb.RationalQuadraticSpline( bin_widths=[1., 1], bin_heights=[np.sqrt(.5), 2 - np.sqrt(.5)], knot_slopes=1) self.assertEqual(tf.float64, bijector.dtype) dim = 5 x = np.linspace(-1.05, 1.05, num=2 * dim, dtype=np.float64).reshape(2, dim) y = self.evaluate(bijector.forward(x)) bijector_test_util.assert_bijective_and_finite( bijector, x, y, eval_func=self.evaluate, event_ndims=0, inverse_event_ndims=0, rtol=1e-5) fldj = bijector.forward_log_det_jacobian(x, event_ndims=0) fldj_theoretical = bijector_test_util.get_fldj_theoretical( bijector, x, event_ndims=0) self.assertAllClose( self.evaluate(fldj_theoretical), self.evaluate(fldj), atol=1e-5, rtol=1e-5)
def rq_splines(draw, batch_shape=None, dtype=tf.float32): if batch_shape is None: batch_shape = draw(tfp_hps.shapes()) lo = draw(hps.floats(min_value=-5, max_value=.5)) hi = draw(hps.floats(min_value=-.5, max_value=5)) lo, hi = min(lo, hi), max(lo, hi) + .2 hp.note('lo, hi: {!r}'.format((lo, hi))) constraints = dict( bin_widths=functools.partial(bijector_hps.spline_bin_size_constraint, hi=hi, lo=lo, dtype=dtype), bin_heights=functools.partial(bijector_hps.spline_bin_size_constraint, hi=hi, lo=lo, dtype=dtype), knot_slopes=functools.partial(bijector_hps.spline_slope_constraint, dtype=dtype)) params = draw( tfp_hps.broadcasting_params(batch_shape, params_event_ndims=dict(bin_widths=1, bin_heights=1, knot_slopes=1), constraint_fn_for=constraints.get)) hp.note('params: {!r}'.format(params)) return tfb.RationalQuadraticSpline(range_min=lo, validate_args=draw(hps.booleans()), **params)
def testAssertsMismatchedSums(self): with self.assertRaisesOpError(r'`sum\(bin_widths, axis=-1\)` must equal ' r'`sum\(bin_heights, axis=-1\)`'): bijector = tfb.RationalQuadraticSpline( bin_widths=[.2, .1, .5], bin_heights=[.1, .3, 5.4], knot_slopes=[.3, .5], validate_args=True) self.evaluate(bijector.forward([.3]))
def __call__(self, x, nunits): if not self._built: self._bin_widths = tf.keras.layers.Dense( nunits * self._nbins, activation=self._bin_positions, name='w') self._bin_heights = tf.keras.layers.Dense( nunits * self._nbins, activation=self._bin_positions, name='h') self._knot_slopes = tf.keras.layers.Dense( nunits * (self._nbins - 1), activation=self._slopes, name='s') self._built = True return tfb.RationalQuadraticSpline( bin_widths=self._bin_widths(x), bin_heights=self._bin_heights(x), knot_slopes=self._knot_slopes(x))
def f(bin_sizes, slopes): return tfb.RationalQuadraticSpline( bin_sizes, bin_sizes, slopes, validate_args=True).forward(bin_sizes)