def lstmcell(inputs, hx, cx, w_ih, w_hh, b_ih, b_hh, use_bias=True): """ Computes the hidden and state variables of a Long Short Term Memory (lstm) cell. Args: input: akg.tvm.Tensor of type float16, float32 with shape [batch, input_size]. hx: akg.tvm.Tensor for hidden variable from previous cell with shape [batch, hidden_size]. cx: akg.tvm.Tensor for state variable from previous cell with shape [batch, hidden_size]. w_ih: akg.tvm.Tensor for input weights with shape [4*hidden_size, input_size]. w_hh: akg.tvm.Tensor for hidden weights with shape [4*hidden_size, hidden_size]. b_ih: akg.tvm.Tensor for input bias with shape [4*hidden_size]. b_hh: akg.tvm.Tensor for hidden bias with shape [4*hidden_size]. Returns: hy: akg.tvm.Tensor for hidden variable of current cell. cy: akg.tvm.Tensor for state variable of current cell. """ w_i_ih, w_f_ih, w_c_ih, w_o_ih = split(w_ih, 4, 0) b_i_ih, b_f_ih, b_c_ih, b_o_ih = split(b_ih, 4) w_i_hh, w_f_hh, w_c_hh, w_o_hh = split(w_hh, 4, 0) b_i_hh, b_f_hh, b_c_hh, b_o_hh = split(b_hh, 4) # gates:[batch, 4*hidden_size] ih*wh+bias # ingate, forgetgate, cellgate, outgate = split(gates, 4, 1) i = dense(inputs, w_i_ih, b_i_ih, use_bias) + dense(hx, w_i_hh, b_i_hh, use_bias) f = dense(inputs, w_f_ih, b_f_ih, use_bias) + dense(hx, w_f_hh, b_f_hh, use_bias) c = dense(inputs, w_c_ih, b_c_ih, use_bias) + dense(hx, w_c_hh, b_c_hh, use_bias) o = dense(inputs, w_o_ih, b_o_ih, use_bias) + dense(hx, w_o_hh, b_o_hh, use_bias) cy = (sigmoid(f) * cx) + (sigmoid(i) * tanh(c)) hy = sigmoid(o) * tanh(cy) return hy, cy
def tanh_ad(head, in_data): """ Compute gradient of tanh operator using automatic differentiate. Args: head (tvm.tensor.Tensor): Tensor of type float16, float32. in_data (tvm.tensor.Tensor): Tensor of type float16, float32. Returns: tvm.tensor.Tensor has the same shape as input. """ in_dtype = in_data.dtype # On cloud environment, cast data type from 'float16' to 'float32', # then cast result back to 'float16', could achieve higher precision. if in_dtype == 'float16' and not utils.product_is_mini(): in_data = akg.topi.cast(in_data, "float32") head = akg.topi.cast(head, "float32") out_data = tanh.tanh(in_data) jacs = list(akg.differentiate(out_data, [in_data], head)) jacs_res = jacs[0] if in_dtype == 'float16' and not utils.product_is_mini(): jacs_res = akg.topi.cast(jacs_res, 'float16') return jacs_res
def gelu_ad_custom(head, in_data): """ Automatic differentiation of gelu with customize function. In order to achieve higher precision, we could also self-define tanh part differentiate with simplify calculation. """ dtype = in_data.dtype const1 = akg.tvm.const(0.044715, dtype) const2 = akg.tvm.const(0.7978845, dtype) const3 = akg.tvm.const(0.1070322, dtype) tmp0 = akg.topi.multiply(in_data, in_data) pow0 = akg.topi.multiply(tmp0, in_data) mul0 = pow0 * const1 add0 = in_data + mul0 mul1 = add0 * const2 tanh_res = tanh.tanh(mul1) add1 = tanh_res + akg.tvm.const(1, dtype) mul2 = add1 * akg.tvm.const(0.5, dtype) mul3 = in_data * mul2 res = mul3 def gelu_diff(out, inp, head, ad_attrs, new_array_pld): temp = tanh_fdiff(head, mul1) return [ temp * (akg.tvm.const(0.7978845, dtype) + const3 * inp[0] * inp[0]) ] jacs = list( akg.differentiate(res, [in_data], head, None, None, override={tanh_res: ([in_data], gelu_diff)})) return jacs[0]
def rnn_tanh_cell_grad(input, hidden, w_ih, w_hh, b_ih, b_hh, grad): """ Computes dgrad w.r.t. dinput (di), dhidden_input (dhid), dweights (dWih, dWhh), dbias (db). Args: input: akg.tvm.Tensor of type float16, float32 with shape [batch, input_size]. hidden: akg.tvm.Tensor for hidden variable from previous cell with shape [batch, hidden_size]. w_ih: akg.tvm.Tensor for input weights with shape [hidden_size, input_size]. w_hh: akg.tvm.Tensor for hidden weights with shape [hidden_size, hidden_size]. b_ih: akg.tvm.Tensor for input bias with shape [hidden_size]. b_hh: akg.tvm.Tensor for hidden bias with shape [hidden_size]. grad: akg.tvm.Tensor representing dy with shape [batch, hidden_size]. Returns: di: akg.tvm.Tensor for dy/di. dhid: akg.tvm.Tensor for dy/dhid. dWih: akg.tvm.Tensor for dy/dWih (input weights). dWhh: akg.tvm.Tensor for dy/dWhh (hidden weights). db: akg.tvm.Tensor for dy/db. """ batch, input_size = get_shape(input) _, hidden_size = get_shape(hidden) igates = akg.topi.nn.dense(input, w_ih, b_ih) hgates = akg.topi.nn.dense(hidden, w_hh, b_hh) h = tanh(igates + hgates) dh = (1 - h * h) * grad kk = akg.tvm.reduce_axis((0, batch)) dWih = akg.tvm.compute( (hidden_size, input_size), lambda i, j: akg.tvm.sum(input[kk, j] * dh(kk, i), axis=kk), name="dWih") kk2 = akg.tvm.reduce_axis((0, batch)) dWhh = akg.tvm.compute( (hidden_size, hidden_size), lambda i, j: akg.tvm.sum(hidden[kk2, j] * dh(kk2, i), axis=kk2), name="dWhh") kk3 = akg.tvm.reduce_axis((0, hidden_size)) di = akg.tvm.compute( (batch, input_size), lambda i, j: akg.tvm.sum(w_ih[kk3, j] * dh[i, kk3], axis=kk3), name="di") kk4 = akg.tvm.reduce_axis((0, hidden_size)) dhid = akg.tvm.compute( (batch, hidden_size), lambda i, j: akg.tvm.sum(w_hh[kk4, j] * dh[i, kk4], axis=kk4), name="dhid") db = akg.topi.sum(dh, 0) return di, dhid, dWih, dWhh, db
def rnn_tanh_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh, use_bias=True): """ RNN cell with tanh non-linearity. Args: inputs: akg.tvm.Tensor of type float16, float32. hidden: akg.tvm.Tensor for hidden variable from previous cell. w_ih: akg.tvm.Tensor for input weights. w_hh: akg.tvm.Tensor for hidden weights. b_ih: akg.tvm.Tensor for input bias. b_hh: akg.tvm.Tensor for hidden bias. Returns: h: akg.tvm.Tensor for hidden output variable of current cell. """ igates = dense(inputs, w_ih, b_ih, use_bias) hgates = dense(hidden, w_hh, b_hh, use_bias) h = tanh(igates + hgates) return h
def gelu(data): """ gelu activation function. ..math:`0.5*data(1+tanh(sqrt(2/pi)(data+0.044715data^3)))` Args: x (tvm.tensor.Tensor): tensor with type float16 or float32. ..math:`0.5*x(1+tanh(sqrt(2/pi)(x+0.044715x^3))) data (tvm.tensor.Tensor): tensor with type float16 or float32. Returns: tvm.tensor.Tensor. """ dtype = data.dtype vc_util.ops_dtype_check(dtype, vc_util.DtypeForDavinci.ALL_FLOAT) if dtype == "float32" and utils.product_is_mini(): data = akg.tvm.compute(data.shape, lambda *indice: data(*indice).astype("float16"), name='type_cast') dtype = "float16" tmp0 = akg.topi.multiply(data, data) pow0 = akg.topi.multiply(tmp0, data) mul0 = pow0 * akg.tvm.const(0.044715, dtype) add0 = data + mul0 mul1 = add0 * akg.tvm.const(0.7978845, dtype) tanh_res = tanh(mul1) add1 = tanh_res + akg.tvm.const(1, dtype) mul2 = add1 * akg.tvm.const(0.5, dtype) mul3 = data * mul2 res = mul3 if dtype == "float32" and utils.product_is_mini(): res = akg.tvm.compute(res.shape, lambda *indice: res(*indice).astype("float16"), name='res') return res
def lstmcell_grad_h(input, hx, cx, w_ih, w_hh, b_ih, b_hh, dh, dc): """ Computes dh w.r.t. dw, db, dcx, dhx, dx. Args: input: akg.tvm.Tensor of type float16, float32. hx: akg.tvm.Tensor for hidden variable from previous cell. cx: akg.tvm.Tensor for state variable from previous cell. w_ih: akg.tvm.Tensor for input weights. w_hh: akg.tvm.Tensor for hidden weights. b_ih: akg.tvm.Tensor for input bias. b_hh: akg.tvm.Tensor for hidden bias. Returns: dw_ih: akg.tvm.Tensor for dh/dw_ih. dw_hh: akg.tvm.Tensor for dh/dw_hh. db_ih: akg.tvm.Tensor for dh/db_ih. db_hh: akg.tvm.Tensor for dh/db_hh. dcx: akg.tvm.Tensor for dh/dcx. dhx: akg.tvm.Tensor for dh/dhx. dx: akg.tvm.Tensor for dh/dx. """ # things from fwd batch, input_size = get_shape(input) _, hidden_size = get_shape(hx) xh = akg.topi.concatenate((hx, input), 1) whl = [w_ih, w_hh] W = concat(whl, 1) # [4*hidden_size, input_size+hidden_size] gates = dense(input, w_ih, b_ih, True) + dense(hx, w_hh, b_hh, True) ingate_in, forgetgate_in, cellgate_in, outgate_in = split(gates, 4, 1) ingate = sigmoid(ingate_in) forgetgate = sigmoid(forgetgate_in) cellgate = tanh(cellgate_in) outgate = sigmoid(outgate_in) cy = (forgetgate * cx) + (ingate * cellgate) tanh_cy = tanh(cy) #hy = outgate * tanh_cy # starts bwd # head * dh/do shape [n,] doutgate = dh * tanh_cy doutgate_in = outgate * (1 - outgate) * doutgate kk = akg.tvm.reduce_axis((0, batch)) dWo = akg.tvm.compute( (hidden_size, hidden_size + input_size), lambda i, j: akg.tvm.sum(xh[kk, j] * doutgate_in(kk, i), axis=kk), name="dWo") dtanh_cy = dh * outgate dc = (1 - tanh_cy * tanh_cy) * dtanh_cy dingate = cellgate * dc dingate_in = ingate * (1 - ingate) * dingate kk3 = akg.tvm.reduce_axis((0, batch)) dWi = akg.tvm.compute( (hidden_size, hidden_size + input_size), lambda i, j: akg.tvm.sum(xh[kk3, j] * dingate_in(kk3, i), axis=kk3), name="dWi") dforgetgate = dc * cx dforgetgate_in = forgetgate * (1 - forgetgate) * dforgetgate kk2 = akg.tvm.reduce_axis((0, batch)) dWf = akg.tvm.compute((hidden_size, hidden_size + input_size), lambda i, j: akg.tvm.sum( xh[kk2, j] * dforgetgate_in(kk2, i), axis=kk2), name="dWf") dcellgate = ingate * dc dcellgate_in = (1 - cellgate * cellgate) * dcellgate kk4 = akg.tvm.reduce_axis((0, batch)) dWc = akg.tvm.compute( (hidden_size, hidden_size + input_size), lambda i, j: akg.tvm.sum(xh[kk4, j] * dcellgate_in(kk4, i), axis=kk4), name="dWc") dW = akg.topi.concatenate((dWi, dWf, dWc, dWo)) db = akg.topi.concatenate( (dingate_in, dforgetgate_in, dcellgate_in, doutgate_in), 1) kk5 = akg.tvm.reduce_axis((0, 4 * hidden_size)) dxh = akg.tvm.compute( (batch, hidden_size + input_size), lambda i, j: akg.tvm.sum(W[kk5, j] * db[i, kk5], axis=kk5), name="dxh") dhx = akg.tvm.compute((batch, hidden_size), lambda i, j: dxh[i, j], name="dhx") dx = akg.tvm.compute((batch, input_size), lambda i, j: dxh[i, j + hidden_size], name="dx") dcx = forgetgate * dc dw_ih = akg.tvm.compute(w_ih.shape, lambda i, j: dW[i, j]) #dw_hh = akg.tvm.compute(w_hh.shape, lambda i, j: dW[i, j + input_size]) bhr = akg.tvm.reduce_axis((0, batch)) db_ih = akg.tvm.compute((4 * hidden_size, ), lambda i: akg.tvm.sum(db[i, bhr], axis=bhr), name="dbih") bir = akg.tvm.reduce_axis((0, batch)) db_hh = akg.tvm.compute((4 * hidden_size, ), lambda i: akg.tvm.sum(db[i, bir], axis=bir), name="dbhh") return dw_ih, w_hh, db_ih, db_hh, dcx, dhx, dx
def lstmcell_grad_c(input, hx, cx, w_ih, w_hh, b_ih, b_hh, dc): """ Computes dc w.r.t. dw, db, dcx, dhx, dx. Args: input: akg.tvm.Tensor of type float16, float32. hx: akg.tvm.Tensor for hidden variable from previous cell. cx: akg.tvm.Tensor for state variable from previous cell. w_ih: akg.tvm.Tensor for input weights. w_hh: akg.tvm.Tensor for hidden weights. b_ih: akg.tvm.Tensor for input bias. b_hh: akg.tvm.Tensor for hidden bias. Returns: dw_ih: akg.tvm.Tensor for dc/dw_ih. dw_hh: akg.tvm.Tensor for dc/dw_hh. db_ih: akg.tvm.Tensor for dc/db_ih. db_hh: akg.tvm.Tensor for dc/db_hh. dcx: akg.tvm.Tensor for dc/dcx. dhx: akg.tvm.Tensor for dc/dhx. dx: akg.tvm.Tensor for dc/dx. """ # things from fwd whl = [w_ih, w_hh] W = concat(whl, 1) # [4*hidden_size, input_size+hidden_size] b = b_ih + b_hh batch, input_size = get_shape(input) _, hidden_size = get_shape(hx) xh = akg.topi.concatenate((hx, input), 1) t = akg.topi.nn.dense(xh, W, b) temp_i = akg.tvm.compute((batch, hidden_size), lambda i, j: t(i, j), name="temp_i") i = sigmoid(temp_i) temp_f = akg.tvm.compute((batch, hidden_size), lambda i, j: t(i, j + hidden_size), name="temp_f") f = sigmoid(temp_f) temp_c_ = akg.tvm.compute((batch, hidden_size), lambda i, j: t(i, j + 2 * hidden_size), name="temp_c") c_ = tanh(temp_c_) # starts bwd # head * dh/do shape [n,] dtemp_o = akg.tvm.compute((batch, hidden_size), lambda *i: 0) dWo = akg.tvm.compute((hidden_size, hidden_size + input_size), lambda i, j: 0, name="dWo") df = dc * cx dtemp_f = f * (1 - f) * df kk2 = akg.tvm.reduce_axis((0, batch)) dWf = akg.tvm.compute( (hidden_size, hidden_size + input_size), lambda i, j: akg.tvm.sum(xh[kk2, j] * dtemp_f(kk2, i), axis=kk2), name="dWf") di = c_ * dc dtemp_i = i * (1 - i) * di kk3 = akg.tvm.reduce_axis((0, batch)) dWi = akg.tvm.compute( (hidden_size, hidden_size + input_size), lambda i, j: akg.tvm.sum(xh[kk3, j] * dtemp_i(kk3, i), axis=kk3), name="dWi") dc_ = i * dc dtemp_c_ = (1 - c_ * c_) * dc_ kk4 = akg.tvm.reduce_axis((0, batch)) dWc = akg.tvm.compute( (hidden_size, hidden_size + input_size), lambda i, j: akg.tvm.sum(xh[kk4, j] * dtemp_c_(kk4, i), axis=kk4), name="dWc") dW = akg.topi.concatenate((dWi, dWf, dWc, dWo)) db = akg.topi.concatenate((dtemp_i, dtemp_f, dtemp_c_, dtemp_o), 1) kk5 = akg.tvm.reduce_axis((0, 4 * hidden_size)) dxh = akg.tvm.compute( (batch, hidden_size + input_size), lambda i, j: akg.tvm.sum(W[kk5, j] * db[i, kk5], axis=kk5), name="dxh") dhx = akg.tvm.compute((batch, hidden_size), lambda i, j: dxh[i, j], name="dhx") dx = akg.tvm.compute((batch, input_size), lambda i, j: dxh[i, j + hidden_size], name="dx") dcx = f * dc dw_ih = akg.tvm.compute(w_ih.shape, lambda i, j: dW[i, j]) #dw_hh = akg.tvm.compute(w_hh.shape, lambda i, j: dW[i, j + input_size]) bhr = akg.tvm.reduce_axis((0, batch)) db_ih = akg.tvm.compute((4 * hidden_size, ), lambda i: akg.tvm.sum(db[i, bhr], axis=bhr), name="dbih") bir = akg.tvm.reduce_axis((0, batch)) db_hh = akg.tvm.compute((4 * hidden_size, ), lambda i: akg.tvm.sum(db[i, bir], axis=bir), name="dbhh") return dw_ih, w_hh, db_ih, db_hh, dcx, dhx, dx