def test_optimizer(should_quantize):
    x = QuantizedVariable.from_variable(get_var(1.0), quantizer=lambda x: -x)
    opt = tf.keras.optimizers.SGD(1.0)

    def loss():
        with context.quantized_scope(should_quantize):
            return x + 1.0

    @tf.function
    def f():
        opt.minimize(loss, var_list=[x])

    f()
    if should_quantize:
        assert evaluate(x) == 2.0
        with context.quantized_scope(should_quantize):
            assert evaluate(x) == -2.0
    else:
        assert evaluate(x) == 0.0
def test_method_delegations(distribute_scope):
    x = QuantizedVariable.from_variable(get_var(3.5),
                                        quantizer=lambda x: 2 * x)
    with context.quantized_scope(True):
        evaluate(x.initializer)
        assert evaluate(x.value()) == 7
        assert evaluate(x.read_value()) == 7
        assert x.trainable
        if version.parse(tf.__version__) > version.parse("1.14"):
            assert x.synchronization == x.latent_variable.synchronization
        assert x.aggregation == x.latent_variable.aggregation
        assert evaluate(x.initialized_value()) == 7
        if not tf.executing_eagerly():
            if not distribute_scope:
                # These functions are not supported for DistributedVariables
                x.load(4.5)
                assert x.eval() == 9
            assert evaluate(x.initial_value) == 7
            assert x.op == x.latent_variable.op
            assert x.graph == x.latent_variable.graph
        if not distribute_scope:
            # These attributes are not supported for DistributedVariables
            assert x.constraint is None
            assert x.initializer == x.latent_variable.initializer

        def apply_and_read(x, fn, args):
            evaluate(fn(*args))
            return evaluate(x)

        assert apply_and_read(x, x.assign, [4]) == 8
        assert apply_and_read(x, x.assign_add, [1]) == 10
        assert apply_and_read(x, x.assign_sub, [1.5]) == 7
        assert x.name == x.latent_variable.name
        assert x.device == x.latent_variable.device
        assert x.shape == ()
        assert x.get_shape() == ()
        try:
            x.set_shape(())
            assert x.shape == ()
        except NotImplementedError:
            pass
Exemple #3
0
 def getter(*args, **kwargs):
     variable = old_getter(*args, **kwargs)
     return QuantizedVariable.from_variable(variable, quantizer)
def test_inheritance(distribute_scope):
    variable = get_var(3.0)
    quantized_variable = QuantizedVariable.from_variable(variable)
    assert isinstance(quantized_variable, QuantizedVariable)
    assert isinstance(quantized_variable, tf.Variable)
    assert isinstance(quantized_variable, DistributedVariable) is distribute_scope  # type: ignore