def grad_fn(inputs, variables, outputs, output_grads): """Custom gradient function.""" del outputs # recomputing below with common_layers.fn_device_dependency("diet_grad", output_grads[0].device) as out_dep: with tf.variable_scope(vs_ctr[0], reuse=True): outputs = fn(*inputs) variables = [common_layers.underlying_variable_ref(v) for v in variables] dequantized_variables = [ params.dequantized[v.name][-1] for v in variables ] grads = tf.gradients(outputs, inputs + dequantized_variables, output_grads) grad_inputs = grads[:len(inputs)] grad_variables = grads[len(inputs):] opt = _create_diet_optimizer(params) # Apply grad_variables here var_updates = [] for v, dv in zip(variables, grad_variables): with tf.variable_scope(vs_ctr[0].name): opt.create_slots(v) update_op = opt.update_variable(v, dv) var_updates.append(update_op) with tf.control_dependencies(var_updates): grad_inputs = [tf.identity(dx) for dx in grad_inputs] out_dep.append(grad_inputs) return grad_inputs, [None] * len(variables)
def grad_fn(inputs, variables, outputs, output_grads): del outputs # recomputing below with common_layers.fn_device_dependency( "diet_grad", output_grads[0].device) as out_dep: with tf.variable_scope(vs_ctr[0], reuse=True): outputs = fn(*inputs) variables = [ common_layers.underlying_variable_ref(v) for v in variables ] dequantized_variables = [ params.dequantized[v.name][-1] for v in variables ] grads = tf.gradients(outputs, inputs + dequantized_variables, output_grads) grad_inputs = grads[:len(inputs)] grad_variables = grads[len(inputs):] opt = _create_diet_optimizer(params) # Apply grad_variables here var_updates = [] for v, dv in zip(variables, grad_variables): with tf.variable_scope(vs_ctr[0].name): opt.create_slots(v) update_op = opt.update_variable(v, dv) var_updates.append(update_op) with tf.control_dependencies(var_updates): grad_inputs = [tf.identity(dx) for dx in grad_inputs] out_dep.append(grad_inputs) return grad_inputs, [None] * len(variables)
def custom_grad_fn(inputs, variables, ys, grad_ys): """Custom gradient fn for a block of reversible residual layers.""" side_inputs = inputs[2:] f_side_idxs = [None] * len(f_side_input) g_side_idxs = [None] * len(g_side_input) assert len(side_inputs) == len(f_side_input) + len(g_side_input) for i, t in enumerate(side_inputs): if t in f_side_input: f_side_idxs[f_side_input.index(t)] = i elif t in g_side_input: g_side_idxs[g_side_input.index(t)] = i else: assert False f_vars = [[] for _ in range(num_layers)] g_vars = [[] for _ in range(num_layers)] f_vars_idxs = [[] for _ in range(num_layers)] g_vars_idxs = [[] for _ in range(num_layers)] for i, t in enumerate(variables): ref = common_layers.underlying_variable_ref(t) # Use the name to identify the layer number and function (f or g) regex = LAYER_RE.match(ref.name) layer_no = int(regex.group(1)) fn_name = regex.group(2) if fn_name == "f": f_vars[layer_no].append(ref) f_vars_idxs[layer_no].append(i) else: assert fn_name == "g" g_vars[layer_no].append(ref) g_vars_idxs[layer_no].append(i) f_var_grads = [] g_var_grads = [] f_side_grads = [] g_side_grads = [] # Reverse variable containers to go backward layer_scopes.reverse() f_vars.reverse() g_vars.reverse() f.reverse() g.reverse() for i in xrange(num_layers): with tf.variable_scope(layer_scopes[i], reuse=True): ys, grad_ys, f_ret, g_ret = _rev_layer_backward( ys, grad_ys, f[i], g[i], f_vars[i], f_side_input, g_vars[i], g_side_input) grad_f_vars, grad_f_side = f_ret grad_g_vars, grad_g_side = g_ret f_var_grads.append(grad_f_vars) g_var_grads.append(grad_g_vars) f_side_grads.append(grad_f_side) g_side_grads.append(grad_g_side) # Accumulate layer gradients for f_side_input and g_side_input acc_f_side_grads = _acc_grads(*f_side_grads) acc_g_side_grads = _acc_grads(*g_side_grads) # Use the stored idxs to put gradients in the passed-in order. side_input_grads = [None] * len(side_inputs) variable_grads = [None] * len(variables) # Variable gradients were collected in reverse layer order. Reverse to match # idxs. f_var_grads.reverse() g_var_grads.reverse() for idxs, grads in list(zip(f_vars_idxs, f_var_grads)) + list( zip(g_vars_idxs, g_var_grads)): for i, grad in zip(idxs, grads): variable_grads[i] = grad for i, grad in zip(f_side_idxs, acc_f_side_grads): side_input_grads[i] = grad for i, grad in zip(g_side_idxs, acc_g_side_grads): side_input_grads[i] = grad grad_x1, grad_x2 = grad_ys return [grad_x1, grad_x2] + side_input_grads, variable_grads
def _efficient_grad_fn(self, inputs, variables, ys, grad_ys): """Custom gradient fn for a block of reversible residual layers.""" side_inputs = inputs[2:] f_side_idxs = [None] * len(self.f_side_input) g_side_idxs = [None] * len(self.g_side_input) assert len(side_inputs) == len(self.f_side_input) + len(self.g_side_input) for i, t in enumerate(side_inputs): if t in self.f_side_input: f_side_idxs[self.f_side_input.index(t)] = i elif t in self.g_side_input: g_side_idxs[self.g_side_input.index(t)] = i else: assert False f_vars = [[] for _ in range(self.num_layers)] g_vars = [[] for _ in range(self.num_layers)] f_vars_idxs = [[] for _ in range(self.num_layers)] g_vars_idxs = [[] for _ in range(self.num_layers)] for i, t in enumerate(variables): ref = common_layers.underlying_variable_ref(t) # Use the name to identify the layer number and function (f or g) regex = LAYER_RE.match(ref.name) layer_no = int(regex.group(1)) fn_name = regex.group(2) if fn_name == "f": f_vars[layer_no].append(ref) f_vars_idxs[layer_no].append(i) else: assert fn_name == "g" g_vars[layer_no].append(ref) g_vars_idxs[layer_no].append(i) f_var_grads = [] g_var_grads = [] f_side_grads = [] g_side_grads = [] # Reverse variable containers to go backward f_vars.reverse() g_vars.reverse() f = list(self.f) g = list(self.g) f.reverse() g.reverse() for i in xrange(self.num_layers): ys, grad_ys, f_ret, g_ret = _rev_layer_backward( ys, grad_ys, f[i], g[i], f_vars[i], self.f_side_input, g_vars[i], self.g_side_input) grad_f_vars, grad_f_side = f_ret grad_g_vars, grad_g_side = g_ret f_var_grads.append(grad_f_vars) g_var_grads.append(grad_g_vars) f_side_grads.append(grad_f_side) g_side_grads.append(grad_g_side) # Accumulate layer gradients for f_side_input and g_side_input acc_f_side_grads = _acc_grads(*f_side_grads) acc_g_side_grads = _acc_grads(*g_side_grads) # Use the stored idxs to put gradients in the passed-in order. side_input_grads = [None] * len(side_inputs) variable_grads = [None] * len(variables) # Variable gradients were collected in reverse layer order. Reverse to match # idxs. f_var_grads.reverse() g_var_grads.reverse() for idxs, grads in list(zip(f_vars_idxs, f_var_grads)) + list( zip(g_vars_idxs, g_var_grads)): for i, grad in zip(idxs, grads): variable_grads[i] = grad for i, grad in zip(f_side_idxs, acc_f_side_grads): side_input_grads[i] = grad for i, grad in zip(g_side_idxs, acc_g_side_grads): side_input_grads[i] = grad grad_x1, grad_x2 = grad_ys return [grad_x1, grad_x2] + side_input_grads, variable_grads