def concat_ad(data, axis, wrt_index=0): """ autodiff of concat with one or more input data. Args: data (list[akg.tvm.tensor.Tensor]): input tensors. axis (int): concat axis wrt_index (int): derivative with respect to index (must be less than len(data)). Returns: concatenation result with the given data and axis. """ output = concat.concat(data, axis) head = akg.tvm.placeholder(output.shape, output.dtype, name="head") jacs = list(akg.differentiate(output, [data[wrt_index]], head)) return jacs[0], head
def segment_max(data, segment_ids, num_segments): """ Computes the max value along segment_ids of a akg.tvm.tensor Args: data: akg.tvm.Tensor of type "float16", "float32" segment_ids: akg.tvm.Tensor of type int32, sorted Returns: akg.tvm.Tensor of same shape and type as data """ d_dtype = data.dtype vc_util.ops_dtype_check(d_dtype, vc_util.DtypeForDavinci.ALL_FLOAT) d_shape = [x.value for x in data.shape] vc_util.check_shape(d_shape) s_shape = segment_ids.shape vc_util.check_shape(s_shape) new_segment_ids, idx = gen_ids(segment_ids) output_shape = (1, ) + tuple(d_shape[len(s_shape):]) zero_data = akg.tvm.compute(output_shape, lambda *i: akg.tvm.const(0.0, d_dtype), name="zero") data_list = split.split(data, new_segment_ids) out_n = num_segments out = [] j = 0 for i in range(0, out_n): if i in idx: tmp = reduce_max.reduce_max(data_list[j], 0, True) out.append(tmp) j = j + 1 else: out.append(zero_data) res = concat.concat(out, 0) return res
def unsorted_segment_max(data, segment_ids, num_segments): """ Computes the max value along segment_ids of a akg.tvm.Tensor Args: data: akg.tvm.Tensor of type float16, float32 segment_ids: akg.tvm.Tensor of type int32, shape is a prefix of input_data.shape. num_segments: the number of classes in segment_ids Returns: akg.tvm.Tensor of same type as input_data, """ d_dtype = data.dtype vc_util.ops_dtype_check(d_dtype, vc_util.DtypeForDavinci.ALL_FLOAT) d_shape = [x.value for x in data.shape] vc_util.check_shape(d_shape) s_shape = segment_ids.shape vc_util.check_shape(s_shape) new_segment_ids, idx = gen_ids(segment_ids) output_shape = (1, ) + tuple(d_shape[len(s_shape):]) zero_data = akg.tvm.compute(output_shape, lambda *i: akg.tvm.const(0.0, d_dtype), name="zero") data_list, new_idx = split_new(data, new_segment_ids, idx, num_segments) out = [] j = 0 for i in range(0, num_segments): if i in new_idx: tmp = reduce_max.reduce_max(data_list[j], 0, True) out.append(tmp) j = j + 1 else: out.append(zero_data) res = concat.concat(out, 0) return res
def split_new(data, new_segment_ids, idx, num_segments): data_list = split.split(data, new_segment_ids) if not isinstance(data_list, (list, tuple)): data_list = [data_list] data = dict() new_idx = [] out = [] for i in range(0, num_segments): if i in idx: data[str(i)] = [] new_idx.append(i) for (tmp, tmp_data) in zip(idx, data_list): if tmp == i: data[str(i)].append(tmp_data) out.append(concat.concat(data[str(i)], 0)) return out, new_idx
def lstmcell_grad_h(input, hx, cx, w_ih, w_hh, b_ih, b_hh, dh, dc): """ Computes dh w.r.t. dw, db, dcx, dhx, dx. Args: input: akg.tvm.Tensor of type float16, float32. hx: akg.tvm.Tensor for hidden variable from previous cell. cx: akg.tvm.Tensor for state variable from previous cell. w_ih: akg.tvm.Tensor for input weights. w_hh: akg.tvm.Tensor for hidden weights. b_ih: akg.tvm.Tensor for input bias. b_hh: akg.tvm.Tensor for hidden bias. Returns: dw_ih: akg.tvm.Tensor for dh/dw_ih. dw_hh: akg.tvm.Tensor for dh/dw_hh. db_ih: akg.tvm.Tensor for dh/db_ih. db_hh: akg.tvm.Tensor for dh/db_hh. dcx: akg.tvm.Tensor for dh/dcx. dhx: akg.tvm.Tensor for dh/dhx. dx: akg.tvm.Tensor for dh/dx. """ # things from fwd batch, input_size = get_shape(input) _, hidden_size = get_shape(hx) xh = akg.topi.concatenate((hx, input), 1) whl = [w_ih, w_hh] W = concat(whl, 1) # [4*hidden_size, input_size+hidden_size] gates = dense(input, w_ih, b_ih, True) + dense(hx, w_hh, b_hh, True) ingate_in, forgetgate_in, cellgate_in, outgate_in = split(gates, 4, 1) ingate = sigmoid(ingate_in) forgetgate = sigmoid(forgetgate_in) cellgate = tanh(cellgate_in) outgate = sigmoid(outgate_in) cy = (forgetgate * cx) + (ingate * cellgate) tanh_cy = tanh(cy) #hy = outgate * tanh_cy # starts bwd # head * dh/do shape [n,] doutgate = dh * tanh_cy doutgate_in = outgate * (1 - outgate) * doutgate kk = akg.tvm.reduce_axis((0, batch)) dWo = akg.tvm.compute( (hidden_size, hidden_size + input_size), lambda i, j: akg.tvm.sum(xh[kk, j] * doutgate_in(kk, i), axis=kk), name="dWo") dtanh_cy = dh * outgate dc = (1 - tanh_cy * tanh_cy) * dtanh_cy dingate = cellgate * dc dingate_in = ingate * (1 - ingate) * dingate kk3 = akg.tvm.reduce_axis((0, batch)) dWi = akg.tvm.compute( (hidden_size, hidden_size + input_size), lambda i, j: akg.tvm.sum(xh[kk3, j] * dingate_in(kk3, i), axis=kk3), name="dWi") dforgetgate = dc * cx dforgetgate_in = forgetgate * (1 - forgetgate) * dforgetgate kk2 = akg.tvm.reduce_axis((0, batch)) dWf = akg.tvm.compute((hidden_size, hidden_size + input_size), lambda i, j: akg.tvm.sum( xh[kk2, j] * dforgetgate_in(kk2, i), axis=kk2), name="dWf") dcellgate = ingate * dc dcellgate_in = (1 - cellgate * cellgate) * dcellgate kk4 = akg.tvm.reduce_axis((0, batch)) dWc = akg.tvm.compute( (hidden_size, hidden_size + input_size), lambda i, j: akg.tvm.sum(xh[kk4, j] * dcellgate_in(kk4, i), axis=kk4), name="dWc") dW = akg.topi.concatenate((dWi, dWf, dWc, dWo)) db = akg.topi.concatenate( (dingate_in, dforgetgate_in, dcellgate_in, doutgate_in), 1) kk5 = akg.tvm.reduce_axis((0, 4 * hidden_size)) dxh = akg.tvm.compute( (batch, hidden_size + input_size), lambda i, j: akg.tvm.sum(W[kk5, j] * db[i, kk5], axis=kk5), name="dxh") dhx = akg.tvm.compute((batch, hidden_size), lambda i, j: dxh[i, j], name="dhx") dx = akg.tvm.compute((batch, input_size), lambda i, j: dxh[i, j + hidden_size], name="dx") dcx = forgetgate * dc dw_ih = akg.tvm.compute(w_ih.shape, lambda i, j: dW[i, j]) #dw_hh = akg.tvm.compute(w_hh.shape, lambda i, j: dW[i, j + input_size]) bhr = akg.tvm.reduce_axis((0, batch)) db_ih = akg.tvm.compute((4 * hidden_size, ), lambda i: akg.tvm.sum(db[i, bhr], axis=bhr), name="dbih") bir = akg.tvm.reduce_axis((0, batch)) db_hh = akg.tvm.compute((4 * hidden_size, ), lambda i: akg.tvm.sum(db[i, bir], axis=bir), name="dbhh") return dw_ih, w_hh, db_ih, db_hh, dcx, dhx, dx
def lstmcell_grad_c(input, hx, cx, w_ih, w_hh, b_ih, b_hh, dc): """ Computes dc w.r.t. dw, db, dcx, dhx, dx. Args: input: akg.tvm.Tensor of type float16, float32. hx: akg.tvm.Tensor for hidden variable from previous cell. cx: akg.tvm.Tensor for state variable from previous cell. w_ih: akg.tvm.Tensor for input weights. w_hh: akg.tvm.Tensor for hidden weights. b_ih: akg.tvm.Tensor for input bias. b_hh: akg.tvm.Tensor for hidden bias. Returns: dw_ih: akg.tvm.Tensor for dc/dw_ih. dw_hh: akg.tvm.Tensor for dc/dw_hh. db_ih: akg.tvm.Tensor for dc/db_ih. db_hh: akg.tvm.Tensor for dc/db_hh. dcx: akg.tvm.Tensor for dc/dcx. dhx: akg.tvm.Tensor for dc/dhx. dx: akg.tvm.Tensor for dc/dx. """ # things from fwd whl = [w_ih, w_hh] W = concat(whl, 1) # [4*hidden_size, input_size+hidden_size] b = b_ih + b_hh batch, input_size = get_shape(input) _, hidden_size = get_shape(hx) xh = akg.topi.concatenate((hx, input), 1) t = akg.topi.nn.dense(xh, W, b) temp_i = akg.tvm.compute((batch, hidden_size), lambda i, j: t(i, j), name="temp_i") i = sigmoid(temp_i) temp_f = akg.tvm.compute((batch, hidden_size), lambda i, j: t(i, j + hidden_size), name="temp_f") f = sigmoid(temp_f) temp_c_ = akg.tvm.compute((batch, hidden_size), lambda i, j: t(i, j + 2 * hidden_size), name="temp_c") c_ = tanh(temp_c_) # starts bwd # head * dh/do shape [n,] dtemp_o = akg.tvm.compute((batch, hidden_size), lambda *i: 0) dWo = akg.tvm.compute((hidden_size, hidden_size + input_size), lambda i, j: 0, name="dWo") df = dc * cx dtemp_f = f * (1 - f) * df kk2 = akg.tvm.reduce_axis((0, batch)) dWf = akg.tvm.compute( (hidden_size, hidden_size + input_size), lambda i, j: akg.tvm.sum(xh[kk2, j] * dtemp_f(kk2, i), axis=kk2), name="dWf") di = c_ * dc dtemp_i = i * (1 - i) * di kk3 = akg.tvm.reduce_axis((0, batch)) dWi = akg.tvm.compute( (hidden_size, hidden_size + input_size), lambda i, j: akg.tvm.sum(xh[kk3, j] * dtemp_i(kk3, i), axis=kk3), name="dWi") dc_ = i * dc dtemp_c_ = (1 - c_ * c_) * dc_ kk4 = akg.tvm.reduce_axis((0, batch)) dWc = akg.tvm.compute( (hidden_size, hidden_size + input_size), lambda i, j: akg.tvm.sum(xh[kk4, j] * dtemp_c_(kk4, i), axis=kk4), name="dWc") dW = akg.topi.concatenate((dWi, dWf, dWc, dWo)) db = akg.topi.concatenate((dtemp_i, dtemp_f, dtemp_c_, dtemp_o), 1) kk5 = akg.tvm.reduce_axis((0, 4 * hidden_size)) dxh = akg.tvm.compute( (batch, hidden_size + input_size), lambda i, j: akg.tvm.sum(W[kk5, j] * db[i, kk5], axis=kk5), name="dxh") dhx = akg.tvm.compute((batch, hidden_size), lambda i, j: dxh[i, j], name="dhx") dx = akg.tvm.compute((batch, input_size), lambda i, j: dxh[i, j + hidden_size], name="dx") dcx = f * dc dw_ih = akg.tvm.compute(w_ih.shape, lambda i, j: dW[i, j]) #dw_hh = akg.tvm.compute(w_hh.shape, lambda i, j: dW[i, j + input_size]) bhr = akg.tvm.reduce_axis((0, batch)) db_ih = akg.tvm.compute((4 * hidden_size, ), lambda i: akg.tvm.sum(db[i, bhr], axis=bhr), name="dbih") bir = akg.tvm.reduce_axis((0, batch)) db_hh = akg.tvm.compute((4 * hidden_size, ), lambda i: akg.tvm.sum(db[i, bir], axis=bir), name="dbhh") return dw_ih, w_hh, db_ih, db_hh, dcx, dhx, dx