def visitArgMax(self, node: AST.Func): (prog_in, expr_in) = self.visit(node.expr) type_out = node.expr.type assert type_out.dim == 2 [I, J] = type_out.shape expr_out = self.getTempVar() expr_in.inputVar = False cmd0 = IR.Comment('argmax(' + expr_in.idf + ')') funcCall = IR.FuncCall("ArgMax", { expr_in: "A", IR.Int(I): "I", IR.Int(J): "J", expr_out: "index" }) prog_argmax = IR.Prog([cmd0, funcCall]) prog_out = IRUtil.concatPrograms(prog_in, prog_argmax) self.decls[expr_out.idf] = Type.Int() return (prog_out, expr_out)
def visitSum(self, node: AST.Sum): ''' expr_out i = 0 for (j = 0; j < n; j++) expr_in = prog_in expr_out = expr_out + expr_in i++ 1. for i in [0, C]: 2. expr_out[i] = expr_out[i] + shr(expr_in[i]) ''' var_idf = node.name self.decls[var_idf] = Type.Int() (prog_in, expr_in) = self.visit(node.expr) start, end = node.start, node.end expr_out = self.getTempVar() type_out = node.type var = IR.Var(var_idf) var_iter = self.getTempIterator() iters = self.getTempIterators(type_out.dim) (scale_out, height_shr, height_noshr) = self.getScaleForTreeSum(self.scales[expr_in.idf], end - start) intv_out = self.getIntervalForTreeSum(self.intvs[expr_in.idf], end - start) # Tree sum to sum output of each iteration expr_in_idx = IRUtil.addIndex(expr_in, iters) expr_out_idx = IRUtil.addIndex(expr_out, iters) cmd1 = IR.Memset(expr_out, type_out.size()) cmd2 = IR.Assn( expr_out_idx, IRUtil.add(expr_out_idx, IRUtil.shr(expr_in_idx, height_shr))) treeSum = IRUtil.loop(type_out.shape, iters, [cmd2]) # Final program to sum output of each iteration prog_sum = [ cmd1, IR.Assn(var, IR.Int(start)), IR.For(var_iter, 0, IRUtil.lt(var_iter, IR.Int(end - start)), prog_in.cmd_l + treeSum + [IR.Assn(var, IRUtil.inc(var))]) ] prog_out = IR.Prog(prog_sum) self.decls[expr_out.idf] = type_out self.scales[expr_out.idf] = scale_out self.intvs[expr_out.idf] = intv_out return (prog_out, expr_out)
def visitSgn(self, node: AST.Func): (prog_in, expr_in) = self.visit(node.expr) expr_out = self.getTempVar() type_in = node.expr.type expr_in_idx = IRUtil.addIndex(expr_in, [IRUtil.zero] * type_in.dim) cmd0 = IR.Comment('sgn(' + expr_in.idf + ')') cmd1 = IR.Assn(expr_out, IRUtil.cond_zero(expr_in_idx, IRUtil.one, IRUtil.zero)) prog_sgn = IR.Prog([cmd0, cmd1]) prog_out = IRUtil.concatPrograms(prog_in, prog_sgn) self.decls[expr_out.idf] = Type.Int() return (prog_out, expr_out)
def visitLet(self, node: AST.Let): (prog_decl, expr_decl) = self.visit(node.decl) type_decl = node.decl.type idf = node.name # e1 : Int if Type.isInt(type_decl): self.decls[idf] = Type.Int() (prog_in, expr_in) = self.visit(node.expr) cmd = IR.Assn(IR.Var(idf), expr_decl) prog_let = IR.Prog([cmd]) prog_out = IRUtil.concatPrograms(prog_decl, prog_let, prog_in) return (prog_out, expr_in) # e1 : Tensor{(),(..)} else: self.scales[idf] = self.scales[expr_decl.idf] self.intvs[idf] = self.intvs[expr_decl.idf] if isinstance(node.decl, AST.Decl): self.globalVars.append(idf) self.decls[idf] = node.decl.type expr_decl.idf = idf expr_decl.inputVar = True (prog_in, expr_in) = self.visit(node.expr) prog_in = prog_in.subst(idf, expr_decl) expr_in = expr_in.subst(idf, expr_decl) prog_out = IRUtil.concatPrograms(prog_decl, prog_in) return (prog_out, expr_in)
def visitReshape(self, node: AST.Reshape): (prog_in, expr_in) = self.visit(node.expr) ''' reshape(A, n, h, w) cmd1: t1 = t2 = 0; loop2: for n in 0:N: for h in 0:H: for w in 0:W: cmd3: B[n][h][w] = A[t1][t2][t3] cmd4: t3++; cmd5: if (t3 == WW) t3 = 0; t2++; if (t2 == HH) t2 = 0; t1++; ''' type_in = node.expr.type type_out = node.type # Compute scaling factors scale_out = self.scales[expr_in.idf] intv_out = self.intvs[expr_in.idf] # Declare variables expr_out = self.getTempVar() iters_in = self.getTempIterators(type_in.dim) iters_out = self.getTempVars(type_out.dim) # Initialize to 0 cmd1 = [IR.Assn(var, IRUtil.zero) for var in iters_out] # Incrementing the first index first_iter = iters_out[0] cmd4 = IRUtil.incCmd(first_iter) # Incrementing other indices using a loop cmd5 = [cmd4] for i in range(1, type_out.dim): curr_iter = iters_out[i] curr_size = IR.Int(type_out.shape[i]) cmd5 = [ IRUtil.incCmd(curr_iter), IR.If(IRUtil.eq(curr_iter, curr_size), [IRUtil.initVarToZero(curr_iter)] + cmd5) ] # Outer loop loopShape = [] loopIters = [] for order in node.order: order = order - 1 loopShape.append(type_in.shape[order]) loopIters.append(iters_in[order]) loop2 = IRUtil.loop(loopShape, loopIters, [ IR.Assn(IRUtil.addIndex(expr_out, iters_out), IRUtil.addIndex(expr_in, iters_in)) ] + cmd5) # Finalize comment = IR.Comment("reshape(" + expr_in.idf + ", " + ', '.join(str(e) for e in type_out.shape) + ")") prog_reshape = IR.Prog([comment] + cmd1 + loop2) prog_out = IRUtil.concatPrograms(prog_in, prog_reshape) # Update context self.decls[expr_out.idf] = type_out self.scales[expr_out.idf] = scale_out self.intvs[expr_out.idf] = intv_out # Update declarations self.decls.update(dict((var.idf, Type.Int()) for var in iters_out)) return (prog_out, expr_out)
def visitTableExp(self, node: AST.Func): (prog_in, expr_in) = self.visit(node.expr) # TODO: use MAX_VAL_EXP type_in = node.expr.type scale_in = self.scales[expr_in.idf] intv_in = self.intvs[expr_in.idf] [m, M] = self.expRange [m_scale, M_scale] = [int(np.ldexp(m, -scale_in)), int(np.ldexp(M, -scale_in))] max = int(np.ldexp(M - m, -scale_in)) shl = self.getShl(max) input = self.getTempVar() [i, j] = self.getTempVars(2) expr_out = self.getTempVar() ''' 1. if ((-x) < min) { 2. i = 0; 3. j = 0; 4. } 5. else { 6. y = ((-x) - min) << shl 7. i = (y >> shrI) & (2^b-1) 8. j = (y >> shrJ) & (2^b-1) 9. } 10. ans = T[i] * U[j] ''' mask = IR.Int(2**self.expB - 1) shrI = Common.wordLength - self.expB shrJ = Common.wordLength - self.expB * 2 table = self.getExpTable(scale_in) scale1 = self.getScale(1) scale2 = self.getScale(abs(np.exp(-m))) [shr1, shr2] = self.getShrForMul(scale1, scale2) expr_1_elt = IRUtil.addIndex(expr_in, [IRUtil.zero] * type_in.dim) expr_2_elt = IRUtil.addIndex(expr_out, [IRUtil.zero] * type_in.dim) cond = IRUtil.lt(IRUtil.negate(expr_1_elt), IR.Int(m_scale)) cmd2 = IR.Assn(i, IR.Int(0)) cmd3 = IR.Assn(j, IR.Int(0)) cmd6 = IR.Assn( input, IRUtil.shl(IRUtil.sub(IRUtil.negate(expr_1_elt), IR.Int(m_scale)), shl)) cmd7 = IR.Assn(i, IRUtil.bitAnd(IRUtil.shrUint(input, shrI), mask)) cmd8 = IR.Assn(j, IRUtil.bitAnd(IRUtil.shrUint(input, shrJ), mask)) cmd1 = IR.If(cond, [cmd2, cmd3], [cmd6, cmd7, cmd8]) cmd10 = IR.Assn( expr_2_elt, IRUtil.mul(IRUtil.shrUint(IRUtil.addIndex(table[0], [i]), shr1), IRUtil.shrUint(IRUtil.addIndex(table[1], [j]), shr2))) scale_out = self.getScaleForExp(scale1, shr1, scale2, shr2) intv_out = self.getIntervalForExp(scale_out, [-m_scale, -M_scale]) cmd0 = IR.Comment('exp(' + expr_in.idf + ')') prog_exp = IR.Prog([cmd0, cmd1, cmd10]) prog_out = IRUtil.concatPrograms(prog_in, prog_exp) self.decls[expr_out.idf] = type_in self.scales[expr_out.idf] = scale_out self.intvs[expr_out.idf] = intv_out self.decls.update(dict((var.idf, Type.Int()) for var in [input, i, j])) return (prog_out, expr_out)