Example #1
0
def block_reduced_full_dw(param_grad, scale=1.0, norm="max", group_size=8):

    # max(abs()) or l2_norm()
    norm  = 0 if norm.lower() == "max" else 1
    # host side scalar, if zero will cause compute for this op to be skipped.
    scale = scalar_constant(scale, dtype=tf.float32)

    assert group_size <= 8

    # backward walk param grad to find BlocksparseMatmulDW ops
    # this should only hit BlocksparseMatmulDWs, BlocksparseMatmulDGs, AddNs or FloatCasts
    ops = get_parents(param_grad, "BlocksparseMatmulDW")
    if len(ops) < 1:
        raise ValueError("BlocksparseMatmulDW op not found")

    # this sorting is dependent on the op names being correctly ordered.
    ops.sort(key=lambda op: op.name.split('/')[-1], reverse=True)

    # use the parent scope for the new ops
    scope = ops[-1].name.split('/')
    scope = '/'.join(scope[0:-1])

    # we're going to be using absolute names, so clear name_scope
    with tf.name_scope(None):
        dw_full = None
        offset  = 0
        while offset < len(ops):

            xs = [op.inputs[0] for op in ops[offset:offset+group_size] ]
            gs = [op.inputs[1] for op in ops[offset:offset+group_size] ]

            # Get the corresponding activation grad op for the last param grad op in the group
            bprop = None
            for consumer in gs[-1].consumers():
                if consumer.type == "BlocksparseMatmulDX":
                    bprop = consumer
                    break
            assert bprop is not None

            # get attributes of first op in group
            up    = ops[offset]
            bsize = up.get_attr("bsize")
            axis  = up.get_attr("axis")
            name  = "%s/block_reduced_full_dw_%03d" % (scope, offset)
            dw_full = [] if dw_full is None else [dw_full]

            dw_full, _, _ = blocksparse_reduced_dw(xs, gs, scale, dw_full, bsize=bsize, norm=norm, axis=axis, name=name)

            # force the dw op before any more time steps are processed
            bprop._add_control_input(dw_full.op)

            offset += group_size

    return dw_full
Example #2
0
def group_param_grads(param_grad, group_size=8, cast32=False):

    assert group_size <= 8

    # backward walk param grad to find BlocksparseMatmulDW ops
    # this should only hit BlocksparseMatmulDWs or AddNs or FloatCasts
    ops = get_parents(param_grad, "BlocksparseMatmulDW")

    # this sorting is dependent on the op names being correctly ordered.
    ops.sort(key=lambda op: op.name.split('/')[-1], reverse=True)
    # for x in ops:
    #     print(x.name)
    # print("")
    # exit()

    # use the parent scope for the new ops
    scope = ops[-1].name.split('/')
    scope = '/'.join(scope[0:-1])

    # we're going to be using absolute names, so clear name_scope
    with tf.name_scope(None):
        offset = 0
        # graph  = tf.get_default_graph()
        while offset < len(ops):

            xs = [op.inputs[0] for op in ops[offset:offset + group_size]]
            gs = [op.inputs[1] for op in ops[offset:offset + group_size]]

            # Get the corresponding activation grad op for the last param grad op in the group
            bprop = None
            for op in gs[-1].consumers():
                if op.type == "BlocksparseMatmulDX":
                    bprop = op
            assert bprop is not None

            # get attributes of first op in group
            up = ops[offset]
            blocks = up.get_attr("blocks")
            bshift = up.get_attr("bshift")
            axis = up.get_attr("axis")
            dtype_dw = up.get_attr("dtype_dw")
            gated_dw = up.get_attr("gated_dw")
            C = up.get_attr("C")
            K = up.get_attr("K")
            bench = up.get_attr("bench") // len(xs)
            lut = up.inputs[2]
            name = "%s/matmul_concat_updat_%03d" % (scope, offset)
            gate = [up.inputs[3]] if len(op.inputs) > 3 else []

            # The first op needs to allocate a new dw tensor
            if offset == 0:
                grad = blocksparse_matmul_dw(xs,
                                             gs,
                                             lut,
                                             gate,
                                             dtype_dw=dtype_dw,
                                             gated_dw=gated_dw,
                                             blocks=blocks,
                                             bshift=bshift,
                                             axis=axis,
                                             C=C,
                                             K=K,
                                             bench=bench,
                                             name=name)
            # subsequent ops can just accumulate in place
            else:
                grad = blocksparse_matmul_dwa(xs,
                                              gs,
                                              lut,
                                              grad,
                                              gate,
                                              gated_dw=gated_dw,
                                              blocks=blocks,
                                              bshift=bshift,
                                              axis=axis,
                                              C=C,
                                              K=K,
                                              bench=bench,
                                              name=name)

            # print(grad.op.name, grad.op.device)

            # force the dw op before any more time steps are processed
            add_control_input(bprop, grad.op)

            #print(grad.op.name)

            offset += group_size

    # get the grad back to float32 if requested
    # TODO: splice the graph instead of this hack
    if cast32 and dtype_dw != tf.float32:
        grad = ew.float_cast(grad, dtype=tf.float32)

    return grad
Example #3
0
def group_param_grads(param_grad, group_size=8):

    assert group_size <= 8

    # backward walk param grad to find BlocksparseMatmulDW ops
    # this should only hit BlocksparseMatmulDWs, BlocksparseMatmulDGs, AddNs or FloatCasts
    ops = get_parents(param_grad, "BlocksparseMatmulDW")

    if len(ops) <= 1:
        return param_grad

    # this sorting is dependent on the op names being correctly ordered.
    ops.sort(key=lambda op: op.name.split('/')[-1], reverse=True)
    # for x in ops:
    #     print(x.name)
    # print("")
    # exit()
    segment_size = len(ops)
    if ops[0].get_attr("gate_grad") and len(ops[0].inputs) == 4:
        gate_count = dict()
        max_count  = 0
        for op in ops:
            gate  = op.inputs[3]
            count = gate_count.get(gate, 0) + 1
            gate_count[gate] = count
            max_count = max(max_count, count)
        for count in gate_count.values():
            if count != max_count:
                raise ValueError("Non-uniform gate broadcasting detected.")
        segment_size = max_count
        if  group_size > segment_size:
            group_size = segment_size
        else:
            assert segment_size % group_size == 0
        # nothing to rewrite here.
        if segment_size == 1:
            return param_grad

    # use the parent scope for the new ops
    scope = ops[-1].name.split('/')
    scope = '/'.join(scope[0:-1])

    # we're going to be using absolute names, so clear name_scope
    with tf.name_scope(None):
        dw  = None
        dws = list()
        offset  = 0
        seg_cnt = 0
        while offset < len(ops):

            xs = [op.inputs[0] for op in ops[offset:offset+group_size] ]
            gs = [op.inputs[1] for op in ops[offset:offset+group_size] ]

            # Get the corresponding activation grad op for the last param grad op in the group
            bprop = None
            for consumer in gs[-1].consumers():
                if consumer.type == "BlocksparseMatmulDX":
                    bprop = consumer
                    break
            assert bprop is not None

            # get attributes of first op in group
            up = ops[offset]
            blocks    = up.get_attr("blocks")
            bsize     = up.get_attr("bsize")
            axis      = up.get_attr("axis")
            gated_dw  = up.get_attr("gated_dw")
            gate_grad = up.get_attr("gate_grad")
            C         = up.get_attr("C")
            K         = up.get_attr("K")
            bench     = up.get_attr("bench") // len(xs)
            lut       = up.inputs[2]
            name      = "%s/matmul_concat_updat_%03d" % (scope, offset)
            gate      = [up.inputs[3]] if len(up.inputs) > 3 else []

            # The first op needs to allocate a new dw tensor
            if dw is None:
                dw = blocksparse_matmul_dw(
                    xs, gs, lut, gate, gated_dw=gated_dw,
                    gate_grad=gate_grad, blocks=blocks, bsize=bsize, axis=axis,
                    C=C, K=K, bench=bench, name=name)
            # subsequent ops can just accumulate in place
            else:
                dw = blocksparse_matmul_dwa(
                    xs, gs, lut, dw, gate, gated_dw=gated_dw,
                    gate_grad=gate_grad, blocks=blocks, bsize=bsize, axis=axis,
                    C=C, K=K, bench=bench, name=name)

            # force the dw op before any more time steps are processed
            bprop._add_control_input(dw.op)

            seg_cnt += group_size
            offset  += group_size

            if gate_grad and seg_cnt >= segment_size:
                seg_cnt = 0
                dws.append(dw)
                dw = None

        if gate_grad:
            for i, dw in enumerate(dws):
                # for op in ops[i*group_size:(i+1)*group_size]:
                #     print(op.name)
                # print()
                dw_op  = ops[i*segment_size:(i+1)*segment_size][-1]
                dws[i] = group_dg_grads(dw_op, dw, scope)

            # add up final dw values in groups of 4 for good mix of perforamnce and memory use
            dw = ew.add_n8_op(dws[0:4]) if len(dws) > 1 else dws[0]
            for i in range(4, len(dws), 4):
                dw = ew.add_n8_op(dws[i:i+4] + [dw])

    # splice in these grad op types sitting on top of the param
    if param_grad.op.type in ("Cast", "FloatCast", "L2NormalizeGradCK", "L2NormalizeGainGradCK"):
        param_grad.op._update_input(0, dw)
        dw = param_grad
    elif param_grad.op.type not in ("AddN", "AddN8", "BlocksparseMatmulDW","BlocksparseMatmulDG"):
        raise ValueError("Unexpected grad op type:", param_grad.op.type, param_grad.op.name)

    return dw
Example #4
0
def group_lstm_grads(grads, params, scope="grouped_lstm", group_size=None):

    grad = None
    grad_idx = None
    for i, (g, p) in enumerate(zip(grads, params)):
        if scope in p.name and "kernel" in p.name:
            grad = g
            grad_idx = i
            break
    assert grad is not None

    # backward walk param grad to find dw MatMul ops
    # walk should terminate with each MatMul op
    ops = list()
    wave = set([grad.op])
    while wave:
        new_wave = set()
        for op in wave:
            for op in (t.op for t in op.inputs):
                # TN MatMul ops
                if op.type == "MatMul" and op.get_attr(
                        "transpose_a") and not op.get_attr("transpose_b"):
                    ops.append(op)
                else:
                    new_wave.add(op)
        wave = new_wave

    # sort op names descending and split out the lstms (if weights are shared)
    last_lstm = None
    lstms = list()
    ops.sort(key=lambda op: op.name, reverse=True)
    for op in ops:
        # gradients/grouped_lstm/lstm_2/step_00_grad/MatMul_1 => lstm_2
        lstm = op.name.split("/")[-3]
        if last_lstm != lstm:
            lstms.insert(0, list())
            last_lstm = lstm
        lstms[0].append(op)

    # we're going to be using absolute names, so clear name_scope
    with tf.name_scope(None):

        lstm_grads = list()
        for lstm_ops in lstms:

            # default dw op to one big matmul per lstm
            if group_size is None:
                group_size = len(lstm_ops)

            # use the lstm scope for the new ops
            # gradients/grouped_lstm/lstm_2/step_00_grad/MatMul_1 => gradients/grouped_lstm/lstm_2
            scope = lstm_ops[-1].name.split('/')
            scope = '/'.join(scope[0:-2])

            offset = 0
            while offset < len(lstm_ops):

                xs = tf.concat([
                    op.inputs[0] for op in lstm_ops[offset:offset + group_size]
                ],
                               axis=0)
                gs = tf.concat([
                    op.inputs[1] for op in lstm_ops[offset:offset + group_size]
                ],
                               axis=0)

                mmop = tf.matmul(xs,
                                 gs,
                                 transpose_a=True,
                                 transpose_b=False,
                                 name="%s/dw_%04d" % (scope, offset))
                grad = mmop if offset == 0 else ew.add(
                    grad, mmop, name="%s/add_%04d" % (scope, offset))

                offset += group_size

            lstm_grads.append(grad)

        if len(lstms) > 1:
            from blocksparse.ewops import add_n
            # gradients/grouped_lstm/lstm_2/step_00_grad/MatMul_1 => gradients/grouped_lstm
            scope = lstms[0][-1].name.split('/')
            scope = '/'.join(scope[0:-3])
            grads[grad_idx] = tf.add_n(lstm_grads, name="%s/add_n" % scope)
        else:
            grads[grad_idx] = lstm_grads[0]

    #grads modified in place


# lstm_scopes = dict()
# # rediculous amount of code just to be able to re-enter a variable scope without its name being re-numbered.
# # https://github.com/tensorflow/tensorflow/pull/14390
# global lstm_scopes
# if scope not in lstm_scopes:
#     with tf.variable_scope(scope) as lstm_scope:
#         lstm_scopes[scope] = lstm_scope
# lstm_scope = lstm_scopes[scope]

# with tf.variable_scope(lstm_scope, auxiliary_name_scope=False), tf.name_scope(lstm_scope.original_name_scope):
#     with tf.variable_scope(weights_scope, reuse=weights_reuse):
#         w = tf.get_variable('kernel', shape=[in_width + width, 4 * width])
#         if bias_scope is None:
#             b = tf.get_variable('bias', shape=[4 * width])
#             if layernorm:
#                 g = tf.get_variable('gain', shape=[4 * width])

#     if bias_scope is not None:
#         with tf.variable_scope(bias_scope, reuse=bias_reuse):
#             b = tf.get_variable('bias', shape=[4 * width])
#             if layernorm:
#                 g = tf.get_variable('gain', shape=[4 * width])
Example #5
0
def group_lstm_grads(grads, params, scope="grouped_lstm", group_size=None):

    grad = None
    grad_idx = None
    for i, (g, p) in enumerate(zip(grads, params)):
        if scope in p.name and "kernel" in p.name:
            grad = g
            grad_idx = i
            break
    assert grad is not None

    # backward walk param grad to find dw MatMul ops
    # walk should terminate with each MatMul op
    ops = list()
    wave = set([grad.op])
    while wave:
        new_wave = set()
        for op in wave:
            for op in (t.op for t in op.inputs):
                # TN MatMul ops
                if op.type == "MatMul" and op.get_attr(
                        "transpose_a") and not op.get_attr("transpose_b"):
                    ops.append(op)
                else:
                    new_wave.add(op)
        wave = new_wave

    # sort op names descending and split out the lstms (if weights are shared)
    last_lstm = None
    lstms = list()
    ops.sort(key=lambda op: op.name, reverse=True)
    for op in ops:
        # gradients/grouped_lstm/lstm_2/step_00_grad/MatMul_1 => lstm_2
        lstm = op.name.split("/")[-3]
        if last_lstm != lstm:
            lstms.insert(0, list())
            last_lstm = lstm
        lstms[0].append(op)

    # we're going to be using absolute names, so clear name_scope
    with tf.name_scope(None):

        lstm_grads = list()
        for lstm_ops in lstms:

            # default dw op to one big matmul per lstm
            if group_size is None:
                group_size = len(lstm_ops)

            # use the lstm scope for the new ops
            # gradients/grouped_lstm/lstm_2/step_00_grad/MatMul_1 => gradients/grouped_lstm/lstm_2
            scope = lstm_ops[-1].name.split('/')
            scope = '/'.join(scope[0:-2])

            offset = 0
            while offset < len(lstm_ops):

                xs = tf.concat([
                    op.inputs[0] for op in lstm_ops[offset:offset + group_size]
                ],
                               axis=0)
                gs = tf.concat([
                    op.inputs[1] for op in lstm_ops[offset:offset + group_size]
                ],
                               axis=0)

                mmop = tf.matmul(xs,
                                 gs,
                                 transpose_a=True,
                                 transpose_b=False,
                                 name="%s/dw_%04d" % (scope, offset))
                grad = mmop if offset == 0 else ew.add(
                    grad, mmop, name="%s/add_%04d" % (scope, offset))

                offset += group_size

            lstm_grads.append(grad)

        if len(lstms) > 1:
            from blocksparse.ewops import add_n
            # gradients/grouped_lstm/lstm_2/step_00_grad/MatMul_1 => gradients/grouped_lstm
            scope = lstms[0][-1].name.split('/')
            scope = '/'.join(scope[0:-3])
            grads[grad_idx] = tf.add_n(lstm_grads, name="%s/add_n" % scope)
        else:
            grads[grad_idx] = lstm_grads[0]