Exemple #1
0
    def visitArgMax(self, node: AST.Func):

        (prog_in, expr_in) = self.visit(node.expr)

        type_out = node.expr.type

        assert type_out.dim == 2

        [I, J] = type_out.shape

        expr_out = self.getTempVar()

        expr_in.inputVar = False

        cmd0 = IR.Comment('argmax(' + expr_in.idf + ')')

        funcCall = IR.FuncCall("ArgMax", {
            expr_in: "A",
            IR.Int(I): "I",
            IR.Int(J): "J",
            expr_out: "index"
        })

        prog_argmax = IR.Prog([cmd0, funcCall])

        prog_out = IRUtil.concatPrograms(prog_in, prog_argmax)

        self.decls[expr_out.idf] = Type.Int()

        return (prog_out, expr_out)
Exemple #2
0
    def visitSum(self, node: AST.Sum):
        '''
        expr_out
        i = 0
        for (j = 0; j < n; j++)
          expr_in = prog_in
          expr_out = expr_out + expr_in
          i++

        1.  for i in [0, C]:
        2.    expr_out[i] = expr_out[i] + shr(expr_in[i])
        '''

        var_idf = node.name
        self.decls[var_idf] = Type.Int()

        (prog_in, expr_in) = self.visit(node.expr)

        start, end = node.start, node.end

        expr_out = self.getTempVar()
        type_out = node.type

        var = IR.Var(var_idf)
        var_iter = self.getTempIterator()
        iters = self.getTempIterators(type_out.dim)

        (scale_out, height_shr,
         height_noshr) = self.getScaleForTreeSum(self.scales[expr_in.idf],
                                                 end - start)
        intv_out = self.getIntervalForTreeSum(self.intvs[expr_in.idf],
                                              end - start)

        # Tree sum to sum output of each iteration
        expr_in_idx = IRUtil.addIndex(expr_in, iters)
        expr_out_idx = IRUtil.addIndex(expr_out, iters)

        cmd1 = IR.Memset(expr_out, type_out.size())
        cmd2 = IR.Assn(
            expr_out_idx,
            IRUtil.add(expr_out_idx, IRUtil.shr(expr_in_idx, height_shr)))
        treeSum = IRUtil.loop(type_out.shape, iters, [cmd2])

        # Final program to sum output of each iteration
        prog_sum = [
            cmd1,
            IR.Assn(var, IR.Int(start)),
            IR.For(var_iter, 0, IRUtil.lt(var_iter, IR.Int(end - start)),
                   prog_in.cmd_l + treeSum + [IR.Assn(var, IRUtil.inc(var))])
        ]

        prog_out = IR.Prog(prog_sum)

        self.decls[expr_out.idf] = type_out
        self.scales[expr_out.idf] = scale_out
        self.intvs[expr_out.idf] = intv_out

        return (prog_out, expr_out)
Exemple #3
0
    def visitSgn(self, node: AST.Func):

        (prog_in, expr_in) = self.visit(node.expr)

        expr_out = self.getTempVar()
        type_in = node.expr.type

        expr_in_idx = IRUtil.addIndex(expr_in, [IRUtil.zero] * type_in.dim)

        cmd0 = IR.Comment('sgn(' + expr_in.idf + ')')
        cmd1 = IR.Assn(expr_out,
                       IRUtil.cond_zero(expr_in_idx, IRUtil.one, IRUtil.zero))

        prog_sgn = IR.Prog([cmd0, cmd1])

        prog_out = IRUtil.concatPrograms(prog_in, prog_sgn)

        self.decls[expr_out.idf] = Type.Int()

        return (prog_out, expr_out)
Exemple #4
0
    def visitLet(self, node: AST.Let):

        (prog_decl, expr_decl) = self.visit(node.decl)
        type_decl = node.decl.type

        idf = node.name

        # e1 : Int
        if Type.isInt(type_decl):
            self.decls[idf] = Type.Int()

            (prog_in, expr_in) = self.visit(node.expr)

            cmd = IR.Assn(IR.Var(idf), expr_decl)
            prog_let = IR.Prog([cmd])

            prog_out = IRUtil.concatPrograms(prog_decl, prog_let, prog_in)

            return (prog_out, expr_in)

        # e1 : Tensor{(),(..)}
        else:
            self.scales[idf] = self.scales[expr_decl.idf]
            self.intvs[idf] = self.intvs[expr_decl.idf]

            if isinstance(node.decl, AST.Decl):
                self.globalVars.append(idf)
                self.decls[idf] = node.decl.type
                expr_decl.idf = idf
                expr_decl.inputVar = True

            (prog_in, expr_in) = self.visit(node.expr)

            prog_in = prog_in.subst(idf, expr_decl)
            expr_in = expr_in.subst(idf, expr_decl)

            prog_out = IRUtil.concatPrograms(prog_decl, prog_in)

            return (prog_out, expr_in)
Exemple #5
0
    def visitReshape(self, node: AST.Reshape):

        (prog_in, expr_in) = self.visit(node.expr)
        '''
		reshape(A, n, h, w)

		cmd1:  t1 = t2 = 0;
		loop2: for n in 0:N:
		         for h in 0:H:
		           for w in 0:W:
		cmd3:        B[n][h][w] = A[t1][t2][t3]
		cmd4:        t3++;
		cmd5:        if (t3 == WW)
		               t3 = 0;
		               t2++;
		               if (t2 == HH)
		                 t2 = 0;
		                 t1++;
		'''

        type_in = node.expr.type
        type_out = node.type

        # Compute scaling factors
        scale_out = self.scales[expr_in.idf]
        intv_out = self.intvs[expr_in.idf]

        # Declare variables
        expr_out = self.getTempVar()

        iters_in = self.getTempIterators(type_in.dim)
        iters_out = self.getTempVars(type_out.dim)

        # Initialize to 0
        cmd1 = [IR.Assn(var, IRUtil.zero) for var in iters_out]

        # Incrementing the first index
        first_iter = iters_out[0]
        cmd4 = IRUtil.incCmd(first_iter)

        # Incrementing other indices using a loop
        cmd5 = [cmd4]
        for i in range(1, type_out.dim):
            curr_iter = iters_out[i]
            curr_size = IR.Int(type_out.shape[i])
            cmd5 = [
                IRUtil.incCmd(curr_iter),
                IR.If(IRUtil.eq(curr_iter, curr_size),
                      [IRUtil.initVarToZero(curr_iter)] + cmd5)
            ]

        # Outer loop
        loopShape = []
        loopIters = []
        for order in node.order:
            order = order - 1
            loopShape.append(type_in.shape[order])
            loopIters.append(iters_in[order])

        loop2 = IRUtil.loop(loopShape, loopIters, [
            IR.Assn(IRUtil.addIndex(expr_out, iters_out),
                    IRUtil.addIndex(expr_in, iters_in))
        ] + cmd5)

        # Finalize
        comment = IR.Comment("reshape(" + expr_in.idf + ", " +
                             ', '.join(str(e) for e in type_out.shape) + ")")
        prog_reshape = IR.Prog([comment] + cmd1 + loop2)

        prog_out = IRUtil.concatPrograms(prog_in, prog_reshape)

        # Update context
        self.decls[expr_out.idf] = type_out
        self.scales[expr_out.idf] = scale_out
        self.intvs[expr_out.idf] = intv_out

        # Update declarations
        self.decls.update(dict((var.idf, Type.Int()) for var in iters_out))

        return (prog_out, expr_out)
Exemple #6
0
    def visitTableExp(self, node: AST.Func):

        (prog_in, expr_in) = self.visit(node.expr)

        # TODO: use MAX_VAL_EXP
        type_in = node.expr.type

        scale_in = self.scales[expr_in.idf]
        intv_in = self.intvs[expr_in.idf]

        [m, M] = self.expRange
        [m_scale,
         M_scale] = [int(np.ldexp(m, -scale_in)),
                     int(np.ldexp(M, -scale_in))]

        max = int(np.ldexp(M - m, -scale_in))
        shl = self.getShl(max)

        input = self.getTempVar()
        [i, j] = self.getTempVars(2)
        expr_out = self.getTempVar()
        '''
		1.  if ((-x) < min) {
		2.  	i = 0;
		3.  	j = 0;
		4.  }
		5.  else {
		6.  	y = ((-x) - min) << shl
		7.  	i = (y >> shrI) & (2^b-1)
		8.  	j = (y >> shrJ) & (2^b-1)
		9.  }
		10. ans = T[i] * U[j]
		'''

        mask = IR.Int(2**self.expB - 1)
        shrI = Common.wordLength - self.expB
        shrJ = Common.wordLength - self.expB * 2
        table = self.getExpTable(scale_in)

        scale1 = self.getScale(1)
        scale2 = self.getScale(abs(np.exp(-m)))

        [shr1, shr2] = self.getShrForMul(scale1, scale2)

        expr_1_elt = IRUtil.addIndex(expr_in, [IRUtil.zero] * type_in.dim)
        expr_2_elt = IRUtil.addIndex(expr_out, [IRUtil.zero] * type_in.dim)

        cond = IRUtil.lt(IRUtil.negate(expr_1_elt), IR.Int(m_scale))

        cmd2 = IR.Assn(i, IR.Int(0))
        cmd3 = IR.Assn(j, IR.Int(0))

        cmd6 = IR.Assn(
            input,
            IRUtil.shl(IRUtil.sub(IRUtil.negate(expr_1_elt), IR.Int(m_scale)),
                       shl))
        cmd7 = IR.Assn(i, IRUtil.bitAnd(IRUtil.shrUint(input, shrI), mask))
        cmd8 = IR.Assn(j, IRUtil.bitAnd(IRUtil.shrUint(input, shrJ), mask))

        cmd1 = IR.If(cond, [cmd2, cmd3], [cmd6, cmd7, cmd8])
        cmd10 = IR.Assn(
            expr_2_elt,
            IRUtil.mul(IRUtil.shrUint(IRUtil.addIndex(table[0], [i]), shr1),
                       IRUtil.shrUint(IRUtil.addIndex(table[1], [j]), shr2)))

        scale_out = self.getScaleForExp(scale1, shr1, scale2, shr2)
        intv_out = self.getIntervalForExp(scale_out, [-m_scale, -M_scale])

        cmd0 = IR.Comment('exp(' + expr_in.idf + ')')

        prog_exp = IR.Prog([cmd0, cmd1, cmd10])

        prog_out = IRUtil.concatPrograms(prog_in, prog_exp)

        self.decls[expr_out.idf] = type_in
        self.scales[expr_out.idf] = scale_out
        self.intvs[expr_out.idf] = intv_out

        self.decls.update(dict((var.idf, Type.Int()) for var in [input, i, j]))

        return (prog_out, expr_out)