def test_parameterized_repeat_side_effects_when_not_using_rep_ids(): q = cirq.LineQubit(0) op = cirq.CircuitOperation( cirq.FrozenCircuit( cirq.X(q).with_classical_controls('c'), cirq.measure(q, key='m')), repetitions=sympy.Symbol('a'), use_repetition_ids=False, ) assert cirq.control_keys(op) == {cirq.MeasurementKey('c')} assert cirq.parameter_names(op.with_params({'a': 1})) == {'a'} assert set(map(str, cirq.measurement_key_objs(op))) == {'m'} assert cirq.measurement_key_names(op) == {'m'} assert cirq.measurement_key_names( cirq.with_measurement_key_mapping(op, {'m': 'm2'})) == {'m2'} with pytest.raises( ValueError, match='Cannot unroll circuit due to nondeterministic repetitions'): op.mapped_circuit() with pytest.raises( ValueError, match='Cannot unroll circuit due to nondeterministic repetitions'): cirq.decompose(op) with pytest.raises(ValueError, match='repetition ids with parameterized repetitions'): op.with_repetition_ids(['x', 'y']) with pytest.raises(ValueError, match='repetition ids with parameterized repetitions'): op.repeat(repetition_ids=['x', 'y'])
def test_keys_under_parent_path(): a = cirq.LineQubit(0) op1 = cirq.CircuitOperation(cirq.FrozenCircuit(cirq.measure(a, key='A'))) assert cirq.measurement_key_names(op1) == {'A'} op2 = op1.with_key_path(('B', )) assert cirq.measurement_key_names(op2) == {'B:A'} op3 = op2.repeat(2) assert cirq.measurement_key_names(op3) == {'B:0:A', 'B:1:A'}
def test_measurement_keys(): a = cirq.NamedQubit('a') b = cirq.NamedQubit('b') m = cirq.Moment(cirq.X(a), cirq.X(b)) assert cirq.measurement_key_names(m) == set() assert not cirq.is_measurement(m) m2 = cirq.Moment(cirq.measure(a, b, key='foo')) assert cirq.measurement_key_names(m2) == {'foo'} assert cirq.is_measurement(m2)
def test_decompose_repeated_nested_measurements(): # Details of this test described at # https://tinyurl.com/measurement-repeated-circuitop#heading=h.sbgxcsyin9wt. a = cirq.LineQubit(0) op1 = (cirq.CircuitOperation(cirq.FrozenCircuit(cirq.measure( a, key='A'))).with_measurement_key_mapping({ 'A': 'B' }).repeat(2, ['zero', 'one'])) op2 = (cirq.CircuitOperation( cirq.FrozenCircuit(cirq.measure(a, key='P'), op1)).with_measurement_key_mapping({ 'B': 'C', 'P': 'Q' }).repeat(2, ['zero', 'one'])) op3 = (cirq.CircuitOperation( cirq.FrozenCircuit(cirq.measure(a, key='X'), op2)).with_measurement_key_mapping({ 'C': 'D', 'X': 'Y' }).repeat(2, ['zero', 'one'])) expected_measurement_keys_in_order = [ 'zero:Y', 'zero:zero:Q', 'zero:zero:zero:D', 'zero:zero:one:D', 'zero:one:Q', 'zero:one:zero:D', 'zero:one:one:D', 'one:Y', 'one:zero:Q', 'one:zero:zero:D', 'one:zero:one:D', 'one:one:Q', 'one:one:zero:D', 'one:one:one:D', ] assert cirq.measurement_key_names(op3) == set( expected_measurement_keys_in_order) expected_circuit = cirq.Circuit() for key in expected_measurement_keys_in_order: expected_circuit.append( cirq.measure(a, key=cirq.MeasurementKey.parse_serialized(key))) assert cirq.Circuit(cirq.decompose(op3)) == expected_circuit assert cirq.measurement_key_names(expected_circuit) == set( expected_measurement_keys_in_order) # Verify that mapped_circuit gives the same operations. assert op3.mapped_circuit(deep=True) == expected_circuit
def test_with_measurement_key_mapping(): a = cirq.LineQubit(0) op = cirq.measure(a, key='m') remap_op = cirq.with_measurement_key_mapping(op, {'m': 'k'}) assert cirq.measurement_key_names(remap_op) == {'k'} assert cirq.with_measurement_key_mapping(op, {'x': 'k'}) is op
def test_with_key_path_prefix(): a = cirq.LineQubit(0) op = cirq.measure(a, key='m') remap_op = cirq.with_key_path_prefix(op, ('a', 'b')) assert cirq.measurement_key_names(remap_op) == {'a:b:m'} assert cirq.with_key_path_prefix(remap_op, tuple()) is remap_op assert cirq.with_key_path_prefix(op, tuple()) is op assert cirq.with_key_path_prefix(cirq.X(a), ('a', 'b')) is NotImplemented
def test_measurement_key_mapping(): class MultiKeyGate: def __init__(self, keys): self._keys = set(keys) def _measurement_key_names_(self): return self._keys def _with_measurement_key_mapping_(self, key_map): if not all(key in key_map for key in self._keys): raise ValueError('missing keys') return MultiKeyGate([key_map[key] for key in self._keys]) assert cirq.measurement_key_names(MultiKeyGate([])) == set() assert cirq.measurement_key_names(MultiKeyGate(['a'])) == {'a'} mkg_ab = MultiKeyGate(['a', 'b']) assert cirq.measurement_key_names(mkg_ab) == {'a', 'b'} mkg_cd = cirq.with_measurement_key_mapping(mkg_ab, {'a': 'c', 'b': 'd'}) assert cirq.measurement_key_names(mkg_cd) == {'c', 'd'} mkg_ac = cirq.with_measurement_key_mapping(mkg_ab, {'a': 'a', 'b': 'c'}) assert cirq.measurement_key_names(mkg_ac) == {'a', 'c'} mkg_ba = cirq.with_measurement_key_mapping(mkg_ab, {'a': 'b', 'b': 'a'}) assert cirq.measurement_key_names(mkg_ba) == {'a', 'b'} with pytest.raises(ValueError): cirq.with_measurement_key_mapping(mkg_ab, {'a': 'c'}) assert cirq.with_measurement_key_mapping(cirq.X, {'a': 'c'}) is NotImplemented mkg_cdx = cirq.with_measurement_key_mapping(mkg_ab, {'a': 'c', 'b': 'd', 'x': 'y'}) assert cirq.measurement_key_names(mkg_cdx) == {'c', 'd'}
def test_mapped_circuit_keeps_keys_under_parent_path(): q = cirq.LineQubit(0) op1 = cirq.CircuitOperation( cirq.FrozenCircuit( cirq.measure(q, key='A'), cirq.measure_single_paulistring(cirq.X(q), key='B'), cirq.MixedUnitaryChannel.from_mixture(cirq.bit_flip(0.5), key='C').on(q), cirq.KrausChannel.from_channel(cirq.phase_damp(0.5), key='D').on(q), ) ) op2 = op1.with_key_path(('X',)) assert cirq.measurement_key_names(op2.mapped_circuit()) == {'X:A', 'X:B', 'X:C', 'X:D'}
def test_measurement_key_path(): class MultiKeyGate: def __init__(self, keys): self._keys = set([cirq.MeasurementKey.parse_serialized(key) for key in keys]) def _measurement_key_names_(self): return {str(key) for key in self._keys} def _with_key_path_(self, path): return MultiKeyGate([str(key._with_key_path_(path)) for key in self._keys]) assert cirq.measurement_key_names(MultiKeyGate([])) == set() assert cirq.measurement_key_names(MultiKeyGate(['a'])) == {'a'} mkg_ab = MultiKeyGate(['a', 'b']) assert cirq.measurement_key_names(mkg_ab) == {'a', 'b'} mkg_cd = cirq.with_key_path(mkg_ab, ('c', 'd')) assert cirq.measurement_key_names(mkg_cd) == {'c:d:a', 'c:d:b'} assert cirq.with_key_path(cirq.X, ('c', 'd')) is NotImplemented
def test_tagged_measurement(): assert not cirq.is_measurement(cirq.GlobalPhaseOperation(coefficient=-1.0).with_tags('tag0')) a = cirq.LineQubit(0) op = cirq.measure(a, key='m').with_tags('tag') assert cirq.is_measurement(op) remap_op = cirq.with_measurement_key_mapping(op, {'m': 'k'}) assert remap_op.tags == ('tag',) assert cirq.is_measurement(remap_op) assert cirq.measurement_key_names(remap_op) == {'k'} assert cirq.with_measurement_key_mapping(op, {'x': 'k'}) == op
def test_measurement_key_enumerable_deprecated(): class Deprecated: def _measurement_key_objs_(self): return [cirq.MeasurementKey('key')] def _measurement_key_names_(self): return ['key'] with cirq.testing.assert_deprecated('frozenset', deadline='v0.16'): assert cirq.measurement_key_objs( Deprecated()) == {cirq.MeasurementKey('key')} with cirq.testing.assert_deprecated('frozenset', deadline='v0.16'): assert cirq.measurement_key_names(Deprecated()) == {'key'}
def test_with_measurement_keys(): a, b = cirq.LineQubit.range(2) circuit = cirq.FrozenCircuit(cirq.X(a), cirq.measure(b, key='mb'), cirq.measure(a, key='ma')) op_base = cirq.CircuitOperation(circuit) op_with_keys = op_base.with_measurement_key_mapping({'ma': 'pa', 'x': 'z'}) assert op_with_keys.base_operation() == op_base assert op_with_keys.measurement_key_map == {'ma': 'pa'} assert cirq.measurement_key_names(op_with_keys) == {'pa', 'mb'} assert cirq.with_measurement_key_mapping(op_base, {'ma': 'pa'}) == op_with_keys # Two measurement keys cannot be mapped onto the same target string. with pytest.raises(ValueError): _ = op_base.with_measurement_key_mapping({'ma': 'mb'})
def test_parameterized_repeat_side_effects(): q = cirq.LineQubit(0) op = cirq.CircuitOperation( cirq.FrozenCircuit( cirq.X(q).with_classical_controls('c'), cirq.measure(q, key='m')), repetitions=sympy.Symbol('a'), ) # Control keys can be calculated because they only "lift" if there's a matching # measurement, in which case they're not returned here. assert cirq.control_keys(op) == {cirq.MeasurementKey('c')} # "local" params do not bind to the repetition param. assert cirq.parameter_names(op.with_params({'a': 1})) == {'a'} # Check errors that require unrolling the circuit. with pytest.raises( ValueError, match='Cannot unroll circuit due to nondeterministic repetitions'): cirq.measurement_key_objs(op) with pytest.raises( ValueError, match='Cannot unroll circuit due to nondeterministic repetitions'): cirq.measurement_key_names(op) with pytest.raises( ValueError, match='Cannot unroll circuit due to nondeterministic repetitions'): op.mapped_circuit() with pytest.raises( ValueError, match='Cannot unroll circuit due to nondeterministic repetitions'): cirq.decompose(op) # Not compatible with repetition ids with pytest.raises(ValueError, match='repetition ids with parameterized repetitions'): op.with_repetition_ids(['x', 'y']) with pytest.raises(ValueError, match='repetition ids with parameterized repetitions'): op.repeat(repetition_ids=['x', 'y']) # TODO(daxfohl): This should work, but likely requires a new protocol that returns *just* the # name of the measurement keys. (measurement_key_names returns the full serialized string). with pytest.raises( ValueError, match='Cannot unroll circuit due to nondeterministic repetitions'): cirq.with_measurement_key_mapping(op, {'m': 'm2'}) # Everything should work once resolved op = cirq.resolve_parameters(op, {'a': 2}) assert set(map(str, cirq.measurement_key_objs(op))) == {'0:m', '1:m'} assert op.mapped_circuit() == cirq.Circuit( cirq.X(q).with_classical_controls('c'), cirq.measure(q, key=cirq.MeasurementKey.parse_serialized('0:m')), cirq.X(q).with_classical_controls('c'), cirq.measure(q, key=cirq.MeasurementKey.parse_serialized('1:m')), ) assert cirq.decompose(op) == cirq.decompose( cirq.Circuit( cirq.X(q).with_classical_controls('c'), cirq.measure(q, key=cirq.MeasurementKey.parse_serialized('0:m')), cirq.X(q).with_classical_controls('c'), cirq.measure(q, key=cirq.MeasurementKey.parse_serialized('1:m')), ))
def test_measurement_keys(): class Composite(cirq.Gate): def _decompose_(self, qubits): yield cirq.measure(qubits[0], key='inner1') yield cirq.measure(qubits[1], key='inner2') yield cirq.reset(qubits[0]) def num_qubits(self) -> int: return 2 class MeasurementKeysGate(cirq.Gate): def _measurement_key_names_(self): return ['a', 'b'] def num_qubits(self) -> int: return 1 class DeprecatedMagicMethod(cirq.Gate): def _measurement_keys_(self): return ['a', 'b'] def num_qubits(self) -> int: return 1 a, b = cirq.LineQubit.range(2) assert cirq.is_measurement(Composite()) with cirq.testing.assert_deprecated(deadline="v0.13"): assert cirq.measurement_keys(Composite()) == {'inner1', 'inner2'} with cirq.testing.assert_deprecated(deadline="v0.13"): assert cirq.measurement_key_names( DeprecatedMagicMethod()) == {'a', 'b'} with cirq.testing.assert_deprecated(deadline="v0.13"): assert cirq.measurement_key_names( DeprecatedMagicMethod().on(a)) == {'a', 'b'} assert cirq.measurement_key_names(Composite()) == {'inner1', 'inner2'} assert cirq.measurement_key_names(Composite().on( a, b)) == {'inner1', 'inner2'} assert not cirq.is_measurement(Composite(), allow_decompose=False) assert cirq.measurement_key_names(Composite(), allow_decompose=False) == set() assert cirq.measurement_key_names(Composite().on(a, b), allow_decompose=False) == set() assert cirq.measurement_key_names(None) == set() assert cirq.measurement_key_names([]) == set() assert cirq.measurement_key_names(cirq.X) == set() assert cirq.measurement_key_names(cirq.X(a)) == set() assert cirq.measurement_key_names(None, allow_decompose=False) == set() assert cirq.measurement_key_names([], allow_decompose=False) == set() assert cirq.measurement_key_names(cirq.X, allow_decompose=False) == set() assert cirq.measurement_key_names(cirq.measure(a, key='out')) == {'out'} assert cirq.measurement_key_names(cirq.measure(a, key='out'), allow_decompose=False) == {'out'} assert cirq.measurement_key_names( cirq.Circuit(cirq.measure(a, key='a'), cirq.measure(b, key='2'))) == {'a', '2'} assert cirq.measurement_key_names(MeasurementKeysGate()) == {'a', 'b'} assert cirq.measurement_key_names( MeasurementKeysGate().on(a)) == {'a', 'b'}
def test_measurement_keys(): class MeasurementKeysGate(cirq.Gate): def _measurement_key_names_(self): return ['a', 'b'] def num_qubits(self) -> int: return 1 class DeprecatedMagicMethod(cirq.Gate): def _measurement_keys_(self): return ['a', 'b'] def num_qubits(self) -> int: return 1 a, b = cirq.LineQubit.range(2) with cirq.testing.assert_deprecated(deadline="v0.13"): assert cirq.measurement_key_names(DeprecatedMagicMethod()) == {'a', 'b'} with cirq.testing.assert_deprecated(deadline="v0.13"): assert cirq.measurement_key_names(DeprecatedMagicMethod().on(a)) == {'a', 'b'} assert cirq.measurement_key_names(None) == set() assert cirq.measurement_key_names([]) == set() assert cirq.measurement_key_names(cirq.X) == set() assert cirq.measurement_key_names(cirq.X(a)) == set() with cirq.testing.assert_deprecated(deadline="v0.14"): assert cirq.measurement_key_names(None, allow_decompose=False) == set() with cirq.testing.assert_deprecated(deadline="v0.14"): assert cirq.measurement_key_names([], allow_decompose=False) == set() with cirq.testing.assert_deprecated(deadline="v0.14"): assert cirq.measurement_key_names(cirq.X, allow_decompose=False) == set() assert cirq.measurement_key_names(cirq.measure(a, key='out')) == {'out'} with cirq.testing.assert_deprecated(deadline="v0.14"): assert cirq.measurement_key_names(cirq.measure(a, key='out'), allow_decompose=False) == { 'out' } assert cirq.measurement_key_names( cirq.Circuit(cirq.measure(a, key='a'), cirq.measure(b, key='2')) ) == {'a', '2'} assert cirq.measurement_key_names(MeasurementKeysGate()) == {'a', 'b'} assert cirq.measurement_key_names(MeasurementKeysGate().on(a)) == {'a', 'b'}