def test_list_sweep_str(): assert str(cirq.UnitSweep) == '''Sweep: {}''' assert str(cirq.Linspace('a', start=0, stop=3, length=4)) == '''Sweep: {'a': 0.0} {'a': 1.0} {'a': 2.0} {'a': 3.0}''' assert str(cirq.Linspace('a', start=0, stop=15.75, length=64)) == '''Sweep: {'a': 0.0} {'a': 0.25} {'a': 0.5} {'a': 0.75} {'a': 1.0} ... {'a': 14.75} {'a': 15.0} {'a': 15.25} {'a': 15.5} {'a': 15.75}''' assert str( cirq.ListSweep( cirq.Linspace('a', 0, 3, 4) + cirq.Linspace('b', 1, 2, 2))) == '''Sweep: {'a': 0.0, 'b': 1.0} {'a': 1.0, 'b': 2.0}''' assert str( cirq.ListSweep( cirq.Linspace('a', 0, 3, 4) * cirq.Linspace('b', 1, 2, 2))) == '''Sweep:
def test_list_sweep(r_list): sweep = cirq.ListSweep(r_list) assert sweep.keys == ['a', 'b'] assert len(sweep) == 4 assert len(list(sweep)) == 4 assert list(sweep)[1] == cirq.ParamResolver({'a': 0.5, 'b': 1.5}) params = list(sweep.param_tuples()) assert len(params) == 4 assert params[3] == (('a', -10), ('b', -9))
def test_repr(): cirq.testing.assert_equivalent_repr( cirq.study.sweeps.Product(cirq.UnitSweep), setup_code='import cirq\nfrom collections import OrderedDict') cirq.testing.assert_equivalent_repr( cirq.study.sweeps.Zip(cirq.UnitSweep), setup_code='import cirq\nfrom collections import OrderedDict') cirq.testing.assert_equivalent_repr( cirq.ListSweep(cirq.Linspace('a', start=0, stop=3, length=4)), setup_code='import cirq\nfrom collections import OrderedDict')
def test_symbol_to_string_conversion(): sweep = cirq.ListSweep([cirq.ParamResolver({sympy.Symbol('a'): 4.0})]) proto = v2.sweep_to_proto(sweep) assert isinstance(proto, v2.run_context_pb2.Sweep) expected = v2.run_context_pb2.Sweep() expected.sweep_function.function_type = v2.run_context_pb2.SweepFunction.ZIP p1 = expected.sweep_function.sweeps.add() p1.single_sweep.parameter_key = 'a' p1.single_sweep.points.points.extend([4.0]) assert proto == expected
def test_repr(): cirq.testing.assert_equivalent_repr( cirq.study.sweeps.Product(cirq.UnitSweep), setup_code='import cirq\nfrom collections import OrderedDict') cirq.testing.assert_equivalent_repr( cirq.study.sweeps.Zip(cirq.UnitSweep), setup_code='import cirq\nfrom collections import OrderedDict') cirq.testing.assert_equivalent_repr( cirq.ListSweep(cirq.Linspace('a', start=0, stop=3, length=4)), setup_code='import cirq\nfrom collections import OrderedDict') cirq.testing.assert_equivalent_repr(cirq.Points('zero&pi', [0, 3.14159])) cirq.testing.assert_equivalent_repr(cirq.Linspace('I/10', 0, 1, 10))
def test_equality(): et = cirq.testing.EqualsTester() et.add_equality_group(cirq.UnitSweep, cirq.UnitSweep) # Simple sweeps with the same key are equal to themselves, but different # from each other even if they happen to contain the same points. et.make_equality_group(lambda: cirq.Linspace('a', 0, 10, 11)) et.make_equality_group(lambda: cirq.Linspace('b', 0, 10, 11)) et.make_equality_group(lambda: cirq.Points('a', list(range(11)))) et.make_equality_group(lambda: cirq.Points('b', list(range(11)))) # Product and Zip sweeps can also be equated. et.make_equality_group(lambda: cirq.Linspace('a', 0, 5, 6) * cirq.Linspace('b', 10, 15, 6)) et.make_equality_group(lambda: cirq.Linspace('a', 0, 5, 6) + cirq.Linspace('b', 10, 15, 6)) et.make_equality_group( lambda: cirq.Points('a', [1, 2]) * (cirq.Linspace('b', 0, 5, 6) + cirq.Linspace('c', 10, 15, 6)) ) # ListSweep et.make_equality_group( lambda: cirq.ListSweep([{'var': 1}, {'var': -1}]), lambda: cirq.ListSweep(({'var': 1}, {'var': -1})), lambda: cirq.ListSweep(r for r in ({'var': 1}, {'var': -1})), ) et.make_equality_group(lambda: cirq.ListSweep([{'var': -1}, {'var': 1}])) et.make_equality_group(lambda: cirq.ListSweep([{'var': 1}])) et.make_equality_group(lambda: cirq.ListSweep([{'x': 1}, {'x': -1}]))
def test_list_sweep_type_error(): with pytest.raises(TypeError, match='Not a ParamResolver'): _ = cirq.ListSweep([cirq.ParamResolver(), 'bad'])
def test_list_sweep_empty(): assert cirq.ListSweep([]).keys == []
def test_list_sweep_bad_expression(): with pytest.raises(TypeError, match='formula'): _ = cirq.ListSweep([cirq.ParamResolver({sympy.Symbol('a') + sympy.Symbol('b'): 4.0})])
def test_list_sweep_bad_expression(): sweep = cirq.ListSweep( [cirq.ParamResolver({sympy.Symbol('a') + sympy.Symbol('b'): 4.0})]) with pytest.raises(ValueError, match='cannot convert'): v2.sweep_to_proto(sweep)