def fused_bn_reduce_grad(data0, data1, data2, data3, data4, data5, data6, data7, layout='NHWC', out_dtype='float16', target=utils.CUDA): if layout == 'NCHW': data3 = topi.transpose(data3, (0, 2, 3, 1)) data7 = topi.transpose(data7, (0, 2, 3, 1)) elif layout != 'NHWC': raise NotImplementedError('Layout not supported {} '.format(layout)) n, h, w, c = data3.shape const = n * h * w inter_dtype = 'float32' out1 = topi.multiply(data4, data5) out1 = topi.divide(out1, const) out1 = topi.expand_dims(out1, axis=0, num_newaxis=3) out1 = topi.broadcast_to(out1, (n, h, w, c)) data3 = topi.cast(data3, inter_dtype) data2 = topi.expand_dims(data2, axis=0, num_newaxis=3) data2 = topi.broadcast_to(data2, (n, h, w, c)) out2 = topi.multiply(data3, const) out2 = topi.subtract(out2, data2) data1 = topi.expand_dims(data1, axis=0, num_newaxis=3) data1 = topi.broadcast_to(data1, (n, h, w, c)) data7 = topi.cast(data7, inter_dtype) out3 = topi.divide(data6, const) out3 = topi.subtract(data7, out3) out3 = topi.multiply(data1, out3) out3 = topi.divide(out3, data0) output = topi.subtract(out2, out3) output = topi.multiply(output, out1) output = topi.cast(output, out_dtype) if layout == "NCHW": output = topi.transpose(output, (0, 3, 1, 2)) return output
def fused_bn_follow(data0, data1, data2, data3, data4, target=utils.CUDA): """ input: data: length is 5 data0: param0 beta data1: param1 gamma data2: param2 BNupdate: xi_variance data3: param6 BNreduce: xi_mean data4: param7 xi_conv2d layout: (N, C, H, W) output: beta + gamma * xi_variance * ( xi - xi_mean/(N*H*W) ) """ n, h, w, c = data4.shape const = n * h * w inter_dtype = 'float32' data4 = topi.cast(data4, inter_dtype) multiply0 = topi.divide(data3, const) multiply0 = topi.expand_dims(multiply0, axis=0, num_newaxis=3) multiply0 = topi.broadcast_to(multiply0, (n, h, w, c)) subtract0 = topi.subtract(data4, multiply0) multiply1 = topi.multiply(subtract0, data2) multiply2 = topi.multiply(multiply1, data1) add0 = topi.add(multiply2, data0) return add0
def StridedSlice(inputs, attrs): in_tensor = inputs[0] shape = in_tensor.shape begin = list(attrs["begin"]) end = list(attrs["end"]) strides = list(attrs["strides"]) slice_len = len(begin) begin_pos = [0] if "begin_mask" not in attrs else bin( int(attrs["begin_mask"]))[-1:1:-1] end_pos = [0] if "end_mask" not in attrs else bin(int( attrs["end_mask"]))[-1:1:-1] ellipsis_pos = [0] if "ellipsis_mask" not in attrs else bin( int(attrs["ellipsis_mask"]))[-1:1:-1] new_axis_pos = [0] if "new_axis_mask" not in attrs else bin( int(attrs["new_axis_mask"]))[-1:1:-1] shrink_axis_pos = [0] if "shrink_axis_mask" not in attrs else bin( int(attrs["shrink_axis_mask"]))[-1:1:-1] out_shape = [] i, j = 0, 0 has_ellipsis = False shrink_axes = [] while i < slice_len or j < len(shape): if j >= slice_len or i >= slice_len: out_shape.append(shape[j]) begin.append(0) end.append(shape[j]) strides.append(1) i += 1 j += 1 continue if i < len(ellipsis_pos) and ellipsis_pos[i] == '1': out_shape.append(shape[j]) begin[i] = 0 end[i] = shape[j] strides[i] = 1 i += 1 j += 1 continue if i < len(new_axis_pos) and new_axis_pos[i] == '1': out_shape.append(1) begin[i] = 1 end[i] = 1 strides[i] = 1 in_tensor = akg_topi.expand_dims(in_tensor, i, 1) i += 1 continue if i < len(shrink_axis_pos) and shrink_axis_pos[i] == '1': shrink_axes.append(i) i += 1 j += 1 continue if int(begin[i]) < 0: begin[i] += shape[j] if int(end[i]) < 0: end[i] += shape[j] if int(begin[i]) < 0: begin[i] = 0 elif int(begin[i]) >= int(shape[j]): begin[i] = shape[j] - 1 if int(end[i]) < 0: end[i] = -1 elif int(end[i]) >= int(shape[j]): end[i] = shape[j] if i < len(begin_pos) and begin_pos[i] == '1': begin[i] = shape[j] - 1 if int(strides[i]) < 0 else 0 if i < len(end_pos) and end_pos[i] == '1': end[i] = -1 if int(strides[i]) < 0 else shape[j] out_idx = (end[i] - begin[i]) // strides[i] if not int(out_idx * strides[i]) == int(end[i] - begin[i]): out_idx += 1 out_shape.append(out_idx) i += 1 j += 1 def get_old_indices(indices, idx): old_indices = list(indices) old_indices.insert(idx, begin[idx]) return old_indices def compute_func(in_tensor_, new_shape_, shrink_axis_): return tvm.compute( new_shape_, lambda *indices: in_tensor_(*get_old_indices( indices, shrink_axis_))) for shrink_axis in reversed(shrink_axes): new_shape = list(in_tensor.shape) new_shape.pop(shrink_axis) if not new_shape: return tvm.compute([1], lambda *i: in_tensor( *[b + idx * s for b, s, idx in zip(begin, strides, i)])) in_tensor = compute_func(in_tensor, new_shape, shrink_axis) begin.pop(shrink_axis) strides.pop(shrink_axis) return tvm.compute( out_shape, lambda *i: in_tensor( *[b + idx * s for b, s, idx in zip(begin, strides, i)]))