def check_unitarity(nw): with pt.Environment(num_t=1, freqdomain=True): nw.initialize() I = np.array(np.eye(nw.num_sources) + 0j, dtype=np.complex64) T = torch.tensor(np.stack([np.real(I), np.imag(I)], 0), names=["c", "s", "b"]) R = nw(T, power=False)[:, 0, 0, :, :].data.cpu().numpy() R = R[0] + 1j * R[1] J = np.dot(R, R.T.conj()) np.testing.assert_array_almost_equal(I, J)
def test_forward_with_power_false(nw, tenv): with tenv: nw(1, power=False)
def test_forward_with_detector(nw, tenv, lpdet): with tenv: nw(1, detector=lpdet)
def test_forward_with_batch_weights(gen, nw, tenv): with tenv: nw.initialize() nw(torch.rand(tenv.num_t, tenv.num_wl, nw.num_sources, 3))
def test_forward_with_different_value_for_each_source(gen, nw, tenv): with tenv: nw.initialize() nw(torch.rand(nw.num_sources, generator=gen).rename("s"))
def test_forward_with_timesource(gen, nw, tenv): with tenv: nw(torch.rand(tenv.num_t, generator=gen).rename("t"))
def test_forward_with_constant_source(nw, tenv): with tenv: nw(source=1)
def test_forward_with_uninitialized_network(nw): with pytest.raises(RuntimeError): nw(source=1)