def test_optimizer(should_quantize): x = QuantizedVariable.from_variable(get_var(1.0), quantizer=lambda x: -x) opt = tf.keras.optimizers.SGD(1.0) def loss(): with context.quantized_scope(should_quantize): return x + 1.0 @tf.function def f(): opt.minimize(loss, var_list=[x]) f() if should_quantize: assert evaluate(x) == 2.0 with context.quantized_scope(should_quantize): assert evaluate(x) == -2.0 else: assert evaluate(x) == 0.0
def test_method_delegations(distribute_scope): x = QuantizedVariable.from_variable(get_var(3.5), quantizer=lambda x: 2 * x) with context.quantized_scope(True): evaluate(x.initializer) assert evaluate(x.value()) == 7 assert evaluate(x.read_value()) == 7 assert x.trainable if version.parse(tf.__version__) > version.parse("1.14"): assert x.synchronization == x.latent_variable.synchronization assert x.aggregation == x.latent_variable.aggregation assert evaluate(x.initialized_value()) == 7 if not tf.executing_eagerly(): if not distribute_scope: # These functions are not supported for DistributedVariables x.load(4.5) assert x.eval() == 9 assert evaluate(x.initial_value) == 7 assert x.op == x.latent_variable.op assert x.graph == x.latent_variable.graph if not distribute_scope: # These attributes are not supported for DistributedVariables assert x.constraint is None assert x.initializer == x.latent_variable.initializer def apply_and_read(x, fn, args): evaluate(fn(*args)) return evaluate(x) assert apply_and_read(x, x.assign, [4]) == 8 assert apply_and_read(x, x.assign_add, [1]) == 10 assert apply_and_read(x, x.assign_sub, [1.5]) == 7 assert x.name == x.latent_variable.name assert x.device == x.latent_variable.device assert x.shape == () assert x.get_shape() == () try: x.set_shape(()) assert x.shape == () except NotImplementedError: pass
def getter(*args, **kwargs): variable = old_getter(*args, **kwargs) return QuantizedVariable.from_variable(variable, quantizer)
def test_inheritance(distribute_scope): variable = get_var(3.0) quantized_variable = QuantizedVariable.from_variable(variable) assert isinstance(quantized_variable, QuantizedVariable) assert isinstance(quantized_variable, tf.Variable) assert isinstance(quantized_variable, DistributedVariable) is distribute_scope # type: ignore