def block_reduced_full_dw(param_grad, scale=1.0, norm="max", group_size=8): # max(abs()) or l2_norm() norm = 0 if norm.lower() == "max" else 1 # host side scalar, if zero will cause compute for this op to be skipped. scale = scalar_constant(scale, dtype=tf.float32) assert group_size <= 8 # backward walk param grad to find BlocksparseMatmulDW ops # this should only hit BlocksparseMatmulDWs, BlocksparseMatmulDGs, AddNs or FloatCasts ops = get_parents(param_grad, "BlocksparseMatmulDW") if len(ops) < 1: raise ValueError("BlocksparseMatmulDW op not found") # this sorting is dependent on the op names being correctly ordered. ops.sort(key=lambda op: op.name.split('/')[-1], reverse=True) # use the parent scope for the new ops scope = ops[-1].name.split('/') scope = '/'.join(scope[0:-1]) # we're going to be using absolute names, so clear name_scope with tf.name_scope(None): dw_full = None offset = 0 while offset < len(ops): xs = [op.inputs[0] for op in ops[offset:offset+group_size] ] gs = [op.inputs[1] for op in ops[offset:offset+group_size] ] # Get the corresponding activation grad op for the last param grad op in the group bprop = None for consumer in gs[-1].consumers(): if consumer.type == "BlocksparseMatmulDX": bprop = consumer break assert bprop is not None # get attributes of first op in group up = ops[offset] bsize = up.get_attr("bsize") axis = up.get_attr("axis") name = "%s/block_reduced_full_dw_%03d" % (scope, offset) dw_full = [] if dw_full is None else [dw_full] dw_full, _, _ = blocksparse_reduced_dw(xs, gs, scale, dw_full, bsize=bsize, norm=norm, axis=axis, name=name) # force the dw op before any more time steps are processed bprop._add_control_input(dw_full.op) offset += group_size return dw_full
def group_param_grads(param_grad, group_size=8, cast32=False): assert group_size <= 8 # backward walk param grad to find BlocksparseMatmulDW ops # this should only hit BlocksparseMatmulDWs or AddNs or FloatCasts ops = get_parents(param_grad, "BlocksparseMatmulDW") # this sorting is dependent on the op names being correctly ordered. ops.sort(key=lambda op: op.name.split('/')[-1], reverse=True) # for x in ops: # print(x.name) # print("") # exit() # use the parent scope for the new ops scope = ops[-1].name.split('/') scope = '/'.join(scope[0:-1]) # we're going to be using absolute names, so clear name_scope with tf.name_scope(None): offset = 0 # graph = tf.get_default_graph() while offset < len(ops): xs = [op.inputs[0] for op in ops[offset:offset + group_size]] gs = [op.inputs[1] for op in ops[offset:offset + group_size]] # Get the corresponding activation grad op for the last param grad op in the group bprop = None for op in gs[-1].consumers(): if op.type == "BlocksparseMatmulDX": bprop = op assert bprop is not None # get attributes of first op in group up = ops[offset] blocks = up.get_attr("blocks") bshift = up.get_attr("bshift") axis = up.get_attr("axis") dtype_dw = up.get_attr("dtype_dw") gated_dw = up.get_attr("gated_dw") C = up.get_attr("C") K = up.get_attr("K") bench = up.get_attr("bench") // len(xs) lut = up.inputs[2] name = "%s/matmul_concat_updat_%03d" % (scope, offset) gate = [up.inputs[3]] if len(op.inputs) > 3 else [] # The first op needs to allocate a new dw tensor if offset == 0: grad = blocksparse_matmul_dw(xs, gs, lut, gate, dtype_dw=dtype_dw, gated_dw=gated_dw, blocks=blocks, bshift=bshift, axis=axis, C=C, K=K, bench=bench, name=name) # subsequent ops can just accumulate in place else: grad = blocksparse_matmul_dwa(xs, gs, lut, grad, gate, gated_dw=gated_dw, blocks=blocks, bshift=bshift, axis=axis, C=C, K=K, bench=bench, name=name) # print(grad.op.name, grad.op.device) # force the dw op before any more time steps are processed add_control_input(bprop, grad.op) #print(grad.op.name) offset += group_size # get the grad back to float32 if requested # TODO: splice the graph instead of this hack if cast32 and dtype_dw != tf.float32: grad = ew.float_cast(grad, dtype=tf.float32) return grad
def group_param_grads(param_grad, group_size=8): assert group_size <= 8 # backward walk param grad to find BlocksparseMatmulDW ops # this should only hit BlocksparseMatmulDWs, BlocksparseMatmulDGs, AddNs or FloatCasts ops = get_parents(param_grad, "BlocksparseMatmulDW") if len(ops) <= 1: return param_grad # this sorting is dependent on the op names being correctly ordered. ops.sort(key=lambda op: op.name.split('/')[-1], reverse=True) # for x in ops: # print(x.name) # print("") # exit() segment_size = len(ops) if ops[0].get_attr("gate_grad") and len(ops[0].inputs) == 4: gate_count = dict() max_count = 0 for op in ops: gate = op.inputs[3] count = gate_count.get(gate, 0) + 1 gate_count[gate] = count max_count = max(max_count, count) for count in gate_count.values(): if count != max_count: raise ValueError("Non-uniform gate broadcasting detected.") segment_size = max_count if group_size > segment_size: group_size = segment_size else: assert segment_size % group_size == 0 # nothing to rewrite here. if segment_size == 1: return param_grad # use the parent scope for the new ops scope = ops[-1].name.split('/') scope = '/'.join(scope[0:-1]) # we're going to be using absolute names, so clear name_scope with tf.name_scope(None): dw = None dws = list() offset = 0 seg_cnt = 0 while offset < len(ops): xs = [op.inputs[0] for op in ops[offset:offset+group_size] ] gs = [op.inputs[1] for op in ops[offset:offset+group_size] ] # Get the corresponding activation grad op for the last param grad op in the group bprop = None for consumer in gs[-1].consumers(): if consumer.type == "BlocksparseMatmulDX": bprop = consumer break assert bprop is not None # get attributes of first op in group up = ops[offset] blocks = up.get_attr("blocks") bsize = up.get_attr("bsize") axis = up.get_attr("axis") gated_dw = up.get_attr("gated_dw") gate_grad = up.get_attr("gate_grad") C = up.get_attr("C") K = up.get_attr("K") bench = up.get_attr("bench") // len(xs) lut = up.inputs[2] name = "%s/matmul_concat_updat_%03d" % (scope, offset) gate = [up.inputs[3]] if len(up.inputs) > 3 else [] # The first op needs to allocate a new dw tensor if dw is None: dw = blocksparse_matmul_dw( xs, gs, lut, gate, gated_dw=gated_dw, gate_grad=gate_grad, blocks=blocks, bsize=bsize, axis=axis, C=C, K=K, bench=bench, name=name) # subsequent ops can just accumulate in place else: dw = blocksparse_matmul_dwa( xs, gs, lut, dw, gate, gated_dw=gated_dw, gate_grad=gate_grad, blocks=blocks, bsize=bsize, axis=axis, C=C, K=K, bench=bench, name=name) # force the dw op before any more time steps are processed bprop._add_control_input(dw.op) seg_cnt += group_size offset += group_size if gate_grad and seg_cnt >= segment_size: seg_cnt = 0 dws.append(dw) dw = None if gate_grad: for i, dw in enumerate(dws): # for op in ops[i*group_size:(i+1)*group_size]: # print(op.name) # print() dw_op = ops[i*segment_size:(i+1)*segment_size][-1] dws[i] = group_dg_grads(dw_op, dw, scope) # add up final dw values in groups of 4 for good mix of perforamnce and memory use dw = ew.add_n8_op(dws[0:4]) if len(dws) > 1 else dws[0] for i in range(4, len(dws), 4): dw = ew.add_n8_op(dws[i:i+4] + [dw]) # splice in these grad op types sitting on top of the param if param_grad.op.type in ("Cast", "FloatCast", "L2NormalizeGradCK", "L2NormalizeGainGradCK"): param_grad.op._update_input(0, dw) dw = param_grad elif param_grad.op.type not in ("AddN", "AddN8", "BlocksparseMatmulDW","BlocksparseMatmulDG"): raise ValueError("Unexpected grad op type:", param_grad.op.type, param_grad.op.name) return dw
def group_lstm_grads(grads, params, scope="grouped_lstm", group_size=None): grad = None grad_idx = None for i, (g, p) in enumerate(zip(grads, params)): if scope in p.name and "kernel" in p.name: grad = g grad_idx = i break assert grad is not None # backward walk param grad to find dw MatMul ops # walk should terminate with each MatMul op ops = list() wave = set([grad.op]) while wave: new_wave = set() for op in wave: for op in (t.op for t in op.inputs): # TN MatMul ops if op.type == "MatMul" and op.get_attr( "transpose_a") and not op.get_attr("transpose_b"): ops.append(op) else: new_wave.add(op) wave = new_wave # sort op names descending and split out the lstms (if weights are shared) last_lstm = None lstms = list() ops.sort(key=lambda op: op.name, reverse=True) for op in ops: # gradients/grouped_lstm/lstm_2/step_00_grad/MatMul_1 => lstm_2 lstm = op.name.split("/")[-3] if last_lstm != lstm: lstms.insert(0, list()) last_lstm = lstm lstms[0].append(op) # we're going to be using absolute names, so clear name_scope with tf.name_scope(None): lstm_grads = list() for lstm_ops in lstms: # default dw op to one big matmul per lstm if group_size is None: group_size = len(lstm_ops) # use the lstm scope for the new ops # gradients/grouped_lstm/lstm_2/step_00_grad/MatMul_1 => gradients/grouped_lstm/lstm_2 scope = lstm_ops[-1].name.split('/') scope = '/'.join(scope[0:-2]) offset = 0 while offset < len(lstm_ops): xs = tf.concat([ op.inputs[0] for op in lstm_ops[offset:offset + group_size] ], axis=0) gs = tf.concat([ op.inputs[1] for op in lstm_ops[offset:offset + group_size] ], axis=0) mmop = tf.matmul(xs, gs, transpose_a=True, transpose_b=False, name="%s/dw_%04d" % (scope, offset)) grad = mmop if offset == 0 else ew.add( grad, mmop, name="%s/add_%04d" % (scope, offset)) offset += group_size lstm_grads.append(grad) if len(lstms) > 1: from blocksparse.ewops import add_n # gradients/grouped_lstm/lstm_2/step_00_grad/MatMul_1 => gradients/grouped_lstm scope = lstms[0][-1].name.split('/') scope = '/'.join(scope[0:-3]) grads[grad_idx] = tf.add_n(lstm_grads, name="%s/add_n" % scope) else: grads[grad_idx] = lstm_grads[0] #grads modified in place # lstm_scopes = dict() # # rediculous amount of code just to be able to re-enter a variable scope without its name being re-numbered. # # https://github.com/tensorflow/tensorflow/pull/14390 # global lstm_scopes # if scope not in lstm_scopes: # with tf.variable_scope(scope) as lstm_scope: # lstm_scopes[scope] = lstm_scope # lstm_scope = lstm_scopes[scope] # with tf.variable_scope(lstm_scope, auxiliary_name_scope=False), tf.name_scope(lstm_scope.original_name_scope): # with tf.variable_scope(weights_scope, reuse=weights_reuse): # w = tf.get_variable('kernel', shape=[in_width + width, 4 * width]) # if bias_scope is None: # b = tf.get_variable('bias', shape=[4 * width]) # if layernorm: # g = tf.get_variable('gain', shape=[4 * width]) # if bias_scope is not None: # with tf.variable_scope(bias_scope, reuse=bias_reuse): # b = tf.get_variable('bias', shape=[4 * width]) # if layernorm: # g = tf.get_variable('gain', shape=[4 * width])
def group_lstm_grads(grads, params, scope="grouped_lstm", group_size=None): grad = None grad_idx = None for i, (g, p) in enumerate(zip(grads, params)): if scope in p.name and "kernel" in p.name: grad = g grad_idx = i break assert grad is not None # backward walk param grad to find dw MatMul ops # walk should terminate with each MatMul op ops = list() wave = set([grad.op]) while wave: new_wave = set() for op in wave: for op in (t.op for t in op.inputs): # TN MatMul ops if op.type == "MatMul" and op.get_attr( "transpose_a") and not op.get_attr("transpose_b"): ops.append(op) else: new_wave.add(op) wave = new_wave # sort op names descending and split out the lstms (if weights are shared) last_lstm = None lstms = list() ops.sort(key=lambda op: op.name, reverse=True) for op in ops: # gradients/grouped_lstm/lstm_2/step_00_grad/MatMul_1 => lstm_2 lstm = op.name.split("/")[-3] if last_lstm != lstm: lstms.insert(0, list()) last_lstm = lstm lstms[0].append(op) # we're going to be using absolute names, so clear name_scope with tf.name_scope(None): lstm_grads = list() for lstm_ops in lstms: # default dw op to one big matmul per lstm if group_size is None: group_size = len(lstm_ops) # use the lstm scope for the new ops # gradients/grouped_lstm/lstm_2/step_00_grad/MatMul_1 => gradients/grouped_lstm/lstm_2 scope = lstm_ops[-1].name.split('/') scope = '/'.join(scope[0:-2]) offset = 0 while offset < len(lstm_ops): xs = tf.concat([ op.inputs[0] for op in lstm_ops[offset:offset + group_size] ], axis=0) gs = tf.concat([ op.inputs[1] for op in lstm_ops[offset:offset + group_size] ], axis=0) mmop = tf.matmul(xs, gs, transpose_a=True, transpose_b=False, name="%s/dw_%04d" % (scope, offset)) grad = mmop if offset == 0 else ew.add( grad, mmop, name="%s/add_%04d" % (scope, offset)) offset += group_size lstm_grads.append(grad) if len(lstms) > 1: from blocksparse.ewops import add_n # gradients/grouped_lstm/lstm_2/step_00_grad/MatMul_1 => gradients/grouped_lstm scope = lstms[0][-1].name.split('/') scope = '/'.join(scope[0:-3]) grads[grad_idx] = tf.add_n(lstm_grads, name="%s/add_n" % scope) else: grads[grad_idx] = lstm_grads[0]