def test_resolve_parameters():
    class NoMethod:
        pass

    class ReturnsNotImplemented:
        def _is_parameterized_(self):
            return NotImplemented

        def _resolve_parameters_(self, resolver):
            return NotImplemented

    class SimpleParameterSwitch:
        def __init__(self, var):
            self.parameter = var

        def _is_parameterized_(self) -> bool:
            return self.parameter == 0

        def _resolve_parameters_(self, resolver: ParamResolver):
            self.parameter = resolver.value_of(self.parameter)
            return self

    assert not cirq.is_parameterized(NoMethod())
    assert not cirq.is_parameterized(ReturnsNotImplemented())
    assert not cirq.is_parameterized(SimpleParameterSwitch('a'))
    assert cirq.is_parameterized(SimpleParameterSwitch(0))

    ni = ReturnsNotImplemented()
    d = {'a': 0}
    r = cirq.ParamResolver(d)
    no = NoMethod()
    assert cirq.resolve_parameters(no, r) == no
    assert cirq.resolve_parameters(no, d) == no
    assert cirq.resolve_parameters(ni, r) == ni
    assert cirq.resolve_parameters(SimpleParameterSwitch(0), r).parameter == 0
    assert cirq.resolve_parameters(SimpleParameterSwitch('a'),
                                   r).parameter == 0
    assert cirq.resolve_parameters(SimpleParameterSwitch('a'),
                                   d).parameter == 0
    assert cirq.resolve_parameters(sympy.Symbol('a'), r) == 0

    a, b, c = tuple(sympy.Symbol(l) for l in 'abc')
    x, y, z = 0, 4, 7
    resolver = {a: x, b: y, c: z}

    assert cirq.resolve_parameters((a, b, c), resolver) == (x, y, z)
    assert cirq.resolve_parameters([a, b, c], resolver) == [x, y, z]
    assert cirq.resolve_parameters((x, y, z), resolver) == (x, y, z)
    assert cirq.resolve_parameters([x, y, z], resolver) == [x, y, z]
    assert cirq.resolve_parameters((), resolver) == ()
    assert cirq.resolve_parameters([], resolver) == []
    assert cirq.resolve_parameters(1, resolver) == 1
    assert cirq.resolve_parameters(1.1, resolver) == 1.1
    assert cirq.resolve_parameters(1j, resolver) == 1j

    assert not cirq.is_parameterized((x, y))
    assert not cirq.is_parameterized([x, y])
    assert cirq.is_parameterized([a, b])
    assert cirq.is_parameterized([a, x])
    assert cirq.is_parameterized((a, b))
    assert cirq.is_parameterized((a, x))
    assert not cirq.is_parameterized(())
    assert not cirq.is_parameterized([])
    assert not cirq.is_parameterized(1)
    assert not cirq.is_parameterized(1.1)
    assert not cirq.is_parameterized(1j)

    assert cirq.parameter_names((a, b, c)) == {'a', 'b', 'c'}
    assert cirq.parameter_names([a, b, c]) == {'a', 'b', 'c'}
    assert cirq.parameter_names((x, y, z)) == set()
    assert cirq.parameter_names([x, y, z]) == set()
    assert cirq.parameter_names(()) == set()
    assert cirq.parameter_names([]) == set()
    assert cirq.parameter_names(1) == set()
    assert cirq.parameter_names(1.1) == set()
    assert cirq.parameter_names(1j) == set()
 def _parameter_names_(self) -> AbstractSet[str]:
     return cirq.parameter_names(self.exponent) | cirq.parameter_names(
         self.phase_exponent)
Esempio n. 3
0
 def _parameter_names_(self) -> AbstractSet[str]:
     return cirq.parameter_names(self.theta) | cirq.parameter_names(
         self.phi)
Esempio n. 4
0
def test_parameterized_repeat_side_effects():
    q = cirq.LineQubit(0)
    op = cirq.CircuitOperation(
        cirq.FrozenCircuit(
            cirq.X(q).with_classical_controls('c'), cirq.measure(q, key='m')),
        repetitions=sympy.Symbol('a'),
    )

    # Control keys can be calculated because they only "lift" if there's a matching
    # measurement, in which case they're not returned here.
    assert cirq.control_keys(op) == {cirq.MeasurementKey('c')}

    # "local" params do not bind to the repetition param.
    assert cirq.parameter_names(op.with_params({'a': 1})) == {'a'}

    # Check errors that require unrolling the circuit.
    with pytest.raises(
            ValueError,
            match='Cannot unroll circuit due to nondeterministic repetitions'):
        cirq.measurement_key_objs(op)
    with pytest.raises(
            ValueError,
            match='Cannot unroll circuit due to nondeterministic repetitions'):
        cirq.measurement_key_names(op)
    with pytest.raises(
            ValueError,
            match='Cannot unroll circuit due to nondeterministic repetitions'):
        op.mapped_circuit()
    with pytest.raises(
            ValueError,
            match='Cannot unroll circuit due to nondeterministic repetitions'):
        cirq.decompose(op)

    # Not compatible with repetition ids
    with pytest.raises(ValueError,
                       match='repetition ids with parameterized repetitions'):
        op.with_repetition_ids(['x', 'y'])
    with pytest.raises(ValueError,
                       match='repetition ids with parameterized repetitions'):
        op.repeat(repetition_ids=['x', 'y'])

    # TODO(daxfohl): This should work, but likely requires a new protocol that returns *just* the
    # name of the measurement keys. (measurement_key_names returns the full serialized string).
    with pytest.raises(
            ValueError,
            match='Cannot unroll circuit due to nondeterministic repetitions'):
        cirq.with_measurement_key_mapping(op, {'m': 'm2'})

    # Everything should work once resolved
    op = cirq.resolve_parameters(op, {'a': 2})
    assert set(map(str, cirq.measurement_key_objs(op))) == {'0:m', '1:m'}
    assert op.mapped_circuit() == cirq.Circuit(
        cirq.X(q).with_classical_controls('c'),
        cirq.measure(q, key=cirq.MeasurementKey.parse_serialized('0:m')),
        cirq.X(q).with_classical_controls('c'),
        cirq.measure(q, key=cirq.MeasurementKey.parse_serialized('1:m')),
    )
    assert cirq.decompose(op) == cirq.decompose(
        cirq.Circuit(
            cirq.X(q).with_classical_controls('c'),
            cirq.measure(q, key=cirq.MeasurementKey.parse_serialized('0:m')),
            cirq.X(q).with_classical_controls('c'),
            cirq.measure(q, key=cirq.MeasurementKey.parse_serialized('1:m')),
        ))
Esempio n. 5
0
def test_parameterized_repeat():
    q = cirq.LineQubit(0)
    op = cirq.CircuitOperation(cirq.FrozenCircuit(
        cirq.X(q)))**sympy.Symbol('a')
    assert cirq.parameter_names(op) == {'a'}
    assert not cirq.has_unitary(op)
    result = cirq.Simulator().simulate(cirq.Circuit(op),
                                       param_resolver={'a': 0})
    assert np.allclose(result.state_vector(), [1, 0])
    result = cirq.Simulator().simulate(cirq.Circuit(op),
                                       param_resolver={'a': 1})
    assert np.allclose(result.state_vector(), [0, 1])
    result = cirq.Simulator().simulate(cirq.Circuit(op),
                                       param_resolver={'a': 2})
    assert np.allclose(result.state_vector(), [1, 0])
    result = cirq.Simulator().simulate(cirq.Circuit(op),
                                       param_resolver={'a': -1})
    assert np.allclose(result.state_vector(), [0, 1])
    with pytest.raises(TypeError,
                       match='Only integer or sympy repetitions are allowed'):
        cirq.Simulator().simulate(cirq.Circuit(op), param_resolver={'a': 1.5})
    with pytest.raises(
            ValueError,
            match='Circuit contains ops whose symbols were not specified'):
        cirq.Simulator().simulate(cirq.Circuit(op))
    op = op**-1
    assert cirq.parameter_names(op) == {'a'}
    assert not cirq.has_unitary(op)
    result = cirq.Simulator().simulate(cirq.Circuit(op),
                                       param_resolver={'a': 0})
    assert np.allclose(result.state_vector(), [1, 0])
    result = cirq.Simulator().simulate(cirq.Circuit(op),
                                       param_resolver={'a': 1})
    assert np.allclose(result.state_vector(), [0, 1])
    result = cirq.Simulator().simulate(cirq.Circuit(op),
                                       param_resolver={'a': 2})
    assert np.allclose(result.state_vector(), [1, 0])
    result = cirq.Simulator().simulate(cirq.Circuit(op),
                                       param_resolver={'a': -1})
    assert np.allclose(result.state_vector(), [0, 1])
    with pytest.raises(TypeError,
                       match='Only integer or sympy repetitions are allowed'):
        cirq.Simulator().simulate(cirq.Circuit(op), param_resolver={'a': 1.5})
    with pytest.raises(
            ValueError,
            match='Circuit contains ops whose symbols were not specified'):
        cirq.Simulator().simulate(cirq.Circuit(op))
    op = op**sympy.Symbol('b')
    assert cirq.parameter_names(op) == {'a', 'b'}
    assert not cirq.has_unitary(op)
    result = cirq.Simulator().simulate(cirq.Circuit(op),
                                       param_resolver={
                                           'a': 1,
                                           'b': 1
                                       })
    assert np.allclose(result.state_vector(), [0, 1])
    result = cirq.Simulator().simulate(cirq.Circuit(op),
                                       param_resolver={
                                           'a': 2,
                                           'b': 1
                                       })
    assert np.allclose(result.state_vector(), [1, 0])
    result = cirq.Simulator().simulate(cirq.Circuit(op),
                                       param_resolver={
                                           'a': 1,
                                           'b': 2
                                       })
    assert np.allclose(result.state_vector(), [1, 0])
    with pytest.raises(TypeError,
                       match='Only integer or sympy repetitions are allowed'):
        cirq.Simulator().simulate(cirq.Circuit(op),
                                  param_resolver={
                                      'a': 1.5,
                                      'b': 1
                                  })
    with pytest.raises(
            ValueError,
            match='Circuit contains ops whose symbols were not specified'):
        cirq.Simulator().simulate(cirq.Circuit(op))
    op = op**2.0
    assert cirq.parameter_names(op) == {'a', 'b'}
    assert not cirq.has_unitary(op)
    result = cirq.Simulator().simulate(cirq.Circuit(op),
                                       param_resolver={
                                           'a': 1,
                                           'b': 1
                                       })
    assert np.allclose(result.state_vector(), [1, 0])
    result = cirq.Simulator().simulate(cirq.Circuit(op),
                                       param_resolver={
                                           'a': 1.5,
                                           'b': 1
                                       })
    assert np.allclose(result.state_vector(), [0, 1])
    result = cirq.Simulator().simulate(cirq.Circuit(op),
                                       param_resolver={
                                           'a': 1,
                                           'b': 1.5
                                       })
    assert np.allclose(result.state_vector(), [0, 1])
    with pytest.raises(TypeError,
                       match='Only integer or sympy repetitions are allowed'):
        cirq.Simulator().simulate(cirq.Circuit(op),
                                  param_resolver={
                                      'a': 1.5,
                                      'b': 1.5
                                  })
    with pytest.raises(
            ValueError,
            match='Circuit contains ops whose symbols were not specified'):
        cirq.Simulator().simulate(cirq.Circuit(op))
Esempio n. 6
0
 def _parameter_names_(self) -> AbstractSet[str]:
     return cirq.parameter_names(self.value)
Esempio n. 7
0
def test_periodic_value_is_parameterized(value, is_parameterized,
                                         parameter_names):
    assert cirq.is_parameterized(value) == is_parameterized
    assert cirq.parameter_names(value) == parameter_names
    resolved = cirq.resolve_parameters(value, {p: 1 for p in parameter_names})
    assert not cirq.is_parameterized(resolved)