Beispiel #1
0
	def visitArgMax(self, node:AST.ArgMax, args=None):
		(prog_1, expr1) = self.visit(node.expr)
		(prog_2, expr2) = self.visit(node.dim)
		
		tmpExpr = self.getTempVar()
		outputShape = node.type.shape

		funcArgsList = OrderedDict()
		outputShape = node.type.shape
		for ii, curDim in enumerate(outputShape):
			funcArgsList[IR.Int(curDim, 32)] = "OutputShape_" + str(ii)
		for ii, curDim in enumerate(node.inShape):
			funcArgsList[IR.Int(curDim, 32)] = "OutputShape_" + str(ii)
		funcArgsList[expr1] = "inArr"
		funcArgsList[expr2] = "dim"
		funcArgsList[tmpExpr] = "outArr"

		if not(Util.Config.disableTruncOpti):
			self.scaleFacMapping[tmpExpr.idf] = 0 #TODO -- is this the right thing to do?

		funcCall = IR.FuncCall("ArgMax" + self.varNameDelim + str(len(outputShape)), funcArgsList)
		comment = IR.Comment(str(node.metadata))
		prog_3 = IRUtil.prog_merge(prog_1, prog_2, IR.Prog([comment, funcCall]))
		prog_3 = IRUtil.prog_merge(IR.Prog([IR.Decl(tmpExpr.idf, node.type)]), prog_3)
		return (prog_3, tmpExpr)
Beispiel #2
0
	def visitReduce(self, node:AST.Reduce, args=None):
		(prog_1, expr1) = self.visit(node.expr)
		(prog_2, expr2) = self.visit(node.dim)

		returnExpr = self.getTempVar()

		assert(node.op in [AST.Operators.ADD, AST.Operators.Mean])
		if (node.op == AST.Operators.ADD):
			funcName = "ReduceSum"
		elif (node.op == AST.Operators.Mean):
			funcName = "ReduceMean"

		if not(Util.Config.disableTruncOpti):
			self.scaleFacMapping[returnExpr.idf] = self.scaleFacMapping[expr1.idf]

		funcArgsList = OrderedDict()
		outputShape = node.type.shape
		for ii, curDim in enumerate(outputShape):
			funcArgsList[IR.Int(curDim, 32)] = "OutputShape_" + str(ii)

		inputShape = node.expr.type.shape
		for ii, curDim in enumerate(inputShape):
			funcArgsList[IR.Int(curDim, 32)] = "InputShape_" + str(ii)

		funcArgsList[expr1] = "inputArr"
		funcArgsList[expr2] = "dimension"
		funcArgsList[returnExpr] = "outArr"
		funcCall = IR.FuncCall(funcName + self.varNameDelim + str(len(outputShape)) + self.varNameDelim + str(len(inputShape)), funcArgsList)
		comment = IR.Comment(str(node.metadata))
		prog_3 = IRUtil.prog_merge(prog_1, prog_2, IR.Prog([comment, funcCall]))

		prog_3 = IRUtil.prog_merge(IR.Prog([IR.Decl(returnExpr.idf, node.type)]), prog_3)
		return (prog_3, returnExpr)
Beispiel #3
0
	def visitTranspose(self, node:AST.Transpose, args=None):
		(inp_prog, inp_arr) = self.visit(node.expr)
		inp_type = node.expr.type
		out_type = node.type
		inp_iters = self.getTempIterators(inp_type.dim)
		out_iters = []
		perm = node.perm
		if (perm is None):
			perm = [i for i in reversed(range(len(inp_type.shape)))]
		for i in perm:
			out_iters.append(inp_iters[i])
		out_arr = self.getTempVar()
		out_arr_expr = IRUtil.addIndex(out_arr, out_iters)
		inp_arr_expr = IRUtil.addIndex(inp_arr, inp_iters)
		assign_expr = IR.Assn(out_arr_expr, inp_arr_expr)
		loop = IRUtil.loop(inp_type.shape, inp_iters, [assign_expr])
		# Finalize
		comment1 = IR.Comment(str(node.metadata))
		comment2 = IR.Comment("transpose(" + inp_arr.idf + ", [" + ', '.join(str(e) for e in inp_type.shape) + "] --> [" + ', '.join(str(e) for e in out_type.shape) + "])")
		transpose_prog = IR.Prog([comment1, comment2] + loop)
		final_prog = IRUtil.prog_merge(inp_prog, transpose_prog)

		for var in inp_iters:
			final_prog = IRUtil.prog_merge(IR.Prog([IR.Decl(var.idf, Type.Int(), isSecret=False)]), final_prog)
		final_prog = IRUtil.prog_merge(IR.Prog([IR.Decl(out_arr.idf, out_type)]), final_prog)

		if not(Util.Config.disableTruncOpti):
			self.scaleFacMapping[out_arr.idf] = self.scaleFacMapping[inp_arr.idf]

		return (final_prog, out_arr)
Beispiel #4
0
    def visitUOp(self, node: AST.UOp, args=None):
        (prog_1, expr_1) = self.visit(node.expr)
        op = node.op
        if op == AST.Operators.ADD:
            return (prog_1, expr_1)
        assert op == AST.Operators.SUB

        typ_2 = node.type

        # e : Int
        if Type.isInt(typ_2):
            prog_2 = prog_1
            expr_2 = IRUtil.negate(expr_1)

        # e: Tensor(), or Tensor(..)
        else:
            # decl fresh vars
            expr_2 = self.getTempVar()
            iters = self.getTempIterators(typ_2.dim)

            # cmdl_assn
            expr_1_elt = IRUtil.addIndex(expr_1, iters)
            expr_2_elt = IRUtil.addIndex(expr_2, iters)
            cmdl_assn = IRUtil.loop(
                typ_2.shape, iters,
                [IR.Assn(expr_2_elt, IRUtil.negate(expr_1_elt))])
            comment = IR.Comment(str(node.metadata))
            prog_2 = IRUtil.prog_merge(prog_1, IR.Prog([comment] + cmdl_assn))

            self.decls[expr_2.idf] = [typ_2]
            prog_2 = IRUtil.prog_merge(
                IR.Prog([IR.Decl(expr_2.idf, node.type)]), prog_2)

        return (prog_2, expr_2)
Beispiel #5
0
    def visitBopElemWiseOp(self, node: AST.BOp, args=None):
        (prog_1, expr_1) = self.visit(node.expr1)
        (prog_2, expr_2) = self.visit(node.expr2)

        shr_n1, shr_n2 = IR.Int(0), IR.Int(0)
        if (node.op == AST.Operators.ElemWiseMul):
            op_ir = IR.Op.Op['.*']
            funcName = "ElemWiseMul"
            shr_n3 = IR.Int(Util.Config.consSF)
        elif (node.op == AST.Operators.ElemWiseDiv):
            op_ir = IR.Op.Op['./']
            funcName = "ElemWiseDiv"
            shr_n3 = IR.Int(
                Util.Config.consSF)  # TODO : rem this, passing +ve -- ?

        typ_3 = node.type
        expr_3 = self.getTempVar()
        cmd0 = IR.Comment(expr_1.idf + ' ' + op_ir.name + ' ' + expr_2.idf)
        outputShape = typ_3.shape
        argsDict = OrderedDict()
        for ii, curDimSize in enumerate(outputShape):
            argsDict[IR.Int(curDimSize, 32)] = "size_" + str(ii)
        argsDict[expr_1] = "A"
        argsDict[expr_2] = "B"
        argsDict[expr_3] = "C"
        argsDict[shr_n3] = "shrC"
        # TODO : for consistency, add shrA and shrB ?
        funcCall = IR.FuncCall(
            funcName + self.varNameDelim + str(len(outputShape)), argsDict)
        prog_3 = IRUtil.prog_merge(prog_1, prog_2, IR.Prog([cmd0, funcCall]))
        self.decls[expr_3.idf] = [typ_3]
        prog_3 = IRUtil.prog_merge(IR.Prog([IR.Decl(expr_3.idf, node.type)]),
                                   prog_3)

        return (prog_3, expr_3)
Beispiel #6
0
    def visitTransp(self, node: AST.Transp, args=None):
        (prog_1, expr_1) = self.visit(node.expr)

        # decl fresh vars
        expr_2 = self.getTempVar()

        # cmdl_for
        typ_2 = node.type
        [I, J] = typ_2.shape
        cmd0 = IR.Comment(expr_1.idf + "^T")
        comment = IR.Comment(str(node.metadata))
        argsList = OrderedDict()
        argsList[IR.Int(I, 32)] = "I"
        argsList[IR.Int(J, 32)] = "J"
        argsList[expr_1] = "A"
        argsList[expr_2] = "B"
        funcCall = IR.FuncCall(
            "Transpose" + self.varNameDelim + str(len(typ_2.shape)), argsList)

        prog_for = IR.Prog([cmd0, comment, funcCall])
        prog_2 = IRUtil.prog_merge(prog_1, prog_for)

        self.decls[expr_2.idf] = [typ_2]
        prog = IRUtil.prog_merge(IR.Prog([IR.Decl(expr_2.idf, typ_2)]), prog)
        return (prog_2, expr_2)
Beispiel #7
0
	def visitUOp(self, node:AST.UOp, args=None):
		(prog_1, expr_1) = self.visit(node.expr)
		op = node.op
		if op == AST.Operators.ADD:
			return (prog_1, expr_1)
		assert op == AST.Operators.SUB
		
		typ_2 = node.type
		expr_2 = self.getTempVar()
		
		if Type.isInt(typ_2):
			comment = IR.Comment(str(node.metadata))
			bitlen = node.expr.bitlen 
			decl = IR.Decl(expr_2.idf, node.type, typ_2.bitlen, typ_2.isSecret)
			assign = IR.Assn(expr_2, IRUtil.negate(expr_1))
			prog_2 = IRUtil.prog_merge(prog_1, IR.Prog([comment, decl, assign]))
		else:
			# decl fresh vars
			iters = self.getTempIterators(typ_2.dim)

			# cmdl_assn
			expr_1_elt = IRUtil.addIndex(expr_1, iters)
			expr_2_elt = IRUtil.addIndex(expr_2, iters)
			cmdl_assn = IRUtil.loop(typ_2.shape, iters, [IR.Assn(expr_2_elt, IRUtil.negate(expr_1_elt))])
			comment = IR.Comment(str(node.metadata))
			prog_2 = IRUtil.prog_merge(prog_1, IR.Prog([comment] + cmdl_assn))
			prog_2 = IRUtil.prog_merge(IR.Prog([IR.Decl(expr_2.idf, node.type)]), prog_2)

		if not(Util.Config.disableTruncOpti):
			self.scaleFacMapping[expr_2.idf] = self.scaleFacMapping[expr_1.idf]

		return (prog_2, expr_2)
Beispiel #8
0
	def visitBopMul2DTensor(self, node:AST.BOp, args=None):
		(prog_1, expr_1) = self.visit(node.expr1)
		(prog_2, expr_2) = self.visit(node.expr2)

		# decl fresh vars
		expr_3 = self.getTempVar()

		typ_1 = node.expr1.type
		typ_2 = node.expr2.type
		typ_3 = node.type

		[I, J] = typ_1.shape
		[J, K] = typ_2.shape
		typ_mul = Type.Tensor([J])

		shrT = Util.Config.consSF

		cmd0 = IR.Comment(expr_1.idf + ' * ' + expr_2.idf)
		funcCallArgsDict = OrderedDict()
		funcCallArgsDict[IR.Int(I, 32)] = "I"
		funcCallArgsDict[IR.Int(J, 32)] = "J"
		funcCallArgsDict[IR.Int(K, 32)] = "K"
		funcCallArgsDict[expr_1] = "A"
		funcCallArgsDict[expr_2] = "B"
		funcCallArgsDict[expr_3] = "C"
		funcCallArgsDict[IR.Int(shrT)] = "shrT"
		
		
		funcCall = IR.FuncCall("MatMulCSF2D", funcCallArgsDict)
		comment = IR.Comment(str(node.metadata))
		prog_3 = IRUtil.prog_merge(prog_1, prog_2, IR.Prog([comment, cmd0, funcCall]))		
		self.decls[expr_3.idf] = [typ_3]
		prog_3 = IRUtil.prog_merge(IR.Prog([IR.Decl(expr_3.idf, node.type)]), prog_3)

		return (prog_3, expr_3)
Beispiel #9
0
	def visitBopMulInt(self, node:AST.BOp, args=None):
		(prog_1, expr_1) = self.visit(node.expr1)
		(prog_2, expr_2) = self.visit(node.expr2)

		expr_3 = self.getTempVar()
		comment = IR.Comment(str(node.metadata))
		bitlen = node.expr.bitlen
		decl = IR.Decl(expr_3.idf, node.type, node.type.bitlen, node.type.isSecret)
		assign = IR.Assn(expr_3, IRUtil.mul(expr_1, expr_2))
		prog_3 = IRUtil.prog_merge(prog_1, prog_2, IR.Prog([comment, decl, assign]))

		progExtraBefore = IR.Prog([])
		progExtraAfter = IR.Prog([])
		if (Util.Config.disableTruncOpti):
			progExtraAfter = self.addTruncateFunctionCall(node, "MulInt", expr_3, Util.Config.consSF)
		else:
			expr1_sf = self.scaleFacMapping[expr_1.idf]
			expr2_sf = self.scaleFacMapping[expr_2.idf]
			if (expr1_sf > self.scaleFac):
				progExtraBefore = self.addTruncateFunctionCall(node.expr1, "MulInt", expr_1, expr1_sf-self.scaleFac)
				self.scaleFacMapping[expr_1.idf] = self.scaleFac
			if (expr2_sf > self.scaleFac):
				progExtraBefore = IRUtil.prog_merge(progExtraBefore, self.addTruncateFunctionCall(node.expr2, "MulInt", expr_2, expr2_sf-self.scaleFac))
				self.scaleFacMapping[expr_2.idf] = self.scaleFac
			self.scaleFacMapping[expr_3.idf] = 2*self.scaleFac

		prog_3 = IRUtil.prog_merge(progExtraBefore, prog_3, progExtraAfter)
		return (prog_3, expr_3)
Beispiel #10
0
    def visitFusedBatchNorm(self, node: AST.FusedBatchNorm, args=None):
        (prog1, expr1) = self.visit(node.expr)
        (prog2, expr2) = self.visit(node.multExpr)
        (prog3, expr3) = self.visit(node.addExpr)

        returnExpr = self.getTempVar()

        funcArgsList = OrderedDict()
        for ii, elem in enumerate(node.type.shape):
            funcArgsList[IR.Int(elem, 32)] = "elem_" + str(ii)
        funcArgsList[expr1] = "expr"
        funcArgsList[expr2] = "multExpr"
        funcArgsList[expr3] = "addExpr"
        funcArgsList[IR.Int(Util.Config.consSF, 32)] = "consSF"
        funcArgsList[returnExpr] = "returnExpr"
        funcCallIR = IR.FuncCall(
            "FusedBatchNorm" + self.varNameDelim + str(len(node.type.shape)) +
            self.varNameDelim  #one for output
            + str(len(node.type.shape)) + self.varNameDelim  #one for input
            + str(len(node.multExpr.type.shape)) + self.varNameDelim +
            str(len(node.addExpr.type.shape)),
            funcArgsList)

        comment = IR.Comment(str(node.metadata))
        returnProg = IRUtil.prog_merge(prog1, prog2, prog3,
                                       IR.Prog([comment, funcCallIR]))

        self.decls[returnExpr.idf] = [node.type]
        returnProg = IRUtil.prog_merge(
            IR.Prog([IR.Decl(returnExpr.idf, node.type)]), returnProg)
        return (returnProg, returnExpr)
Beispiel #11
0
	def visitFusedBatchNorm(self, node:AST.FusedBatchNorm, args=None):
		(prog1, expr1) = self.visit(node.expr)
		(prog2, expr2) = self.visit(node.multExpr)
		(prog3, expr3) = self.visit(node.addExpr)

		returnExpr = self.getTempVar()

		funcArgsList = OrderedDict()
		for ii, elem in enumerate(node.type.shape):
			funcArgsList[IR.Int(elem, 32)] = "elem_"+str(ii)
		funcArgsList[expr1] = "expr"
		funcArgsList[expr2] = "multExpr"
		funcArgsList[expr3] = "addExpr"

		progExtraBefore = IR.Prog([])
		multExprScaleDownSf = self.scaleFac
		addExprScaleUpSf = 0
		if not(Util.Config.disableTruncOpti):
			#TruncOpti is on
			multExprScaleDownSf = 0
			addExprScaleUpSf = 0

			expr_sf = self.scaleFacMapping[expr1.idf]
			multExpr_sf = self.scaleFacMapping[expr2.idf]
			addExpr_sf = self.scaleFacMapping[expr3.idf]
			if (expr_sf > self.scaleFac):
				#Scale down needed
				progExtraBefore = self.addTruncateFunctionCall(node.expr, "FusedBatchNorm", expr1, expr_sf - self.scaleFac)
				self.scaleFacMapping[expr1.idf] = self.scaleFac

			if (multExpr_sf > self.scaleFac):
				#Scale down needed
				progExtraBefore = IRUtil.prog_merge(progExtraBefore, self.addTruncateFunctionCall(node.multExpr, "FusedBatchNorm", expr2, multExpr_sf - self.scaleFac))
				self.scaleFacMapping[expr2.idf] = self.scaleFac

			final_sf = 2*self.scaleFac
			assert(final_sf >= addExpr_sf)
			if (final_sf > addExpr_sf):
				addExprScaleUpSf = final_sf - addExpr_sf
				self.scaleFacMapping[expr3.idf] += addExprScaleUpSf

			self.scaleFacMapping[returnExpr.idf] = final_sf

		funcArgsList[IR.Int(multExprScaleDownSf, 32)] = "multExprScaleDownSf"
		funcArgsList[IR.Int(addExprScaleUpSf, 32)] = "addExprScaleUpSf"
		funcArgsList[returnExpr] = "returnExpr"
			
		funcCallIR = IR.FuncCall("FusedBatchNorm" + self.varNameDelim 
							  + str(len(node.type.shape)) + self.varNameDelim #one for output
							  + str(len(node.type.shape)) + self.varNameDelim #one for input
							  + str(len(node.multExpr.type.shape)) + self.varNameDelim
							  + str(len(node.addExpr.type.shape)), 
							  funcArgsList)

		comment = IR.Comment(str(node.metadata))
		returnProg = IRUtil.prog_merge(prog1, prog2, prog3, progExtraBefore, IR.Prog([comment, funcCallIR]))

		returnProg = IRUtil.prog_merge(IR.Prog([IR.Decl(returnExpr.idf, node.type)]), returnProg)
		return (returnProg, returnExpr)
Beispiel #12
0
	def visitFloorLike(self, node:AST.Func, args=None):
		(prog1, expr1) = self.visit(node.expr)
		tmpExpr = self.getTempVar()

		if node.op == AST.Operators.Floor:
			funcName = "Floor"
		elif node.op == AST.Operators.Shape:
			funcName = "Shape"
		elif node.op == AST.Operators.RELU:
			funcName = "Relu"
		elif node.op == AST.Operators.ClearMemSecret:
			funcName = "ClearMemSecret"
		elif node.op == AST.Operators.ClearMemPublic:
			funcName = "ClearMemPublic"
		else:
			assert False

		argsList = OrderedDict()
		
		inputType = node.expr.type
		if Type.isTensor(inputType):
			for ii, curDim in enumerate(inputType.shape):
				argsList[IR.Int(curDim, 32)] = "inShape_" + str(ii)
		argsList[expr1] = "inArr"

		if Type.isTensor(node.type):
			argsList[tmpExpr] = "outArr"
			
		if node.op == AST.Operators.Floor:
			argsList[IR.Int(Util.Config.consSF,32)] = "curScale"

		progExtra = IR.Prog([])
		if (Util.Config.disableTruncOpti):
			if node.op == AST.Operators.RELU:
				argsList[IR.Int(Util.Config.consSF,32)] = "consSF"
				argsList[IR.Bool(False)] = "doTruncation"
		else:
			final_sf = self.scaleFacMapping[expr1.idf]
			if node.op == AST.Operators.RELU:
				argsList[IR.Int(final_sf - self.scaleFac,32)] = "consSF"
				if (final_sf > self.scaleFac):
					#If it can't tolerate one more mult operation, then scale down here
					final_sf = self.scaleFac
					argsList[IR.Bool(True)] = "doTruncation"
				else:
					argsList[IR.Bool(False)] = "doTruncation"
			self.scaleFacMapping[tmpExpr.idf] = final_sf
		
		comment = IR.Comment(str(node.metadata))
		funcNameSuffix = ""
		if Type.isTensor(inputType):
			funcNameSuffix = str(len(inputType.shape))

		progFinal = IRUtil.prog_merge(prog1 , IR.Prog([comment, IR.FuncCall(funcName + self.varNameDelim + funcNameSuffix, argsList)]))
		if Type.isTensor(node.type):
			progFinal = IRUtil.prog_merge(IR.Prog([IR.Decl(tmpExpr.idf, node.type)]), progFinal)

		progFinal = IRUtil.prog_merge(progFinal, progExtra)
		return (progFinal, tmpExpr)
Beispiel #13
0
    def visitBopAddOrSub(self, node: AST.BOp, args=None):
        (prog_1, expr_1) = self.visit(node.expr1)
        (prog_2, expr_2) = self.visit(node.expr2)

        # op_ir, typ_3
        op = node.op
        if (op == AST.Operators.ADD):
            (op_ir, op_fn) = (IR.Op.Op['+'], operator.add)
            funcName = "MatAdd"
        elif (op == AST.Operators.SUB):
            (op_ir, op_fn) = (IR.Op.Op['-'], operator.sub)
            funcName = "MatSub"
        else:
            assert False

        typ_3 = node.type

        # e : Int
        if Type.isInt(typ_3):
            prog_3 = IRUtil.prog_merge(prog_1, prog_2)
            expr_3 = IR.IntBop(expr_1, op_ir, expr_2)
        # e : Tensor() -- float, or Tensor(..)
        else:
            ## TODO : Hack for techfest
            if (node.type.dim != node.expr1.type.dim):
                # This needs broadcast of expr1
                assert False  # For now this shouldn't occur
            if (node.type.dim != node.expr2.type.dim):
                # This needs broadcast of expr2
                funcName += 'BroadCast'

            # decl fresh vars
            expr_3 = self.getTempVar()

            cmd0 = IR.Comment(expr_1.idf + ' ' + op_ir.name + ' ' + expr_2.idf)
            outputShape = typ_3.shape
            argsDict = OrderedDict()
            inp1_shape = node.expr1.type.shape
            inp2_shape = node.expr2.type.shape
            for ii, curDimSize in enumerate(inp1_shape):
                argsDict[IR.Int(curDimSize, 32)] = "size_" + str(ii)
            for ii, curDimSize in enumerate(inp2_shape):
                argsDict[IR.Int(curDimSize, 32)] = "size_" + str(ii)
            for ii, curDimSize in enumerate(outputShape):
                argsDict[IR.Int(curDimSize, 32)] = "size_" + str(ii)
            argsDict[expr_1] = "A"
            argsDict[expr_2] = "B"
            argsDict[expr_3] = "C"
            funcCall = IR.FuncCall(
                funcName + self.varNameDelim + str(len(outputShape)), argsDict)
            comment = IR.Comment(str(node.metadata))
            prog_3 = IRUtil.prog_merge(prog_1, prog_2,
                                       IR.Prog([comment, cmd0, funcCall]))
            self.decls[expr_3.idf] = [typ_3]
            prog_3 = IRUtil.prog_merge(
                IR.Prog([IR.Decl(expr_3.idf, node.type)]), prog_3)

        return (prog_3, expr_3)
Beispiel #14
0
	def visitUninterpFuncCall(self, node:AST.UninterpFuncCall, args=None):
		progList = []
		exprList = []
		for ii, curArg in enumerate(node.argsList):
			(progN, exprN) = self.visit(curArg)
			progList.append(progN)
			exprList.append(exprN)
		
		returnExpr = self.getTempVar()

		funcName = node.funcName
		funcName += self.varNameDelim + str(len(node.outputShape))
		for ii, curArg in enumerate(node.argsList):
			if Type.isTensor(curArg.type):
				curShape = curArg.type.shape

				# If len(shape) == 0 : that means its a float - no need to qualify
				#	the function name with 0 in that case, since its essentially
				#	become an int.
				if (len(curShape) > 0):
					funcName += self.varNameDelim + str(len(curShape))
			### TODO : WRONG -- TEMP FIX -- right now if random strings like int are passed, its being set as datatype int -- int datatype in
			#		   unintrepreted func call is being used in a hacky way right now -- fix this later

		# Policy : 
		#	First output tensor sizes are inserted in args.
		#	Then for each input tensor, its shape is inserted in args, followed by the input tensor itself.
		#	If the current input tensor has the same shape as any of the previous tensors, then its shape is not inserted.
		funcArgsList = OrderedDict()
		tensorShapesFound = {}
		outputShape = node.type.shape
		for ii, curDim in enumerate(outputShape):
			funcArgsList[IR.Int(curDim, 32)] = "OutputShape_" + str(ii)
		tensorShapesFound[tuple(outputShape)] = True
		for ii in range(0, len(node.argsList)):
			if node.outputDiffInpDims < 2 and Type.isTensor(node.argsList[ii].type):
				curInpShape = node.argsList[ii].type.shape
				if ((node.outputDiffInpDims == 1) or (node.outputDiffInpDims == 0 and tuple(curInpShape) not in tensorShapesFound)):
					for jj, curDim in enumerate(curInpShape):
						funcArgsList[IR.Int(curDim, 32)] = "Input_" + str(ii) + self.varNameDelim + str(jj)
					tensorShapesFound[tuple(curInpShape)] = True
			funcArgsList[exprList[ii]] = "inpExpr_" + str(ii)
		funcArgsList[returnExpr] = "output"
		
		comment = IR.Comment(str(node.metadata))
		progFinal = progList[0]
		if len(progList) > 1:
			for ii in range(1, len(progList)):
				progFinal = IRUtil.prog_merge(progFinal, progList[ii])
		progFinal = IRUtil.prog_merge(progFinal, IR.Prog([comment, IR.FuncCall(funcName, funcArgsList)]))

		self.decls[returnExpr.idf] = [node.type, "public" if node.isSecret is False else "secret"]
		progFinal = IRUtil.prog_merge(IR.Prog([IR.Decl(returnExpr.idf, 
														node.type, 
														isSecret="public" if node.isSecret is False else "secret")]), 
										progFinal)
		return (progFinal, returnExpr)
Beispiel #15
0
    def visitFloorLike(self, node: AST.Func, args=None):
        (prog1, expr1) = self.visit(node.expr)
        tmpExpr = self.getTempVar()

        if node.op == AST.Operators.Floor:
            funcName = "Floor"
        elif node.op == AST.Operators.Shape:
            funcName = "Shape"
        elif node.op == AST.Operators.RELU:
            funcName = "Relu"
        elif node.op == AST.Operators.ClearMemSecret:
            funcName = "ClearMemSecret"
        elif node.op == AST.Operators.ClearMemPublic:
            funcName = "ClearMemPublic"
        else:
            assert False

        argsList = OrderedDict()

        inputType = node.expr.type
        if Type.isTensor(inputType):
            for ii, curDim in enumerate(inputType.shape):
                argsList[IR.Int(curDim, 32)] = "inShape_" + str(ii)
        argsList[expr1] = "inArr"

        if Type.isTensor(node.type):
            argsList[tmpExpr] = "outArr"
            self.decls[tmpExpr.idf] = [node.type]

        if node.op == AST.Operators.Floor:
            argsList[IR.Int(Util.Config.consSF)] = "curScale"
        comment = IR.Comment(str(node.metadata))

        funcNameSuffix = ""
        if Type.isTensor(inputType):
            funcNameSuffix = str(len(inputType.shape))

        progFinal = IRUtil.prog_merge(
            prog1,
            IR.Prog([
                comment,
                IR.FuncCall(funcName + self.varNameDelim + funcNameSuffix,
                            argsList)
            ]))
        if Type.isTensor(node.type):
            progFinal = IRUtil.prog_merge(
                IR.Prog([IR.Decl(tmpExpr.idf, node.type)]), progFinal)

        return (progFinal, tmpExpr)
Beispiel #16
0
    def visitPool(self, node: AST.Pool, args=None):
        (prog_1, expr_1) = self.visit(node.expr)

        [N, H, W, CI] = node.expr.type.shape
        [N1, outH, outW, CI1] = node.type.shape
        assert (N == N1 and CI == CI1)
        [expr_2] = self.getTempVars(1)

        comment = IR.Comment(str(node.metadata))
        funcCallArgsDict = OrderedDict()
        funcCallArgsDict[IR.Int(N1, 32)] = "N1"
        funcCallArgsDict[IR.Int(outH, 32)] = "outH"
        funcCallArgsDict[IR.Int(outW, 32)] = "outW"
        funcCallArgsDict[IR.Int(CI1, 32)] = "CI1"
        funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.FH],
                                32)] = "FH"
        funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.FW],
                                32)] = "FW"
        funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.zPadHLeft],
                                32)] = "zPadHLeft"
        funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.zPadHRight],
                                32)] = "zPadHRight"
        funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.zPadWLeft],
                                32)] = "zPadWLeft"
        funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.zPadWRight],
                                32)] = "zPadWRight"
        funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.strideH],
                                32)] = "strideH"
        funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.strideW],
                                32)] = "strideW"
        funcCallArgsDict[IR.Int(N, 32)] = "N"
        funcCallArgsDict[IR.Int(H, 32)] = "H"
        funcCallArgsDict[IR.Int(W, 32)] = "W"
        funcCallArgsDict[IR.Int(CI, 32)] = "CI"

        funcCallArgsDict[expr_1] = "input"
        funcCallArgsDict[expr_2] = "output"

        funcCall = IR.FuncCall(node.poolType, funcCallArgsDict)

        prog_pool = IR.Prog([comment, funcCall])
        prog_2 = IRUtil.prog_merge(prog_1, prog_pool)

        self.decls[expr_2.idf] = [node.type]
        prog_2 = IRUtil.prog_merge(IR.Prog([IR.Decl(expr_2.idf, node.type)]),
                                   prog_2)

        return (prog_2, expr_2)
Beispiel #17
0
    def fixOuputScale(self, res: (IR.Prog, IR.Expr), compiler: IRBuilderCSF):
        prog = res[0]
        expr = res[1]
        output_scale = compiler.scaleFacMapping[expr.idf]
        if output_scale == -1 or output_scale == Util.Config.consSF:
            return (prog, expr)
        elif output_scale > Util.Config.consSF:
            scale_down = output_scale - Util.Config.consSF
            type = compiler.typeInfo[expr.idf]
            if Type.isInt(type):
                output_shape = []
            if Type.isTensor(type):
                output_shape = type.shape

            argsDict = OrderedDict()
            funcName = "ScaleDown"
            for ii, curDimSize in enumerate(output_shape):
                argsDict[IR.Int(curDimSize, 32)] = "size_" + str(ii)
            funcName = funcName + str(len(output_shape))
            argsDict[expr] = "expr"
            argsDict[IR.Int(scale_down, 32)] = "consSF"
            funcCall = IR.FuncCall(funcName, argsDict)
            new_prog = IR.Prog([funcCall])
            prog = IRUtil.prog_merge(prog, new_prog)
            return (prog, expr)
        else:
            assert False, "Scale up shouldnt be required of final output {} -> {}. We lost precision somewhere".format(
                output_scale, Util.Config.consSF)
Beispiel #18
0
	def visitDecl(self, node:AST.Decl, args=None):
		def helperAssignGen(l1, l2, allComb):
			if l2 == []:
				allComb.append(l1)
			else:
				for cur in range(l2[0]):
					helperAssignGen(l1 + [cur], l2[1:], allComb)

		prog = IR.Prog([])
		expr = self.getTempVar()
		expr.inputVar = True
		
		# If there is a valueList, then add assignment commands
		specialBitLen = -1
		if node.valueList:
			# Add assignment statements for each element of the tensor in a different array
			comment = IR.Comment(str(node.metadata))
			prog = IRUtil.prog_merge(prog, IR.Prog([comment, IR.Comment('Element assignments for ' + expr.idf)]))
			allComb = []
			helperAssignGen([], node.shape, allComb)
			for i,curComb in enumerate(allComb):
				curVal = node.valueList[i]
				finalVal = None
				if isinstance(curVal, AST.Int):
					finalVal = IR.Int(curVal.value, curVal.bitLen)
					if (specialBitLen == -1 and curVal.bitLen != Util.Config.wordLength):
						specialBitLen = curVal.bitLen
				elif isinstance(curVal, AST.Float):
					finalVal = IR.DataType.getInt(np.ldexp(curVal.value, Util.Config.consSF))
				else:
					# Assuming the elements can only be either int or floats
					assert False
				prog = IRUtil.prog_merge(prog, IR.Prog([IR.Assn(IRUtil.addIndex(expr, 
																				list(map(lambda x: IR.Int(x), curComb))), 
																finalVal)]))

		self.decls[expr.idf] = [node.type, 
								"public" if node.isSecret is False else "secret",
								Util.Config.wordLength if specialBitLen == -1 else specialBitLen
								]
		prog = IRUtil.prog_merge(IR.Prog([IR.Decl(expr.idf, 
												node.type, 
												Util.Config.wordLength if specialBitLen == -1 else specialBitLen, 
												"public" if node.isSecret is False else "secret"
												)]), prog)
		
		return (prog, expr)
Beispiel #19
0
	def visitBopMulScalar1DTensor(self, node:AST.BOp, args=None):
		(prog_1, expr_1) = self.visit(node.expr1)
		(prog_2, expr_2) = self.visit(node.expr2)

		typ_1 = node.expr1.type
		typ_2 = node.expr2.type
		typ_3 = node.type

		isIntMult = False
		if typ_1.dim == 0 or Type.isInt(typ_1):
			a, b = expr_1, expr_2
			outputShape = typ_2.shape
			isIntMult = (Type.isInt(typ_1))
		else:
			a, b = expr_2, expr_1
			outputShape = typ_1.shape
			isIntMult = (Type.isInt(typ_2))

		# decl fresh vars
		expr_3 = self.getTempVar()
		cmd0 = IR.Comment(expr_1.idf + ' * ' + expr_2.idf)
		funcCallArgsDict = OrderedDict()
		for ii,curDimSize in enumerate(outputShape):
				funcCallArgsDict[IR.Int(curDimSize, 32)] = "size_" + str(ii)
		funcCallArgsDict[a] = "A"
		funcCallArgsDict[b] = "B"
		funcCallArgsDict[expr_3] = "C"
		progExtraBefore = IR.Prog([])
		progExtraAfter = IR.Prog([])
		if (Util.Config.disableTruncOpti):
			progExtraAfter = self.addTruncateFunctionCall(node, "ScalarMul", expr_3, Util.Config.consSF)
		else:
			expr1_sf = self.scaleFacMapping[expr_1.idf]
			expr2_sf = self.scaleFacMapping[expr_2.idf]
			if (expr1_sf > self.scaleFac):
				progExtraBefore = self.addTruncateFunctionCall(node.expr1, "ScalarMul", expr_1, expr1_sf-self.scaleFac)
				self.scaleFacMapping[expr_1.idf] = self.scaleFac
			if (expr2_sf > self.scaleFac):
				progExtraBefore = IRUtil.prog_merge(progExtraBefore, self.addTruncateFunctionCall(node.expr2, "ScalarMul", expr_2, expr2_sf-self.scaleFac))
				self.scaleFacMapping[expr_2.idf] = self.scaleFac
			self.scaleFacMapping[expr_3.idf] = 2*self.scaleFac

		funcCall = IR.FuncCall('ScalarMul' + self.varNameDelim + str(len(outputShape)), funcCallArgsDict)
		prog_3 = IRUtil.prog_merge(prog_1, prog_2, progExtraBefore, IR.Prog([cmd0, funcCall]))
		prog_3 = IRUtil.prog_merge(IR.Prog([IR.Decl(expr_3.idf, node.type)]), prog_3, progExtraAfter)
		return (prog_3, expr_3)
Beispiel #20
0
    def visitBopMulInt(self, node: AST.BOp, args=None):
        (prog_1, expr_1) = self.visit(node.expr1)
        (prog_2, expr_2) = self.visit(node.expr2)

        prog_3 = IRUtil.prog_merge(prog_1, prog_2)
        expr_3 = IRUtil.mul(expr_1, expr_2)

        return (prog_3, expr_3)
Beispiel #21
0
    def visitReduce(self, node: AST.Reduce, args=None):
        (prog_1, expr1) = self.visit(node.expr)
        (prog_2, expr2) = self.visit(node.dim)

        returnExpr = self.getTempVar()

        assert (node.op
                in [AST.Operators.ADD, AST.Operators.MUL, AST.Operators.Mean])
        scalingFac = None
        if (node.op == AST.Operators.ADD):
            funcName = "ReduceSum"
        elif (node.op == AST.Operators.MUL):
            funcName = "ReduceMul"
            scalingFac = Util.Config.consSF
        elif (node.op == AST.Operators.Mean):
            funcName = "ReduceMean"
        else:
            print("Unknown node.op in AST.Reduce.", file=sys.stderr)
            assert False

        funcArgsList = OrderedDict()
        outputShape = node.type.shape
        for ii, curDim in enumerate(outputShape):
            funcArgsList[IR.Int(curDim, 32)] = "OutputShape_" + str(ii)

        inputShape = node.expr.type.shape
        for ii, curDim in enumerate(inputShape):
            funcArgsList[IR.Int(curDim, 32)] = "InputShape_" + str(ii)

        funcArgsList[expr1] = "inputArr"
        funcArgsList[expr2] = "dimension"
        funcArgsList[returnExpr] = "outArr"
        if scalingFac:
            funcArgsList[IR.Int(scalingFac)] = "ScalingFactor"

        funcCall = IR.FuncCall(
            funcName + self.varNameDelim + str(len(outputShape)) +
            self.varNameDelim + str(len(inputShape)), funcArgsList)
        comment = IR.Comment(str(node.metadata))
        prog_3 = IRUtil.prog_merge(prog_1, prog_2, IR.Prog([comment,
                                                            funcCall]))

        self.decls[returnExpr.idf] = [node.type]
        prog_3 = IRUtil.prog_merge(
            IR.Prog([IR.Decl(returnExpr.idf, node.type)]), prog_3)
        return (prog_3, returnExpr)
Beispiel #22
0
 def visitLet(self, node: AST.Let, args=None):
     (prog_1, expr_1) = self.visit(node.decl)
     typ_1 = node.decl.type
     idf = node.name.name
     (prog_2, expr_2) = self.visit(node.expr)
     prog_2 = prog_2.subst(idf, expr_1)
     expr_2 = expr_2.subst(idf, expr_1)
     prog_3 = IRUtil.prog_merge(prog_1, prog_2)
     return (prog_3, expr_2)
Beispiel #23
0
    def visitBopConv(self, node: AST.BOp, args=None):
        (prog1, expr1) = self.visit(node.expr1)
        (prog2, expr2) = self.visit(node.expr2)

        [N, H, W, CI] = node.expr1.type.shape
        [FH, FW, CI, CO] = node.expr2.type.shape

        returnExpr = self.getTempVar()
        comment = IR.Comment(expr1.idf + ' # ' + expr2.idf)
        funcCallArgsDict = OrderedDict()
        funcCallArgsDict[IR.Int(N, 32)] = "N"
        funcCallArgsDict[IR.Int(H, 32)] = "H"
        funcCallArgsDict[IR.Int(W, 32)] = "W"
        funcCallArgsDict[IR.Int(CI, 32)] = "CI"
        funcCallArgsDict[IR.Int(FH, 32)] = "FH"
        funcCallArgsDict[IR.Int(FW, 32)] = "FW"
        funcCallArgsDict[IR.Int(CO, 32)] = "CO"
        funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.zPadHLeft],
                                32)] = "zPadHLeft"
        funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.zPadHRight],
                                32)] = "zPadHRight"
        funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.zPadWLeft],
                                32)] = "zPadWLeft"
        funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.zPadWRight],
                                32)] = "zPadWRight"
        funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.strideH],
                                32)] = "strideH"
        funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.strideW],
                                32)] = "strideW"

        funcCallArgsDict[expr1] = "input"
        funcCallArgsDict[expr2] = "filter"
        funcCallArgsDict[IR.Int(Util.Config.consSF, 32)] = "consSF"
        funcCallArgsDict[returnExpr] = "output"

        funcCall = IR.FuncCall("Conv2DCSF", funcCallArgsDict)

        progConv = IR.Prog([comment, funcCall])
        returnProg = IRUtil.prog_merge(prog1, prog2, progConv)

        self.decls[returnExpr.idf] = [node.type]
        returnProg = IRUtil.prog_merge(
            IR.Prog([IR.Decl(returnExpr.idf, node.type)]), returnProg)
        return (returnProg, returnExpr)
Beispiel #24
0
	def visitFloat(self, node:AST.Float, args=None):
		r = node.value
		p = self.get_expnt(abs(r))
		k = IR.DataType.getInt(np.ldexp(r, p))
		expr = self.getTempVar()
		prog = IR.Prog([IR.Comment('Float to int : {0} to {1}, isSecret = {2}.'.format(str(r), str(k), node.isSecret))])
		prog = IRUtil.prog_merge(IR.Prog([IR.Decl(expr.idf, node.type, -1, node.isSecret, [k])]), prog)
		if (not(Util.Config.disableTruncOpti)):
			self.scaleFacMapping[expr.idf] = self.scaleFac
		return (prog, expr)
Beispiel #25
0
    def visitBopMulScalar1DTensor(self, node: AST.BOp, args=None):
        (prog_1, expr_1) = self.visit(node.expr1)
        (prog_2, expr_2) = self.visit(node.expr2)

        typ_1 = node.expr1.type
        typ_2 = node.expr2.type
        typ_3 = node.type

        isIntMult = False
        if typ_1.dim == 0 or Type.isInt(typ_1):
            a, b = expr_1, expr_2
            outputShape = typ_2.shape
            isIntMult = (Type.isInt(typ_1))
        else:
            a, b = expr_2, expr_1
            outputShape = typ_1.shape
            isIntMult = (Type.isInt(typ_2))

        # a represents the scalar and b the tensor
        shr3 = IR.Int(Util.Config.consSF)
        if isIntMult: shr3 = 0  # If multiplying with an int, then sf = 0

        # decl fresh vars
        expr_3 = self.getTempVar()
        cmd0 = IR.Comment(expr_1.idf + ' * ' + expr_2.idf)
        funcCallArgsDict = OrderedDict()
        for ii, curDimSize in enumerate(outputShape):
            funcCallArgsDict[IR.Int(curDimSize, 32)] = "size_" + str(ii)
        funcCallArgsDict[a] = "A"
        funcCallArgsDict[b] = "B"
        funcCallArgsDict[expr_3] = "C"
        funcCallArgsDict[shr3] = "shr3"

        funcCall = IR.FuncCall(
            'ScalarMul' + self.varNameDelim + str(len(outputShape)),
            funcCallArgsDict)
        prog_3 = IRUtil.prog_merge(prog_1, prog_2, IR.Prog([cmd0, funcCall]))

        self.decls[expr_3.idf] = [typ_3]
        prog_3 = IRUtil.prog_merge(IR.Prog([IR.Decl(expr_3.idf, node.type)]),
                                   prog_3)

        return (prog_3, expr_3)
Beispiel #26
0
	def visitIndex(self, node:AST.Index, args=None):
		(prog_1, expr_1) = self.visit(node.expr)
		prog_idx = expr_idx = []
		for curIdx in node.index:
			(prog_cur, expr_cur) = self.visit(curIdx)
			prog_idx.append(prog_cur)
			expr_idx.append(expr_cur)
		prog_3 = IRUtil.prog_merge(prog_1, [curCmd for curProg in prog_idx for curCmd in curProg])
		expr_3 = IRUtil.addIndex(expr_1, expr_idx)

		return (prog_3, expr_3)
Beispiel #27
0
	def visitInt(self, node:AST.Int, args=None):
		n = node.value
		prog = IR.Prog([IR.Comment('Int node, isSecret = {0}.'.format(node.isSecret))])
		expr = self.getTempVar()
		bitlen = -1
		if node.bitLen: 
			bitlen = node.bitLen
		prog = IRUtil.prog_merge(IR.Prog([IR.Decl(expr.idf, node.type, bitlen, node.isSecret, [n])]), prog)
		if (not(Util.Config.disableTruncOpti)):
			self.scaleFacMapping[expr.idf] = self.scaleFac if node.isScaled else 0
		return (prog, expr)
Beispiel #28
0
	def visitBopElemWiseOp(self, node:AST.BOp, args=None):
		(prog_1, expr_1) = self.visit(node.expr1)
		(prog_2, expr_2) = self.visit(node.expr2)

		if (node.op == AST.Operators.ElemWiseMul):
			op_ir = IR.Op.Op['.*']
			funcName = "ElemWiseMul"
		elif (node.op == AST.Operators.ElemWiseDiv):
			op_ir = IR.Op.Op['./']
			funcName = "ElemWiseDiv"

		typ_3 = node.type
		expr_3 = self.getTempVar()
		cmd0 = IR.Comment(expr_1.idf + ' ' + op_ir.name + ' ' + expr_2.idf)
		outputShape = typ_3.shape
		argsDict = OrderedDict()
		for ii,curDimSize in enumerate(outputShape):
			argsDict[IR.Int(curDimSize, 32)] = "size_" + str(ii)
		argsDict[expr_1] = "A"
		argsDict[expr_2] = "B"
		argsDict[expr_3] = "C"
		progExtraBefore = IR.Prog([])
		progExtraAfter = IR.Prog([])
		if (Util.Config.disableTruncOpti):
			progExtraAfter = self.addTruncateFunctionCall(node, funcName, expr_3, Util.Config.consSF)
		else:
			expr1_sf = self.scaleFacMapping[expr_1.idf]
			expr2_sf = self.scaleFacMapping[expr_2.idf]
			if (expr1_sf > self.scaleFac):
				progExtraBefore = self.addTruncateFunctionCall(node.expr1, funcName, expr_1, expr1_sf-self.scaleFac)
				self.scaleFacMapping[expr_1.idf] = self.scaleFac
			if (expr2_sf > self.scaleFac):
				progExtraBefore = IRUtil.prog_merge(progExtraBefore, self.addTruncateFunctionCall(node.expr2, funcName, expr_2, expr2_sf-self.scaleFac))
				self.scaleFacMapping[expr_2.idf] = self.scaleFac
			self.scaleFacMapping[expr_3.idf] = 2*self.scaleFac

		funcCall = IR.FuncCall(funcName + self.varNameDelim + str(len(outputShape)), argsDict)
		prog_3 = IRUtil.prog_merge(prog_1, prog_2, progExtraBefore, IR.Prog([cmd0, funcCall]))
		prog_3 = IRUtil.prog_merge(IR.Prog([IR.Decl(expr_3.idf, node.type)]), prog_3, progExtraAfter)

		return (prog_3, expr_3)
Beispiel #29
0
	def visitFloat(self, node:AST.Float, args=None):
		r = node.value
		p = self.get_expnt(abs(r))
		k = IR.DataType.getInt(np.ldexp(r, p))
		expr = None
		prog = IR.Prog([IR.Comment('Float to int : {0} to {1}, isSecret = {2}.'.format(str(r), str(k), node.isSecret))])
		if not(node.isSecret):
			expr = IR.Int(k)
		else:
			expr = self.getTempVar()
			self.decls[expr.idf] = [node.type]
			prog = IRUtil.prog_merge(IR.Prog([IR.Decl(expr.idf, node.type)]), prog)
		return (prog, expr)
Beispiel #30
0
 def visitFloat(self, node: AST.Float, args=None):
     r = node.value
     p = self.get_expnt(abs(r))
     k = IR.DataType.getInt(np.ldexp(r, p))
     comment = IR.Comment('Float to int : ' + str(r) + ' to ' + str(k))
     expr = None
     if not (node.isSecret):
         expr = IR.Int(k)
     else:
         expr = self.getTempVar()
         self.decls[expr.idf] = [node.type]
     prog = IRUtil.prog_merge(IR.Prog([IR.Decl(expr.idf, node.type)]), prog)
     return (IR.Prog([comment]), expr)