def test_apply_mixture_bad_args(): target = np.zeros((3, ) + (1, 2, 3) + (3, 1, 2) + (3, )) with pytest.raises(ValueError, match='Invalid target_tensor shape'): cirq.apply_mixture( cirq.IdentityGate(3, (1, 2, 3)), cirq.ApplyMixtureArgs( target, np.zeros_like(target), np.zeros_like(target), np.zeros_like(target), (1, 2, 3), (4, 5, 6), ), default=np.array([]), ) target = np.zeros((2, 3, 2, 3)) with pytest.raises(ValueError, match='Invalid mixture qid shape'): cirq.apply_mixture( cirq.IdentityGate(2, (2, 9)), cirq.ApplyMixtureArgs( target, np.zeros_like(target), np.zeros_like(target), np.zeros_like(target), (0, 1), (2, 3), ), default=np.array([]), )
def assert_apply_mixture_returns( val: Any, rho: np.ndarray, left_axes: Iterable[int], right_axes: Optional[Iterable[int]], assert_result_is_out_buf: bool = False, expected_result: np.ndarray = None, ): out_buf, buf0, buf1 = make_buffers(rho.shape, rho.dtype) result = cirq.apply_mixture( val, args=cirq.ApplyMixtureArgs( target_tensor=rho, left_axes=left_axes, right_axes=right_axes, out_buffer=out_buf, auxiliary_buffer0=buf0, auxiliary_buffer1=buf1, ), ) if assert_result_is_out_buf: assert result is out_buf else: assert result is not out_buf np.testing.assert_array_almost_equal(result, expected_result)
def test_apply_mixture_apply_unitary_not_implemented(): class ApplyUnitaryNotImplemeneted: def _apply_unitary_(self, args: cirq.ApplyUnitaryArgs): return NotImplemented rho = np.ones((2, 2, 2, 2), dtype=np.complex128) out_buf, aux_buf0, aux_buf1 = make_buffers((2, 2, 2, 2), dtype=rho.dtype) with pytest.raises(TypeError, match='has no'): cirq.apply_mixture(ApplyUnitaryNotImplemeneted(), args=cirq.ApplyMixtureArgs( target_tensor=rho, left_axes=[1], right_axes=[3], out_buffer=out_buf, auxiliary_buffer0=aux_buf0, auxiliary_buffer1=aux_buf1))
def test_apply_mixture_no_protocols_implemented_default(): class NoProtocols: pass args = cirq.ApplyMixtureArgs(target_tensor=np.eye(2), left_axes=[0], right_axes=[1], out_buffer=None, auxiliary_buffer0=None, auxiliary_buffer1=None) result = cirq.apply_mixture(NoProtocols(), args, default='cirq') assert result == 'cirq'