def test_apply_channel_bad_args(): target = np.zeros((3, ) + (1, 2, 3) + (3, 1, 2) + (3, )) with pytest.raises(ValueError, match='Invalid target_tensor shape'): cirq.apply_channel( cirq.IdentityGate(3, (1, 2, 3)), cirq.ApplyChannelArgs( target, np.zeros_like(target), np.zeros_like(target), np.zeros_like(target), (1, 2, 3), (4, 5, 6), ), ) target = np.zeros((2, 3, 2, 3)) with pytest.raises(ValueError, match='Invalid channel qid shape'): cirq.apply_channel( cirq.IdentityGate(2, (2, 9)), cirq.ApplyChannelArgs( target, np.zeros_like(target), np.zeros_like(target), np.zeros_like(target), (0, 1), (2, 3), ), )
def test_apply_channel_no_protocols_implemented_default(): class NoProtocols: pass args = cirq.ApplyChannelArgs(target_tensor=np.eye(2), left_axes=[0], right_axes=[1], out_buffer=None, auxiliary_buffer0=None, auxiliary_buffer1=None) result = cirq.apply_channel(NoProtocols(), args, 'cirq') assert result == 'cirq'
def test_apply_channel_apply_unitary_not_implemented(): class ApplyUnitaryNotImplemeneted: def _apply_unitary_(self, args: cirq.ApplyUnitaryArgs): return NotImplemented rho = np.ones((2, 2, 2, 2), dtype=np.complex128) out_buf, aux_buf0, aux_buf1 = make_buffers((2, 2, 2, 2), dtype=rho.dtype) with pytest.raises(TypeError): cirq.apply_channel(ApplyUnitaryNotImplemeneted(), args=cirq.ApplyChannelArgs( target_tensor=rho, left_axes=[1], right_axes=[3], out_buffer=out_buf, auxiliary_buffer0=aux_buf0, auxiliary_buffer1=aux_buf1))
def apply_channel(val, rho, left_axes, right_axes, assert_result_is_out_buf=False): out_buf, buf0, buf1 = make_buffers(rho.shape, rho.dtype) result = cirq.apply_channel(val, args=cirq.ApplyChannelArgs( target_tensor=rho, left_axes=left_axes, right_axes=right_axes, out_buffer=out_buf, auxiliary_buffer0=buf0, auxiliary_buffer1=buf1)) if assert_result_is_out_buf: assert result is out_buf else: assert result is not out_buf return result