コード例 #1
0
def fused_bn_reduce_grad(data0,
                         data1,
                         data2,
                         data3,
                         data4,
                         data5,
                         data6,
                         data7,
                         layout='NHWC',
                         out_dtype='float16',
                         target=utils.CUDA):

    if layout == 'NCHW':
        data3 = topi.transpose(data3, (0, 2, 3, 1))
        data7 = topi.transpose(data7, (0, 2, 3, 1))
    elif layout != 'NHWC':
        raise NotImplementedError('Layout not supported {} '.format(layout))

    n, h, w, c = data3.shape
    const = n * h * w
    inter_dtype = 'float32'
    out1 = topi.multiply(data4, data5)
    out1 = topi.divide(out1, const)
    out1 = topi.expand_dims(out1, axis=0, num_newaxis=3)
    out1 = topi.broadcast_to(out1, (n, h, w, c))

    data3 = topi.cast(data3, inter_dtype)
    data2 = topi.expand_dims(data2, axis=0, num_newaxis=3)
    data2 = topi.broadcast_to(data2, (n, h, w, c))
    out2 = topi.multiply(data3, const)
    out2 = topi.subtract(out2, data2)

    data1 = topi.expand_dims(data1, axis=0, num_newaxis=3)
    data1 = topi.broadcast_to(data1, (n, h, w, c))
    data7 = topi.cast(data7, inter_dtype)
    out3 = topi.divide(data6, const)
    out3 = topi.subtract(data7, out3)
    out3 = topi.multiply(data1, out3)
    out3 = topi.divide(out3, data0)

    output = topi.subtract(out2, out3)
    output = topi.multiply(output, out1)

    output = topi.cast(output, out_dtype)

    if layout == "NCHW":
        output = topi.transpose(output, (0, 3, 1, 2))

    return output
コード例 #2
0
def fused_bn_follow(data0, data1, data2, data3, data4, target=utils.CUDA):
    """
    input:
    data: length is 5
    data0: param0 beta
    data1: param1 gamma
    data2: param2 BNupdate: xi_variance
    data3: param6 BNreduce: xi_mean
    data4: param7 xi_conv2d

    layout: (N, C, H, W)

    output:
    beta + gamma * xi_variance * ( xi -  xi_mean/(N*H*W) )
    """

    n, h, w, c = data4.shape
    const = n * h * w
    inter_dtype = 'float32'
    data4 = topi.cast(data4, inter_dtype)

    multiply0 = topi.divide(data3, const)
    multiply0 = topi.expand_dims(multiply0, axis=0, num_newaxis=3)
    multiply0 = topi.broadcast_to(multiply0, (n, h, w, c))

    subtract0 = topi.subtract(data4, multiply0)

    multiply1 = topi.multiply(subtract0, data2)
    multiply2 = topi.multiply(multiply1, data1)

    add0 = topi.add(multiply2, data0)

    return add0
コード例 #3
0
ファイル: topi.py プロジェクト: mindspore-ai/akg
def StridedSlice(inputs, attrs):
    in_tensor = inputs[0]
    shape = in_tensor.shape
    begin = list(attrs["begin"])
    end = list(attrs["end"])
    strides = list(attrs["strides"])
    slice_len = len(begin)

    begin_pos = [0] if "begin_mask" not in attrs else bin(
        int(attrs["begin_mask"]))[-1:1:-1]
    end_pos = [0] if "end_mask" not in attrs else bin(int(
        attrs["end_mask"]))[-1:1:-1]
    ellipsis_pos = [0] if "ellipsis_mask" not in attrs else bin(
        int(attrs["ellipsis_mask"]))[-1:1:-1]
    new_axis_pos = [0] if "new_axis_mask" not in attrs else bin(
        int(attrs["new_axis_mask"]))[-1:1:-1]
    shrink_axis_pos = [0] if "shrink_axis_mask" not in attrs else bin(
        int(attrs["shrink_axis_mask"]))[-1:1:-1]

    out_shape = []
    i, j = 0, 0
    has_ellipsis = False
    shrink_axes = []
    while i < slice_len or j < len(shape):
        if j >= slice_len or i >= slice_len:
            out_shape.append(shape[j])
            begin.append(0)
            end.append(shape[j])
            strides.append(1)
            i += 1
            j += 1
            continue
        if i < len(ellipsis_pos) and ellipsis_pos[i] == '1':
            out_shape.append(shape[j])
            begin[i] = 0
            end[i] = shape[j]
            strides[i] = 1
            i += 1
            j += 1
            continue
        if i < len(new_axis_pos) and new_axis_pos[i] == '1':
            out_shape.append(1)
            begin[i] = 1
            end[i] = 1
            strides[i] = 1
            in_tensor = akg_topi.expand_dims(in_tensor, i, 1)
            i += 1
            continue
        if i < len(shrink_axis_pos) and shrink_axis_pos[i] == '1':
            shrink_axes.append(i)
            i += 1
            j += 1
            continue
        if int(begin[i]) < 0:
            begin[i] += shape[j]
        if int(end[i]) < 0:
            end[i] += shape[j]
        if int(begin[i]) < 0:
            begin[i] = 0
        elif int(begin[i]) >= int(shape[j]):
            begin[i] = shape[j] - 1
        if int(end[i]) < 0:
            end[i] = -1
        elif int(end[i]) >= int(shape[j]):
            end[i] = shape[j]
        if i < len(begin_pos) and begin_pos[i] == '1':
            begin[i] = shape[j] - 1 if int(strides[i]) < 0 else 0
        if i < len(end_pos) and end_pos[i] == '1':
            end[i] = -1 if int(strides[i]) < 0 else shape[j]
        out_idx = (end[i] - begin[i]) // strides[i]
        if not int(out_idx * strides[i]) == int(end[i] - begin[i]):
            out_idx += 1
        out_shape.append(out_idx)
        i += 1
        j += 1

    def get_old_indices(indices, idx):
        old_indices = list(indices)
        old_indices.insert(idx, begin[idx])
        return old_indices

    def compute_func(in_tensor_, new_shape_, shrink_axis_):
        return tvm.compute(
            new_shape_, lambda *indices: in_tensor_(*get_old_indices(
                indices, shrink_axis_)))

    for shrink_axis in reversed(shrink_axes):
        new_shape = list(in_tensor.shape)
        new_shape.pop(shrink_axis)
        if not new_shape:
            return tvm.compute([1], lambda *i: in_tensor(
                *[b + idx * s for b, s, idx in zip(begin, strides, i)]))
        in_tensor = compute_func(in_tensor, new_shape, shrink_axis)
        begin.pop(shrink_axis)
        strides.pop(shrink_axis)
    return tvm.compute(
        out_shape, lambda *i: in_tensor(
            *[b + idx * s for b, s, idx in zip(begin, strides, i)]))