def test_flattener_repr():
    assert repr(flatten_expressions._ParamFlattener(
        {'a': 1})) == ("_ParamFlattener({a: 1})")
    assert repr(
        flatten_expressions._ParamFlattener(
            {'a': 1}, get_param_name=lambda expr: 'x')).startswith(
                "_ParamFlattener({a: 1}, get_param_name=<function ")
def test_flattener_value_of():
    flattener = flatten_expressions._ParamFlattener({'c': 5, 'x1': 'x1'})
    assert flattener.value_of(9) == 9
    assert flattener.value_of('c') == 5
    assert flattener.value_of(sympy.Symbol('c')) == 5
    # Twice
    assert (flattener.value_of(sympy.Symbol('c') / 2 +
                               1) == sympy.Symbol('<c/2 + 1>'))
    assert (flattener.value_of(sympy.Symbol('c') / 2 +
                               1) == sympy.Symbol('<c/2 + 1>'))
    # Collisions between the string representation of different expressions
    # This tests the unusual case where str(expr1) == str(expr2) doesn't imply
    # expr1 == expr2.  In this case it would be incorrect to flatten to the same
    # symbol because the two expression will evaluate to different values.
    # Also tests that '_#' is appended when avoiding collisions.
    assert (flattener.value_of(
        sympy.Symbol('c') /
        sympy.Symbol('2 + 1')) == sympy.Symbol('<c/2 + 1>_1'))
    assert (flattener.value_of(sympy.Symbol('c/2') +
                               1) == sympy.Symbol('<c/2 + 1>_2'))

    assert (cirq.flatten([sympy.Symbol('c') / 2 + 1,
                          sympy.Symbol('c/2') + 1])[0] == [
                              sympy.Symbol('<c/2 + 1>'),
                              sympy.Symbol('<c/2 + 1>_1')
                          ])
def test_expr_map_names():
    flattener = flatten_expressions._ParamFlattener({'collision': '<x + 2>'})
    expressions = [sympy.Symbol('x') + i for i in range(3)]
    syms = flattener.flatten(expressions)
    assert syms == [
        sympy.Symbol(name) for name in ('x', '<x + 1>', '<x + 2>_1')
    ]
def test_resolver_new():
    flattener = flatten_expressions._ParamFlattener({'a': 'b'})
    flattener2 = cirq.ParamResolver(flattener)
    assert flattener2 is flattener
def test_flattener_new():
    flattener = flatten_expressions._ParamFlattener({'a': 'b'})
    flattener2 = flatten_expressions._ParamFlattener(flattener)
    assert isinstance(flattener2, flatten_expressions._ParamFlattener)
    assert flattener2.param_dict == flattener.param_dict