def test_transform_bijection(batch_shape, event_shape): tf = OrderedTransform() assert tf.inv.inv is tf shape = torch.Size(batch_shape + event_shape) sample = Normal(0, 1).expand(shape).sample() tf_sample = tf(sample) inv_tf_sample = tf.inv(tf_sample) assert torch.allclose(sample, inv_tf_sample)
def test_transform_log_abs_det(batch_shape, event_shape): tf = OrderedTransform() shape = torch.Size(batch_shape + event_shape) x = torch.randn(shape, requires_grad=True) y = tf(x) log_det = tf.log_abs_det_jacobian(x, y) assert log_det.shape == batch_shape # The "log_abs_det_jacobian" above is more like a batch of log abs det # jacobians, each computed along the last (event) dimension. I'll introduce # the `cjald` function, defined above, to help compute this. log_det_actual = cjald(tf, x).det().abs().log() assert torch.allclose(log_det, log_det_actual)
def test_autograd(): predictor = torch.randn(5, requires_grad=True) order = OrderedTransform() pre_cutpoints = torch.randn(3, requires_grad=True) cutpoints = order(pre_cutpoints) data = torch.tensor([0, 1, 2, 3, 0], dtype=float) dist = OrderedLogistic(predictor, cutpoints, validate_args=True) dist.log_prob(data).sum().backward() assert predictor.grad is not None assert torch.all(predictor.grad != 0).item() assert pre_cutpoints.grad is not None assert torch.all(pre_cutpoints.grad != 0).item()