def _transform(self, input: Tensor, input_log_det: Optional[Tensor], inverse: bool, compute_log_det: bool) -> Tuple[Tensor, Optional[Tensor]]: if inverse: output = input * 2.0 + 1 event_ndims = self.x_event_ndims else: output = (input - 1.0) * 0.5 event_ndims = self.y_event_ndims if compute_log_det: if inverse: output_log_det = float_scalar_like(-math.log(2.), output) else: output_log_det = float_scalar_like(math.log(2.), output) for axis in int_range(-event_ndims, 0): output_log_det = output_log_det * output.shape[axis] if input_log_det is not None: output_log_det = output_log_det + input_log_det else: output_log_det: Optional[Tensor] = None return output, output_log_det
def log_prob_fn(t): log_px = distribution.log_prob(t.transform_origin.tensor, group_ndims=0) y, log_det = flow(t.transform_origin.tensor) # y and log |dy/dx| assert_allclose(y, t.tensor, atol=1e-4, rtol=1e-6) ctx.assertEqual( T.rank(log_det), T.rank(log_px) - (flow.get_x_event_ndims() - distribution.event_ndims)) return -log_det + T.reduce_sum( log_px, T.int_range( -(flow.get_x_event_ndims() - distribution.event_ndims), 0))
def check_x(layer): y = layer(x) y_mean, y_var = T.calculate_mean_and_var(y, axis=T.int_range( -T.rank(x), -1)) if use_bias: assert_allclose(y_mean, T.zeros_like(y_mean), atol=1e-6, rtol=1e-4) assert_allclose(y_var, T.ones_like(y_var), atol=1e-6, rtol=1e-4)
def check_scale(ctx, scale: Scale, x, pre_scale, expected_y, expected_log_det): assert(T.shape(x) == T.shape(expected_log_det)) # dimension error with pytest.raises(Exception, match=r'`rank\(input\) >= event_ndims` does not hold'): _ = scale(T.random.randn([1]), T.random.randn([1]), event_ndims=2) with pytest.raises(Exception, match=r'`rank\(input\) >= rank\(pre_scale\)` does not hold'): _ = scale(T.random.randn([1]), T.random.randn([1, 2]), event_ndims=1) with pytest.raises(Exception, match=r'The shape of `input_log_det` is not expected'): _ = scale(T.random.randn([2, 3]), T.random.randn([2, 3]), event_ndims=1, input_log_det=T.random.randn([3])) with pytest.raises(Exception, match=r'The shape of `input_log_det` is not expected'): _ = scale(T.random.randn([2, 3]), T.random.randn([2, 3]), event_ndims=2, input_log_det=T.random.randn([2])) # check call for event_ndims in range(T.rank(pre_scale), T.rank(x)): this_expected_log_det = T.reduce_sum( expected_log_det, axis=T.int_range(-event_ndims, 0)) input_log_det = T.random.randn(T.shape(this_expected_log_det)) # check no compute log_det y, log_det = scale(x, pre_scale, event_ndims=event_ndims, compute_log_det=False) assert_allclose(y, expected_y, rtol=1e-4, atol=1e-6) ctx.assertIsNone(log_det) # check compute log_det y, log_det = scale(x, pre_scale, event_ndims=event_ndims) assert_allclose(y, expected_y, rtol=1e-4, atol=1e-6) assert_allclose(log_det, this_expected_log_det, rtol=1e-4, atol=1e-6) # check compute log_det with input_log_det y, log_det = scale( x, pre_scale, event_ndims=event_ndims, input_log_det=input_log_det) assert_allclose(y, expected_y, rtol=1e-4, atol=1e-6) assert_allclose(log_det, input_log_det + this_expected_log_det, rtol=1e-4, atol=1e-6) # check inverse, no compute log_det inv_x, log_det = scale(expected_y, pre_scale, event_ndims=event_ndims, inverse=True, compute_log_det=False) assert_allclose(inv_x, x, rtol=1e-4, atol=1e-6) ctx.assertIsNone(log_det) # check inverse, compute log_det inv_x, log_det = scale(expected_y, pre_scale, event_ndims=event_ndims, inverse=True) assert_allclose(inv_x, x, rtol=1e-4, atol=1e-6) assert_allclose(log_det, -this_expected_log_det, rtol=1e-4, atol=1e-6) # check inverse, compute log_det with input_log_det inv_x, log_det = scale(expected_y, pre_scale, event_ndims=event_ndims, inverse=True, input_log_det=input_log_det) assert_allclose(inv_x, x, rtol=1e-4, atol=1e-6) assert_allclose(log_det, input_log_det - this_expected_log_det, rtol=1e-4, atol=1e-6)
def check_distribution_instance(ctx, d, event_ndims, batch_shape, min_event_ndims, max_event_ndims, log_prob_fn, transform_origin_distribution=None, transform_origin_group_ndims=None, **expected_attrs): ctx.assertLessEqual(max_event_ndims - event_ndims, d.batch_ndims) event_shape = expected_attrs.get('event_shape', None) ctx.assertEqual(d.min_event_ndims, min_event_ndims) ctx.assertEqual(d.value_ndims, len(batch_shape) + event_ndims) if event_shape is not None: ctx.assertEqual(d.value_shape, batch_shape + event_shape) ctx.assertEqual(d.batch_shape, batch_shape) ctx.assertEqual(d.batch_ndims, len(batch_shape)) ctx.assertEqual(d.event_ndims, event_ndims) ctx.assertEqual(d.event_shape, event_shape) for attr, val in expected_attrs.items(): ctx.assertEqual(getattr(d, attr), val) ctx.assertEqual( d.validate_tensors, expected_attrs.get('validate_tensors', settings.validate_tensors)) # check sample for n_samples in (None, 5): for group_ndims in (None, 0, -(event_ndims - min_event_ndims), max_event_ndims - event_ndims): for reparameterized2 in (None, True, False): if reparameterized2 and not d.reparameterized: continue # sample() sample_kwargs = {} if n_samples is not None: sample_kwargs['n_samples'] = n_samples sample_shape = [n_samples] else: sample_shape = [] if group_ndims is not None: sample_kwargs['group_ndims'] = group_ndims else: group_ndims = 0 if reparameterized2 is not None: sample_kwargs['reparameterized'] = reparameterized2 else: reparameterized2 = d.reparameterized t = d.sample(**sample_kwargs) ctx.assertEqual(t.group_ndims, group_ndims) ctx.assertEqual(t.reparameterized, reparameterized2) ctx.assertEqual(T.rank(t.tensor), d.value_ndims + len(sample_shape)) ctx.assertEqual( T.shape(t.tensor)[:(d.batch_ndims + len(sample_shape))], sample_shape + d.batch_shape) if transform_origin_distribution is not None: ctx.assertIsInstance(t.transform_origin, StochasticTensor) ctx.assertIs(t.transform_origin.distribution, transform_origin_distribution) ctx.assertIs(t.transform_origin.group_ndims, transform_origin_group_ndims) # log_prob() expected_log_prob = log_prob_fn(t) for group_ndims2 in (None, 0, -(event_ndims - min_event_ndims), max_event_ndims - event_ndims): if group_ndims2 is not None: log_prob_kwargs = {'group_ndims': group_ndims2} else: log_prob_kwargs = {} group_ndims2 = group_ndims log_prob = t.log_prob(**log_prob_kwargs) ctx.assertEqual( T.shape(log_prob), T.shape(t.tensor)[:T.rank(t.tensor) - (group_ndims2 + event_ndims)]) assert_allclose( log_prob, T.reduce_sum( expected_log_prob, T.int_range( -(group_ndims2 + (event_ndims - min_event_ndims)), 0)), rtol=1e-4, atol=1e-6, ) prob = t.prob(**log_prob_kwargs) assert_allclose(prob, T.exp(log_prob), rtol=1e-4, atol=1e-6) if transform_origin_distribution is not None: for p in (log_prob, prob): ctx.assertIsInstance(p.transform_origin, StochasticTensor) ctx.assertIs(p.transform_origin.distribution, transform_origin_distribution) ctx.assertIs(p.transform_origin.group_ndims, transform_origin_group_ndims)