def grad_fn(inputs, variables, outputs, output_grads): del outputs # recomputing below with fn_device_dependency("diet_grad", output_grads[0].device) as out_dep: with tf.variable_scope(vs_ctr[0], reuse=True): outputs = fn(*inputs) variables = [variable_ref(v) for v in variables] dequantized_variables = [params.dequantized[v.name][-1] for v in variables] grads = tf.gradients(outputs, inputs + dequantized_variables, output_grads) grad_inputs = grads[:len(inputs)] grad_variables = grads[len(inputs):] opt = _create_diet_optimizer(params) # Apply grad_variables here var_updates = [] for v, dv in zip(variables, grad_variables): with tf.variable_scope(vs_ctr[0].name): opt.create_slots(v) update_op = opt.update_variable(v, dv) var_updates.append(update_op) with tf.control_dependencies(var_updates): grad_inputs = [tf.identity(dx) for dx in grad_inputs] out_dep.append(grad_inputs) return grad_inputs, [None] * len(variables)
def diet_var_initializer(shape, dtype, partition_info=None): del dtype del partition_info with fn_device_dependency("diet_init") as out_deps: float_range = math.sqrt(3) ret = tf.random_uniform(shape, -float_range, float_range) if params.quantize: ret = _quantize(ret, params, randomize=False) out_deps.append(ret) return ret
def _fn_with_diet_vars(fn, args, params): """Call function with args; use diet variables according to params.""" vs_ctr = [] def grad_fn(inputs, variables, outputs, output_grads): del outputs # recomputing below with fn_device_dependency("diet_grad", output_grads[0].device) as out_dep: with tf.variable_scope(vs_ctr[0], reuse=True): outputs = fn(*inputs) variables = [variable_ref(v) for v in variables] dequantized_variables = [ params.dequantized[v.name][-1] for v in variables ] grads = tf.gradients(outputs, inputs + dequantized_variables, output_grads) grad_inputs = grads[:len(inputs)] grad_variables = grads[len(inputs):] opt = _create_diet_optimizer(params) # Apply grad_variables here var_updates = [] for v, dv in zip(variables, grad_variables): with tf.variable_scope(vs_ctr[0].name): opt.create_slots(v) update_op = opt.update_variable(v, dv) var_updates.append(update_op) with tf.control_dependencies(var_updates): grad_inputs = [tf.identity(dx) for dx in grad_inputs] out_dep.append(grad_inputs) return grad_inputs, [None] * len(variables) @fn_with_custom_grad(grad_fn, use_global_vars=True) def forward(*inputs): with tf.variable_scope( None, default_name="diet", custom_getter=make_diet_var_getter(params)) as vs: vs_ctr.append(vs) outputs = fn(*inputs) return outputs with fn_device_dependency("diet_forward", args[0].device) as out_dep: outputs = forward(*args) out_dep.append(outputs) return outputs