def test_reverse_inversion_not_applied(reparam): """Assert the reverse inversion works correctly""" reparam.parameters = ['x'] reparam.prime_parameters = ['x_prime'] reparam.offsets = {'x': 1.0} reparam._edges = {'x': False} reparam.bounds = {'x': [0, 5]} x_val = np.array([[1], [2]]) x = numpy_array_to_live_points(x_val, ['x']) x_prime = numpy_array_to_live_points(np.array([3, 4]), ['x_prime']) log_j = np.zeros(2) with patch('nessai.reparameterisations.inverse_rescale_minus_one_to_one', side_effect=lambda x, *args, **kwargs: (x, np.array([5, 6]))) as f: x_out, x_prime_out, log_j_out = RescaleToBounds._reverse_inversion( reparam, x, x_prime, log_j, 'x', 'x_prime', ) assert f.call_args_list[0][1] == {'xmin': 0, 'xmax': 5} # Should be output of rescaling minus offset np.testing.assert_array_equal(x_out['x'], np.array([2, 3])) # x_prime should be the same assert x_prime_out is x_prime # Jacobian should just include jacobian from rescaling np.testing.assert_array_equal(log_j_out, np.array([5, 6]))
def test_reverse_inversion(reparam): """Assert the reverse inversion works correctly""" reparam.parameters = ['x'] reparam.prime_parameters = ['x_prime'] reparam.offsets = {'x': 1.0} reparam._edges = {'x': 'upper'} reparam.bounds = {'x': [0, 5]} x_val = np.array([[-0.7], [0.4]]) x = numpy_array_to_live_points(x_val, ['x']) x_prime = numpy_array_to_live_points(np.array([3, 4]), ['x_prime']) log_j = np.zeros(2) # Return the same value to check that the negative values are handled # correctly with patch('nessai.reparameterisations.inverse_rescale_zero_to_one', side_effect=lambda x, *args: (x, np.array([5, 6]))) as f: x_out, x_prime_out, log_j_out = RescaleToBounds._reverse_inversion( reparam, x, x_prime, log_j, 'x', 'x_prime', ) assert f.call_args_list[0][0][1] == 0.0 assert f.call_args_list[0][0][2] == 5.0 # Should be output of rescaling minus offset np.testing.assert_array_equal(x_out['x'], np.array([1.3, 1.6])) # x_prime should be the same assert x_prime_out is x_prime # Jacobian should just include jacobian from rescaling np.testing.assert_array_equal(log_j_out, np.array([5, 6]))