예제 #1
0
def test_parameterized_repeat_side_effects_when_not_using_rep_ids():
    q = cirq.LineQubit(0)
    op = cirq.CircuitOperation(
        cirq.FrozenCircuit(
            cirq.X(q).with_classical_controls('c'), cirq.measure(q, key='m')),
        repetitions=sympy.Symbol('a'),
        use_repetition_ids=False,
    )
    assert cirq.control_keys(op) == {cirq.MeasurementKey('c')}
    assert cirq.parameter_names(op.with_params({'a': 1})) == {'a'}
    assert set(map(str, cirq.measurement_key_objs(op))) == {'m'}
    assert cirq.measurement_key_names(op) == {'m'}
    assert cirq.measurement_key_names(
        cirq.with_measurement_key_mapping(op, {'m': 'm2'})) == {'m2'}
    with pytest.raises(
            ValueError,
            match='Cannot unroll circuit due to nondeterministic repetitions'):
        op.mapped_circuit()
    with pytest.raises(
            ValueError,
            match='Cannot unroll circuit due to nondeterministic repetitions'):
        cirq.decompose(op)
    with pytest.raises(ValueError,
                       match='repetition ids with parameterized repetitions'):
        op.with_repetition_ids(['x', 'y'])
    with pytest.raises(ValueError,
                       match='repetition ids with parameterized repetitions'):
        op.repeat(repetition_ids=['x', 'y'])
예제 #2
0
def test_keys_under_parent_path():
    a = cirq.LineQubit(0)
    op1 = cirq.CircuitOperation(cirq.FrozenCircuit(cirq.measure(a, key='A')))
    assert cirq.measurement_key_names(op1) == {'A'}
    op2 = op1.with_key_path(('B', ))
    assert cirq.measurement_key_names(op2) == {'B:A'}
    op3 = op2.repeat(2)
    assert cirq.measurement_key_names(op3) == {'B:0:A', 'B:1:A'}
예제 #3
0
def test_measurement_keys():
    a = cirq.NamedQubit('a')
    b = cirq.NamedQubit('b')
    m = cirq.Moment(cirq.X(a), cirq.X(b))
    assert cirq.measurement_key_names(m) == set()
    assert not cirq.is_measurement(m)

    m2 = cirq.Moment(cirq.measure(a, b, key='foo'))
    assert cirq.measurement_key_names(m2) == {'foo'}
    assert cirq.is_measurement(m2)
예제 #4
0
def test_decompose_repeated_nested_measurements():
    # Details of this test described at
    # https://tinyurl.com/measurement-repeated-circuitop#heading=h.sbgxcsyin9wt.
    a = cirq.LineQubit(0)

    op1 = (cirq.CircuitOperation(cirq.FrozenCircuit(cirq.measure(
        a, key='A'))).with_measurement_key_mapping({
            'A': 'B'
        }).repeat(2, ['zero', 'one']))

    op2 = (cirq.CircuitOperation(
        cirq.FrozenCircuit(cirq.measure(a, key='P'),
                           op1)).with_measurement_key_mapping({
                               'B': 'C',
                               'P': 'Q'
                           }).repeat(2, ['zero', 'one']))

    op3 = (cirq.CircuitOperation(
        cirq.FrozenCircuit(cirq.measure(a, key='X'),
                           op2)).with_measurement_key_mapping({
                               'C': 'D',
                               'X': 'Y'
                           }).repeat(2, ['zero', 'one']))

    expected_measurement_keys_in_order = [
        'zero:Y',
        'zero:zero:Q',
        'zero:zero:zero:D',
        'zero:zero:one:D',
        'zero:one:Q',
        'zero:one:zero:D',
        'zero:one:one:D',
        'one:Y',
        'one:zero:Q',
        'one:zero:zero:D',
        'one:zero:one:D',
        'one:one:Q',
        'one:one:zero:D',
        'one:one:one:D',
    ]
    assert cirq.measurement_key_names(op3) == set(
        expected_measurement_keys_in_order)

    expected_circuit = cirq.Circuit()
    for key in expected_measurement_keys_in_order:
        expected_circuit.append(
            cirq.measure(a, key=cirq.MeasurementKey.parse_serialized(key)))

    assert cirq.Circuit(cirq.decompose(op3)) == expected_circuit
    assert cirq.measurement_key_names(expected_circuit) == set(
        expected_measurement_keys_in_order)

    # Verify that mapped_circuit gives the same operations.
    assert op3.mapped_circuit(deep=True) == expected_circuit
예제 #5
0
def test_with_measurement_key_mapping():
    a = cirq.LineQubit(0)
    op = cirq.measure(a, key='m')

    remap_op = cirq.with_measurement_key_mapping(op, {'m': 'k'})
    assert cirq.measurement_key_names(remap_op) == {'k'}
    assert cirq.with_measurement_key_mapping(op, {'x': 'k'}) is op
예제 #6
0
def test_with_key_path_prefix():
    a = cirq.LineQubit(0)
    op = cirq.measure(a, key='m')
    remap_op = cirq.with_key_path_prefix(op, ('a', 'b'))
    assert cirq.measurement_key_names(remap_op) == {'a:b:m'}
    assert cirq.with_key_path_prefix(remap_op, tuple()) is remap_op
    assert cirq.with_key_path_prefix(op, tuple()) is op
    assert cirq.with_key_path_prefix(cirq.X(a), ('a', 'b')) is NotImplemented
def test_measurement_key_mapping():
    class MultiKeyGate:
        def __init__(self, keys):
            self._keys = set(keys)

        def _measurement_key_names_(self):
            return self._keys

        def _with_measurement_key_mapping_(self, key_map):
            if not all(key in key_map for key in self._keys):
                raise ValueError('missing keys')
            return MultiKeyGate([key_map[key] for key in self._keys])

    assert cirq.measurement_key_names(MultiKeyGate([])) == set()
    assert cirq.measurement_key_names(MultiKeyGate(['a'])) == {'a'}

    mkg_ab = MultiKeyGate(['a', 'b'])
    assert cirq.measurement_key_names(mkg_ab) == {'a', 'b'}

    mkg_cd = cirq.with_measurement_key_mapping(mkg_ab, {'a': 'c', 'b': 'd'})
    assert cirq.measurement_key_names(mkg_cd) == {'c', 'd'}

    mkg_ac = cirq.with_measurement_key_mapping(mkg_ab, {'a': 'a', 'b': 'c'})
    assert cirq.measurement_key_names(mkg_ac) == {'a', 'c'}

    mkg_ba = cirq.with_measurement_key_mapping(mkg_ab, {'a': 'b', 'b': 'a'})
    assert cirq.measurement_key_names(mkg_ba) == {'a', 'b'}

    with pytest.raises(ValueError):
        cirq.with_measurement_key_mapping(mkg_ab, {'a': 'c'})

    assert cirq.with_measurement_key_mapping(cirq.X, {'a': 'c'}) is NotImplemented

    mkg_cdx = cirq.with_measurement_key_mapping(mkg_ab, {'a': 'c', 'b': 'd', 'x': 'y'})
    assert cirq.measurement_key_names(mkg_cdx) == {'c', 'd'}
def test_mapped_circuit_keeps_keys_under_parent_path():
    q = cirq.LineQubit(0)
    op1 = cirq.CircuitOperation(
        cirq.FrozenCircuit(
            cirq.measure(q, key='A'),
            cirq.measure_single_paulistring(cirq.X(q), key='B'),
            cirq.MixedUnitaryChannel.from_mixture(cirq.bit_flip(0.5), key='C').on(q),
            cirq.KrausChannel.from_channel(cirq.phase_damp(0.5), key='D').on(q),
        )
    )
    op2 = op1.with_key_path(('X',))
    assert cirq.measurement_key_names(op2.mapped_circuit()) == {'X:A', 'X:B', 'X:C', 'X:D'}
def test_measurement_key_path():
    class MultiKeyGate:
        def __init__(self, keys):
            self._keys = set([cirq.MeasurementKey.parse_serialized(key) for key in keys])

        def _measurement_key_names_(self):
            return {str(key) for key in self._keys}

        def _with_key_path_(self, path):
            return MultiKeyGate([str(key._with_key_path_(path)) for key in self._keys])

    assert cirq.measurement_key_names(MultiKeyGate([])) == set()
    assert cirq.measurement_key_names(MultiKeyGate(['a'])) == {'a'}

    mkg_ab = MultiKeyGate(['a', 'b'])
    assert cirq.measurement_key_names(mkg_ab) == {'a', 'b'}

    mkg_cd = cirq.with_key_path(mkg_ab, ('c', 'd'))
    assert cirq.measurement_key_names(mkg_cd) == {'c:d:a', 'c:d:b'}

    assert cirq.with_key_path(cirq.X, ('c', 'd')) is NotImplemented
예제 #10
0
def test_tagged_measurement():
    assert not cirq.is_measurement(cirq.GlobalPhaseOperation(coefficient=-1.0).with_tags('tag0'))

    a = cirq.LineQubit(0)
    op = cirq.measure(a, key='m').with_tags('tag')
    assert cirq.is_measurement(op)

    remap_op = cirq.with_measurement_key_mapping(op, {'m': 'k'})
    assert remap_op.tags == ('tag',)
    assert cirq.is_measurement(remap_op)
    assert cirq.measurement_key_names(remap_op) == {'k'}
    assert cirq.with_measurement_key_mapping(op, {'x': 'k'}) == op
예제 #11
0
def test_measurement_key_enumerable_deprecated():
    class Deprecated:
        def _measurement_key_objs_(self):
            return [cirq.MeasurementKey('key')]

        def _measurement_key_names_(self):
            return ['key']

    with cirq.testing.assert_deprecated('frozenset', deadline='v0.16'):
        assert cirq.measurement_key_objs(
            Deprecated()) == {cirq.MeasurementKey('key')}

    with cirq.testing.assert_deprecated('frozenset', deadline='v0.16'):
        assert cirq.measurement_key_names(Deprecated()) == {'key'}
def test_with_measurement_keys():
    a, b = cirq.LineQubit.range(2)
    circuit = cirq.FrozenCircuit(cirq.X(a), cirq.measure(b, key='mb'), cirq.measure(a, key='ma'))
    op_base = cirq.CircuitOperation(circuit)

    op_with_keys = op_base.with_measurement_key_mapping({'ma': 'pa', 'x': 'z'})
    assert op_with_keys.base_operation() == op_base
    assert op_with_keys.measurement_key_map == {'ma': 'pa'}
    assert cirq.measurement_key_names(op_with_keys) == {'pa', 'mb'}

    assert cirq.with_measurement_key_mapping(op_base, {'ma': 'pa'}) == op_with_keys

    # Two measurement keys cannot be mapped onto the same target string.
    with pytest.raises(ValueError):
        _ = op_base.with_measurement_key_mapping({'ma': 'mb'})
예제 #13
0
def test_parameterized_repeat_side_effects():
    q = cirq.LineQubit(0)
    op = cirq.CircuitOperation(
        cirq.FrozenCircuit(
            cirq.X(q).with_classical_controls('c'), cirq.measure(q, key='m')),
        repetitions=sympy.Symbol('a'),
    )

    # Control keys can be calculated because they only "lift" if there's a matching
    # measurement, in which case they're not returned here.
    assert cirq.control_keys(op) == {cirq.MeasurementKey('c')}

    # "local" params do not bind to the repetition param.
    assert cirq.parameter_names(op.with_params({'a': 1})) == {'a'}

    # Check errors that require unrolling the circuit.
    with pytest.raises(
            ValueError,
            match='Cannot unroll circuit due to nondeterministic repetitions'):
        cirq.measurement_key_objs(op)
    with pytest.raises(
            ValueError,
            match='Cannot unroll circuit due to nondeterministic repetitions'):
        cirq.measurement_key_names(op)
    with pytest.raises(
            ValueError,
            match='Cannot unroll circuit due to nondeterministic repetitions'):
        op.mapped_circuit()
    with pytest.raises(
            ValueError,
            match='Cannot unroll circuit due to nondeterministic repetitions'):
        cirq.decompose(op)

    # Not compatible with repetition ids
    with pytest.raises(ValueError,
                       match='repetition ids with parameterized repetitions'):
        op.with_repetition_ids(['x', 'y'])
    with pytest.raises(ValueError,
                       match='repetition ids with parameterized repetitions'):
        op.repeat(repetition_ids=['x', 'y'])

    # TODO(daxfohl): This should work, but likely requires a new protocol that returns *just* the
    # name of the measurement keys. (measurement_key_names returns the full serialized string).
    with pytest.raises(
            ValueError,
            match='Cannot unroll circuit due to nondeterministic repetitions'):
        cirq.with_measurement_key_mapping(op, {'m': 'm2'})

    # Everything should work once resolved
    op = cirq.resolve_parameters(op, {'a': 2})
    assert set(map(str, cirq.measurement_key_objs(op))) == {'0:m', '1:m'}
    assert op.mapped_circuit() == cirq.Circuit(
        cirq.X(q).with_classical_controls('c'),
        cirq.measure(q, key=cirq.MeasurementKey.parse_serialized('0:m')),
        cirq.X(q).with_classical_controls('c'),
        cirq.measure(q, key=cirq.MeasurementKey.parse_serialized('1:m')),
    )
    assert cirq.decompose(op) == cirq.decompose(
        cirq.Circuit(
            cirq.X(q).with_classical_controls('c'),
            cirq.measure(q, key=cirq.MeasurementKey.parse_serialized('0:m')),
            cirq.X(q).with_classical_controls('c'),
            cirq.measure(q, key=cirq.MeasurementKey.parse_serialized('1:m')),
        ))
예제 #14
0
def test_measurement_keys():
    class Composite(cirq.Gate):
        def _decompose_(self, qubits):
            yield cirq.measure(qubits[0], key='inner1')
            yield cirq.measure(qubits[1], key='inner2')
            yield cirq.reset(qubits[0])

        def num_qubits(self) -> int:
            return 2

    class MeasurementKeysGate(cirq.Gate):
        def _measurement_key_names_(self):
            return ['a', 'b']

        def num_qubits(self) -> int:
            return 1

    class DeprecatedMagicMethod(cirq.Gate):
        def _measurement_keys_(self):
            return ['a', 'b']

        def num_qubits(self) -> int:
            return 1

    a, b = cirq.LineQubit.range(2)
    assert cirq.is_measurement(Composite())
    with cirq.testing.assert_deprecated(deadline="v0.13"):
        assert cirq.measurement_keys(Composite()) == {'inner1', 'inner2'}
    with cirq.testing.assert_deprecated(deadline="v0.13"):
        assert cirq.measurement_key_names(
            DeprecatedMagicMethod()) == {'a', 'b'}
    with cirq.testing.assert_deprecated(deadline="v0.13"):
        assert cirq.measurement_key_names(
            DeprecatedMagicMethod().on(a)) == {'a', 'b'}
    assert cirq.measurement_key_names(Composite()) == {'inner1', 'inner2'}
    assert cirq.measurement_key_names(Composite().on(
        a, b)) == {'inner1', 'inner2'}
    assert not cirq.is_measurement(Composite(), allow_decompose=False)
    assert cirq.measurement_key_names(Composite(),
                                      allow_decompose=False) == set()
    assert cirq.measurement_key_names(Composite().on(a, b),
                                      allow_decompose=False) == set()

    assert cirq.measurement_key_names(None) == set()
    assert cirq.measurement_key_names([]) == set()
    assert cirq.measurement_key_names(cirq.X) == set()
    assert cirq.measurement_key_names(cirq.X(a)) == set()
    assert cirq.measurement_key_names(None, allow_decompose=False) == set()
    assert cirq.measurement_key_names([], allow_decompose=False) == set()
    assert cirq.measurement_key_names(cirq.X, allow_decompose=False) == set()
    assert cirq.measurement_key_names(cirq.measure(a, key='out')) == {'out'}
    assert cirq.measurement_key_names(cirq.measure(a, key='out'),
                                      allow_decompose=False) == {'out'}

    assert cirq.measurement_key_names(
        cirq.Circuit(cirq.measure(a, key='a'),
                     cirq.measure(b, key='2'))) == {'a', '2'}
    assert cirq.measurement_key_names(MeasurementKeysGate()) == {'a', 'b'}
    assert cirq.measurement_key_names(
        MeasurementKeysGate().on(a)) == {'a', 'b'}
def test_measurement_keys():
    class MeasurementKeysGate(cirq.Gate):
        def _measurement_key_names_(self):
            return ['a', 'b']

        def num_qubits(self) -> int:
            return 1

    class DeprecatedMagicMethod(cirq.Gate):
        def _measurement_keys_(self):
            return ['a', 'b']

        def num_qubits(self) -> int:
            return 1

    a, b = cirq.LineQubit.range(2)
    with cirq.testing.assert_deprecated(deadline="v0.13"):
        assert cirq.measurement_key_names(DeprecatedMagicMethod()) == {'a', 'b'}
    with cirq.testing.assert_deprecated(deadline="v0.13"):
        assert cirq.measurement_key_names(DeprecatedMagicMethod().on(a)) == {'a', 'b'}

    assert cirq.measurement_key_names(None) == set()
    assert cirq.measurement_key_names([]) == set()
    assert cirq.measurement_key_names(cirq.X) == set()
    assert cirq.measurement_key_names(cirq.X(a)) == set()
    with cirq.testing.assert_deprecated(deadline="v0.14"):
        assert cirq.measurement_key_names(None, allow_decompose=False) == set()
    with cirq.testing.assert_deprecated(deadline="v0.14"):
        assert cirq.measurement_key_names([], allow_decompose=False) == set()
    with cirq.testing.assert_deprecated(deadline="v0.14"):
        assert cirq.measurement_key_names(cirq.X, allow_decompose=False) == set()
    assert cirq.measurement_key_names(cirq.measure(a, key='out')) == {'out'}
    with cirq.testing.assert_deprecated(deadline="v0.14"):
        assert cirq.measurement_key_names(cirq.measure(a, key='out'), allow_decompose=False) == {
            'out'
        }

    assert cirq.measurement_key_names(
        cirq.Circuit(cirq.measure(a, key='a'), cirq.measure(b, key='2'))
    ) == {'a', '2'}
    assert cirq.measurement_key_names(MeasurementKeysGate()) == {'a', 'b'}
    assert cirq.measurement_key_names(MeasurementKeysGate().on(a)) == {'a', 'b'}