def visitArgMax(self, node:AST.ArgMax, args=None): (prog_1, expr1) = self.visit(node.expr) (prog_2, expr2) = self.visit(node.dim) tmpExpr = self.getTempVar() outputShape = node.type.shape funcArgsList = OrderedDict() outputShape = node.type.shape for ii, curDim in enumerate(outputShape): funcArgsList[IR.Int(curDim, 32)] = "OutputShape_" + str(ii) for ii, curDim in enumerate(node.inShape): funcArgsList[IR.Int(curDim, 32)] = "OutputShape_" + str(ii) funcArgsList[expr1] = "inArr" funcArgsList[expr2] = "dim" funcArgsList[tmpExpr] = "outArr" if not(Util.Config.disableTruncOpti): self.scaleFacMapping[tmpExpr.idf] = 0 #TODO -- is this the right thing to do? funcCall = IR.FuncCall("ArgMax" + self.varNameDelim + str(len(outputShape)), funcArgsList) comment = IR.Comment(str(node.metadata)) prog_3 = IRUtil.prog_merge(prog_1, prog_2, IR.Prog([comment, funcCall])) prog_3 = IRUtil.prog_merge(IR.Prog([IR.Decl(tmpExpr.idf, node.type)]), prog_3) return (prog_3, tmpExpr)
def visitReduce(self, node:AST.Reduce, args=None): (prog_1, expr1) = self.visit(node.expr) (prog_2, expr2) = self.visit(node.dim) returnExpr = self.getTempVar() assert(node.op in [AST.Operators.ADD, AST.Operators.Mean]) if (node.op == AST.Operators.ADD): funcName = "ReduceSum" elif (node.op == AST.Operators.Mean): funcName = "ReduceMean" if not(Util.Config.disableTruncOpti): self.scaleFacMapping[returnExpr.idf] = self.scaleFacMapping[expr1.idf] funcArgsList = OrderedDict() outputShape = node.type.shape for ii, curDim in enumerate(outputShape): funcArgsList[IR.Int(curDim, 32)] = "OutputShape_" + str(ii) inputShape = node.expr.type.shape for ii, curDim in enumerate(inputShape): funcArgsList[IR.Int(curDim, 32)] = "InputShape_" + str(ii) funcArgsList[expr1] = "inputArr" funcArgsList[expr2] = "dimension" funcArgsList[returnExpr] = "outArr" funcCall = IR.FuncCall(funcName + self.varNameDelim + str(len(outputShape)) + self.varNameDelim + str(len(inputShape)), funcArgsList) comment = IR.Comment(str(node.metadata)) prog_3 = IRUtil.prog_merge(prog_1, prog_2, IR.Prog([comment, funcCall])) prog_3 = IRUtil.prog_merge(IR.Prog([IR.Decl(returnExpr.idf, node.type)]), prog_3) return (prog_3, returnExpr)
def visitTranspose(self, node:AST.Transpose, args=None): (inp_prog, inp_arr) = self.visit(node.expr) inp_type = node.expr.type out_type = node.type inp_iters = self.getTempIterators(inp_type.dim) out_iters = [] perm = node.perm if (perm is None): perm = [i for i in reversed(range(len(inp_type.shape)))] for i in perm: out_iters.append(inp_iters[i]) out_arr = self.getTempVar() out_arr_expr = IRUtil.addIndex(out_arr, out_iters) inp_arr_expr = IRUtil.addIndex(inp_arr, inp_iters) assign_expr = IR.Assn(out_arr_expr, inp_arr_expr) loop = IRUtil.loop(inp_type.shape, inp_iters, [assign_expr]) # Finalize comment1 = IR.Comment(str(node.metadata)) comment2 = IR.Comment("transpose(" + inp_arr.idf + ", [" + ', '.join(str(e) for e in inp_type.shape) + "] --> [" + ', '.join(str(e) for e in out_type.shape) + "])") transpose_prog = IR.Prog([comment1, comment2] + loop) final_prog = IRUtil.prog_merge(inp_prog, transpose_prog) for var in inp_iters: final_prog = IRUtil.prog_merge(IR.Prog([IR.Decl(var.idf, Type.Int(), isSecret=False)]), final_prog) final_prog = IRUtil.prog_merge(IR.Prog([IR.Decl(out_arr.idf, out_type)]), final_prog) if not(Util.Config.disableTruncOpti): self.scaleFacMapping[out_arr.idf] = self.scaleFacMapping[inp_arr.idf] return (final_prog, out_arr)
def visitUOp(self, node: AST.UOp, args=None): (prog_1, expr_1) = self.visit(node.expr) op = node.op if op == AST.Operators.ADD: return (prog_1, expr_1) assert op == AST.Operators.SUB typ_2 = node.type # e : Int if Type.isInt(typ_2): prog_2 = prog_1 expr_2 = IRUtil.negate(expr_1) # e: Tensor(), or Tensor(..) else: # decl fresh vars expr_2 = self.getTempVar() iters = self.getTempIterators(typ_2.dim) # cmdl_assn expr_1_elt = IRUtil.addIndex(expr_1, iters) expr_2_elt = IRUtil.addIndex(expr_2, iters) cmdl_assn = IRUtil.loop( typ_2.shape, iters, [IR.Assn(expr_2_elt, IRUtil.negate(expr_1_elt))]) comment = IR.Comment(str(node.metadata)) prog_2 = IRUtil.prog_merge(prog_1, IR.Prog([comment] + cmdl_assn)) self.decls[expr_2.idf] = [typ_2] prog_2 = IRUtil.prog_merge( IR.Prog([IR.Decl(expr_2.idf, node.type)]), prog_2) return (prog_2, expr_2)
def visitBopElemWiseOp(self, node: AST.BOp, args=None): (prog_1, expr_1) = self.visit(node.expr1) (prog_2, expr_2) = self.visit(node.expr2) shr_n1, shr_n2 = IR.Int(0), IR.Int(0) if (node.op == AST.Operators.ElemWiseMul): op_ir = IR.Op.Op['.*'] funcName = "ElemWiseMul" shr_n3 = IR.Int(Util.Config.consSF) elif (node.op == AST.Operators.ElemWiseDiv): op_ir = IR.Op.Op['./'] funcName = "ElemWiseDiv" shr_n3 = IR.Int( Util.Config.consSF) # TODO : rem this, passing +ve -- ? typ_3 = node.type expr_3 = self.getTempVar() cmd0 = IR.Comment(expr_1.idf + ' ' + op_ir.name + ' ' + expr_2.idf) outputShape = typ_3.shape argsDict = OrderedDict() for ii, curDimSize in enumerate(outputShape): argsDict[IR.Int(curDimSize, 32)] = "size_" + str(ii) argsDict[expr_1] = "A" argsDict[expr_2] = "B" argsDict[expr_3] = "C" argsDict[shr_n3] = "shrC" # TODO : for consistency, add shrA and shrB ? funcCall = IR.FuncCall( funcName + self.varNameDelim + str(len(outputShape)), argsDict) prog_3 = IRUtil.prog_merge(prog_1, prog_2, IR.Prog([cmd0, funcCall])) self.decls[expr_3.idf] = [typ_3] prog_3 = IRUtil.prog_merge(IR.Prog([IR.Decl(expr_3.idf, node.type)]), prog_3) return (prog_3, expr_3)
def visitTransp(self, node: AST.Transp, args=None): (prog_1, expr_1) = self.visit(node.expr) # decl fresh vars expr_2 = self.getTempVar() # cmdl_for typ_2 = node.type [I, J] = typ_2.shape cmd0 = IR.Comment(expr_1.idf + "^T") comment = IR.Comment(str(node.metadata)) argsList = OrderedDict() argsList[IR.Int(I, 32)] = "I" argsList[IR.Int(J, 32)] = "J" argsList[expr_1] = "A" argsList[expr_2] = "B" funcCall = IR.FuncCall( "Transpose" + self.varNameDelim + str(len(typ_2.shape)), argsList) prog_for = IR.Prog([cmd0, comment, funcCall]) prog_2 = IRUtil.prog_merge(prog_1, prog_for) self.decls[expr_2.idf] = [typ_2] prog = IRUtil.prog_merge(IR.Prog([IR.Decl(expr_2.idf, typ_2)]), prog) return (prog_2, expr_2)
def visitUOp(self, node:AST.UOp, args=None): (prog_1, expr_1) = self.visit(node.expr) op = node.op if op == AST.Operators.ADD: return (prog_1, expr_1) assert op == AST.Operators.SUB typ_2 = node.type expr_2 = self.getTempVar() if Type.isInt(typ_2): comment = IR.Comment(str(node.metadata)) bitlen = node.expr.bitlen decl = IR.Decl(expr_2.idf, node.type, typ_2.bitlen, typ_2.isSecret) assign = IR.Assn(expr_2, IRUtil.negate(expr_1)) prog_2 = IRUtil.prog_merge(prog_1, IR.Prog([comment, decl, assign])) else: # decl fresh vars iters = self.getTempIterators(typ_2.dim) # cmdl_assn expr_1_elt = IRUtil.addIndex(expr_1, iters) expr_2_elt = IRUtil.addIndex(expr_2, iters) cmdl_assn = IRUtil.loop(typ_2.shape, iters, [IR.Assn(expr_2_elt, IRUtil.negate(expr_1_elt))]) comment = IR.Comment(str(node.metadata)) prog_2 = IRUtil.prog_merge(prog_1, IR.Prog([comment] + cmdl_assn)) prog_2 = IRUtil.prog_merge(IR.Prog([IR.Decl(expr_2.idf, node.type)]), prog_2) if not(Util.Config.disableTruncOpti): self.scaleFacMapping[expr_2.idf] = self.scaleFacMapping[expr_1.idf] return (prog_2, expr_2)
def visitBopMul2DTensor(self, node:AST.BOp, args=None): (prog_1, expr_1) = self.visit(node.expr1) (prog_2, expr_2) = self.visit(node.expr2) # decl fresh vars expr_3 = self.getTempVar() typ_1 = node.expr1.type typ_2 = node.expr2.type typ_3 = node.type [I, J] = typ_1.shape [J, K] = typ_2.shape typ_mul = Type.Tensor([J]) shrT = Util.Config.consSF cmd0 = IR.Comment(expr_1.idf + ' * ' + expr_2.idf) funcCallArgsDict = OrderedDict() funcCallArgsDict[IR.Int(I, 32)] = "I" funcCallArgsDict[IR.Int(J, 32)] = "J" funcCallArgsDict[IR.Int(K, 32)] = "K" funcCallArgsDict[expr_1] = "A" funcCallArgsDict[expr_2] = "B" funcCallArgsDict[expr_3] = "C" funcCallArgsDict[IR.Int(shrT)] = "shrT" funcCall = IR.FuncCall("MatMulCSF2D", funcCallArgsDict) comment = IR.Comment(str(node.metadata)) prog_3 = IRUtil.prog_merge(prog_1, prog_2, IR.Prog([comment, cmd0, funcCall])) self.decls[expr_3.idf] = [typ_3] prog_3 = IRUtil.prog_merge(IR.Prog([IR.Decl(expr_3.idf, node.type)]), prog_3) return (prog_3, expr_3)
def visitBopMulInt(self, node:AST.BOp, args=None): (prog_1, expr_1) = self.visit(node.expr1) (prog_2, expr_2) = self.visit(node.expr2) expr_3 = self.getTempVar() comment = IR.Comment(str(node.metadata)) bitlen = node.expr.bitlen decl = IR.Decl(expr_3.idf, node.type, node.type.bitlen, node.type.isSecret) assign = IR.Assn(expr_3, IRUtil.mul(expr_1, expr_2)) prog_3 = IRUtil.prog_merge(prog_1, prog_2, IR.Prog([comment, decl, assign])) progExtraBefore = IR.Prog([]) progExtraAfter = IR.Prog([]) if (Util.Config.disableTruncOpti): progExtraAfter = self.addTruncateFunctionCall(node, "MulInt", expr_3, Util.Config.consSF) else: expr1_sf = self.scaleFacMapping[expr_1.idf] expr2_sf = self.scaleFacMapping[expr_2.idf] if (expr1_sf > self.scaleFac): progExtraBefore = self.addTruncateFunctionCall(node.expr1, "MulInt", expr_1, expr1_sf-self.scaleFac) self.scaleFacMapping[expr_1.idf] = self.scaleFac if (expr2_sf > self.scaleFac): progExtraBefore = IRUtil.prog_merge(progExtraBefore, self.addTruncateFunctionCall(node.expr2, "MulInt", expr_2, expr2_sf-self.scaleFac)) self.scaleFacMapping[expr_2.idf] = self.scaleFac self.scaleFacMapping[expr_3.idf] = 2*self.scaleFac prog_3 = IRUtil.prog_merge(progExtraBefore, prog_3, progExtraAfter) return (prog_3, expr_3)
def visitFusedBatchNorm(self, node: AST.FusedBatchNorm, args=None): (prog1, expr1) = self.visit(node.expr) (prog2, expr2) = self.visit(node.multExpr) (prog3, expr3) = self.visit(node.addExpr) returnExpr = self.getTempVar() funcArgsList = OrderedDict() for ii, elem in enumerate(node.type.shape): funcArgsList[IR.Int(elem, 32)] = "elem_" + str(ii) funcArgsList[expr1] = "expr" funcArgsList[expr2] = "multExpr" funcArgsList[expr3] = "addExpr" funcArgsList[IR.Int(Util.Config.consSF, 32)] = "consSF" funcArgsList[returnExpr] = "returnExpr" funcCallIR = IR.FuncCall( "FusedBatchNorm" + self.varNameDelim + str(len(node.type.shape)) + self.varNameDelim #one for output + str(len(node.type.shape)) + self.varNameDelim #one for input + str(len(node.multExpr.type.shape)) + self.varNameDelim + str(len(node.addExpr.type.shape)), funcArgsList) comment = IR.Comment(str(node.metadata)) returnProg = IRUtil.prog_merge(prog1, prog2, prog3, IR.Prog([comment, funcCallIR])) self.decls[returnExpr.idf] = [node.type] returnProg = IRUtil.prog_merge( IR.Prog([IR.Decl(returnExpr.idf, node.type)]), returnProg) return (returnProg, returnExpr)
def visitFusedBatchNorm(self, node:AST.FusedBatchNorm, args=None): (prog1, expr1) = self.visit(node.expr) (prog2, expr2) = self.visit(node.multExpr) (prog3, expr3) = self.visit(node.addExpr) returnExpr = self.getTempVar() funcArgsList = OrderedDict() for ii, elem in enumerate(node.type.shape): funcArgsList[IR.Int(elem, 32)] = "elem_"+str(ii) funcArgsList[expr1] = "expr" funcArgsList[expr2] = "multExpr" funcArgsList[expr3] = "addExpr" progExtraBefore = IR.Prog([]) multExprScaleDownSf = self.scaleFac addExprScaleUpSf = 0 if not(Util.Config.disableTruncOpti): #TruncOpti is on multExprScaleDownSf = 0 addExprScaleUpSf = 0 expr_sf = self.scaleFacMapping[expr1.idf] multExpr_sf = self.scaleFacMapping[expr2.idf] addExpr_sf = self.scaleFacMapping[expr3.idf] if (expr_sf > self.scaleFac): #Scale down needed progExtraBefore = self.addTruncateFunctionCall(node.expr, "FusedBatchNorm", expr1, expr_sf - self.scaleFac) self.scaleFacMapping[expr1.idf] = self.scaleFac if (multExpr_sf > self.scaleFac): #Scale down needed progExtraBefore = IRUtil.prog_merge(progExtraBefore, self.addTruncateFunctionCall(node.multExpr, "FusedBatchNorm", expr2, multExpr_sf - self.scaleFac)) self.scaleFacMapping[expr2.idf] = self.scaleFac final_sf = 2*self.scaleFac assert(final_sf >= addExpr_sf) if (final_sf > addExpr_sf): addExprScaleUpSf = final_sf - addExpr_sf self.scaleFacMapping[expr3.idf] += addExprScaleUpSf self.scaleFacMapping[returnExpr.idf] = final_sf funcArgsList[IR.Int(multExprScaleDownSf, 32)] = "multExprScaleDownSf" funcArgsList[IR.Int(addExprScaleUpSf, 32)] = "addExprScaleUpSf" funcArgsList[returnExpr] = "returnExpr" funcCallIR = IR.FuncCall("FusedBatchNorm" + self.varNameDelim + str(len(node.type.shape)) + self.varNameDelim #one for output + str(len(node.type.shape)) + self.varNameDelim #one for input + str(len(node.multExpr.type.shape)) + self.varNameDelim + str(len(node.addExpr.type.shape)), funcArgsList) comment = IR.Comment(str(node.metadata)) returnProg = IRUtil.prog_merge(prog1, prog2, prog3, progExtraBefore, IR.Prog([comment, funcCallIR])) returnProg = IRUtil.prog_merge(IR.Prog([IR.Decl(returnExpr.idf, node.type)]), returnProg) return (returnProg, returnExpr)
def visitFloorLike(self, node:AST.Func, args=None): (prog1, expr1) = self.visit(node.expr) tmpExpr = self.getTempVar() if node.op == AST.Operators.Floor: funcName = "Floor" elif node.op == AST.Operators.Shape: funcName = "Shape" elif node.op == AST.Operators.RELU: funcName = "Relu" elif node.op == AST.Operators.ClearMemSecret: funcName = "ClearMemSecret" elif node.op == AST.Operators.ClearMemPublic: funcName = "ClearMemPublic" else: assert False argsList = OrderedDict() inputType = node.expr.type if Type.isTensor(inputType): for ii, curDim in enumerate(inputType.shape): argsList[IR.Int(curDim, 32)] = "inShape_" + str(ii) argsList[expr1] = "inArr" if Type.isTensor(node.type): argsList[tmpExpr] = "outArr" if node.op == AST.Operators.Floor: argsList[IR.Int(Util.Config.consSF,32)] = "curScale" progExtra = IR.Prog([]) if (Util.Config.disableTruncOpti): if node.op == AST.Operators.RELU: argsList[IR.Int(Util.Config.consSF,32)] = "consSF" argsList[IR.Bool(False)] = "doTruncation" else: final_sf = self.scaleFacMapping[expr1.idf] if node.op == AST.Operators.RELU: argsList[IR.Int(final_sf - self.scaleFac,32)] = "consSF" if (final_sf > self.scaleFac): #If it can't tolerate one more mult operation, then scale down here final_sf = self.scaleFac argsList[IR.Bool(True)] = "doTruncation" else: argsList[IR.Bool(False)] = "doTruncation" self.scaleFacMapping[tmpExpr.idf] = final_sf comment = IR.Comment(str(node.metadata)) funcNameSuffix = "" if Type.isTensor(inputType): funcNameSuffix = str(len(inputType.shape)) progFinal = IRUtil.prog_merge(prog1 , IR.Prog([comment, IR.FuncCall(funcName + self.varNameDelim + funcNameSuffix, argsList)])) if Type.isTensor(node.type): progFinal = IRUtil.prog_merge(IR.Prog([IR.Decl(tmpExpr.idf, node.type)]), progFinal) progFinal = IRUtil.prog_merge(progFinal, progExtra) return (progFinal, tmpExpr)
def visitBopAddOrSub(self, node: AST.BOp, args=None): (prog_1, expr_1) = self.visit(node.expr1) (prog_2, expr_2) = self.visit(node.expr2) # op_ir, typ_3 op = node.op if (op == AST.Operators.ADD): (op_ir, op_fn) = (IR.Op.Op['+'], operator.add) funcName = "MatAdd" elif (op == AST.Operators.SUB): (op_ir, op_fn) = (IR.Op.Op['-'], operator.sub) funcName = "MatSub" else: assert False typ_3 = node.type # e : Int if Type.isInt(typ_3): prog_3 = IRUtil.prog_merge(prog_1, prog_2) expr_3 = IR.IntBop(expr_1, op_ir, expr_2) # e : Tensor() -- float, or Tensor(..) else: ## TODO : Hack for techfest if (node.type.dim != node.expr1.type.dim): # This needs broadcast of expr1 assert False # For now this shouldn't occur if (node.type.dim != node.expr2.type.dim): # This needs broadcast of expr2 funcName += 'BroadCast' # decl fresh vars expr_3 = self.getTempVar() cmd0 = IR.Comment(expr_1.idf + ' ' + op_ir.name + ' ' + expr_2.idf) outputShape = typ_3.shape argsDict = OrderedDict() inp1_shape = node.expr1.type.shape inp2_shape = node.expr2.type.shape for ii, curDimSize in enumerate(inp1_shape): argsDict[IR.Int(curDimSize, 32)] = "size_" + str(ii) for ii, curDimSize in enumerate(inp2_shape): argsDict[IR.Int(curDimSize, 32)] = "size_" + str(ii) for ii, curDimSize in enumerate(outputShape): argsDict[IR.Int(curDimSize, 32)] = "size_" + str(ii) argsDict[expr_1] = "A" argsDict[expr_2] = "B" argsDict[expr_3] = "C" funcCall = IR.FuncCall( funcName + self.varNameDelim + str(len(outputShape)), argsDict) comment = IR.Comment(str(node.metadata)) prog_3 = IRUtil.prog_merge(prog_1, prog_2, IR.Prog([comment, cmd0, funcCall])) self.decls[expr_3.idf] = [typ_3] prog_3 = IRUtil.prog_merge( IR.Prog([IR.Decl(expr_3.idf, node.type)]), prog_3) return (prog_3, expr_3)
def visitUninterpFuncCall(self, node:AST.UninterpFuncCall, args=None): progList = [] exprList = [] for ii, curArg in enumerate(node.argsList): (progN, exprN) = self.visit(curArg) progList.append(progN) exprList.append(exprN) returnExpr = self.getTempVar() funcName = node.funcName funcName += self.varNameDelim + str(len(node.outputShape)) for ii, curArg in enumerate(node.argsList): if Type.isTensor(curArg.type): curShape = curArg.type.shape # If len(shape) == 0 : that means its a float - no need to qualify # the function name with 0 in that case, since its essentially # become an int. if (len(curShape) > 0): funcName += self.varNameDelim + str(len(curShape)) ### TODO : WRONG -- TEMP FIX -- right now if random strings like int are passed, its being set as datatype int -- int datatype in # unintrepreted func call is being used in a hacky way right now -- fix this later # Policy : # First output tensor sizes are inserted in args. # Then for each input tensor, its shape is inserted in args, followed by the input tensor itself. # If the current input tensor has the same shape as any of the previous tensors, then its shape is not inserted. funcArgsList = OrderedDict() tensorShapesFound = {} outputShape = node.type.shape for ii, curDim in enumerate(outputShape): funcArgsList[IR.Int(curDim, 32)] = "OutputShape_" + str(ii) tensorShapesFound[tuple(outputShape)] = True for ii in range(0, len(node.argsList)): if node.outputDiffInpDims < 2 and Type.isTensor(node.argsList[ii].type): curInpShape = node.argsList[ii].type.shape if ((node.outputDiffInpDims == 1) or (node.outputDiffInpDims == 0 and tuple(curInpShape) not in tensorShapesFound)): for jj, curDim in enumerate(curInpShape): funcArgsList[IR.Int(curDim, 32)] = "Input_" + str(ii) + self.varNameDelim + str(jj) tensorShapesFound[tuple(curInpShape)] = True funcArgsList[exprList[ii]] = "inpExpr_" + str(ii) funcArgsList[returnExpr] = "output" comment = IR.Comment(str(node.metadata)) progFinal = progList[0] if len(progList) > 1: for ii in range(1, len(progList)): progFinal = IRUtil.prog_merge(progFinal, progList[ii]) progFinal = IRUtil.prog_merge(progFinal, IR.Prog([comment, IR.FuncCall(funcName, funcArgsList)])) self.decls[returnExpr.idf] = [node.type, "public" if node.isSecret is False else "secret"] progFinal = IRUtil.prog_merge(IR.Prog([IR.Decl(returnExpr.idf, node.type, isSecret="public" if node.isSecret is False else "secret")]), progFinal) return (progFinal, returnExpr)
def visitFloorLike(self, node: AST.Func, args=None): (prog1, expr1) = self.visit(node.expr) tmpExpr = self.getTempVar() if node.op == AST.Operators.Floor: funcName = "Floor" elif node.op == AST.Operators.Shape: funcName = "Shape" elif node.op == AST.Operators.RELU: funcName = "Relu" elif node.op == AST.Operators.ClearMemSecret: funcName = "ClearMemSecret" elif node.op == AST.Operators.ClearMemPublic: funcName = "ClearMemPublic" else: assert False argsList = OrderedDict() inputType = node.expr.type if Type.isTensor(inputType): for ii, curDim in enumerate(inputType.shape): argsList[IR.Int(curDim, 32)] = "inShape_" + str(ii) argsList[expr1] = "inArr" if Type.isTensor(node.type): argsList[tmpExpr] = "outArr" self.decls[tmpExpr.idf] = [node.type] if node.op == AST.Operators.Floor: argsList[IR.Int(Util.Config.consSF)] = "curScale" comment = IR.Comment(str(node.metadata)) funcNameSuffix = "" if Type.isTensor(inputType): funcNameSuffix = str(len(inputType.shape)) progFinal = IRUtil.prog_merge( prog1, IR.Prog([ comment, IR.FuncCall(funcName + self.varNameDelim + funcNameSuffix, argsList) ])) if Type.isTensor(node.type): progFinal = IRUtil.prog_merge( IR.Prog([IR.Decl(tmpExpr.idf, node.type)]), progFinal) return (progFinal, tmpExpr)
def visitPool(self, node: AST.Pool, args=None): (prog_1, expr_1) = self.visit(node.expr) [N, H, W, CI] = node.expr.type.shape [N1, outH, outW, CI1] = node.type.shape assert (N == N1 and CI == CI1) [expr_2] = self.getTempVars(1) comment = IR.Comment(str(node.metadata)) funcCallArgsDict = OrderedDict() funcCallArgsDict[IR.Int(N1, 32)] = "N1" funcCallArgsDict[IR.Int(outH, 32)] = "outH" funcCallArgsDict[IR.Int(outW, 32)] = "outW" funcCallArgsDict[IR.Int(CI1, 32)] = "CI1" funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.FH], 32)] = "FH" funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.FW], 32)] = "FW" funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.zPadHLeft], 32)] = "zPadHLeft" funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.zPadHRight], 32)] = "zPadHRight" funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.zPadWLeft], 32)] = "zPadWLeft" funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.zPadWRight], 32)] = "zPadWRight" funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.strideH], 32)] = "strideH" funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.strideW], 32)] = "strideW" funcCallArgsDict[IR.Int(N, 32)] = "N" funcCallArgsDict[IR.Int(H, 32)] = "H" funcCallArgsDict[IR.Int(W, 32)] = "W" funcCallArgsDict[IR.Int(CI, 32)] = "CI" funcCallArgsDict[expr_1] = "input" funcCallArgsDict[expr_2] = "output" funcCall = IR.FuncCall(node.poolType, funcCallArgsDict) prog_pool = IR.Prog([comment, funcCall]) prog_2 = IRUtil.prog_merge(prog_1, prog_pool) self.decls[expr_2.idf] = [node.type] prog_2 = IRUtil.prog_merge(IR.Prog([IR.Decl(expr_2.idf, node.type)]), prog_2) return (prog_2, expr_2)
def fixOuputScale(self, res: (IR.Prog, IR.Expr), compiler: IRBuilderCSF): prog = res[0] expr = res[1] output_scale = compiler.scaleFacMapping[expr.idf] if output_scale == -1 or output_scale == Util.Config.consSF: return (prog, expr) elif output_scale > Util.Config.consSF: scale_down = output_scale - Util.Config.consSF type = compiler.typeInfo[expr.idf] if Type.isInt(type): output_shape = [] if Type.isTensor(type): output_shape = type.shape argsDict = OrderedDict() funcName = "ScaleDown" for ii, curDimSize in enumerate(output_shape): argsDict[IR.Int(curDimSize, 32)] = "size_" + str(ii) funcName = funcName + str(len(output_shape)) argsDict[expr] = "expr" argsDict[IR.Int(scale_down, 32)] = "consSF" funcCall = IR.FuncCall(funcName, argsDict) new_prog = IR.Prog([funcCall]) prog = IRUtil.prog_merge(prog, new_prog) return (prog, expr) else: assert False, "Scale up shouldnt be required of final output {} -> {}. We lost precision somewhere".format( output_scale, Util.Config.consSF)
def visitDecl(self, node:AST.Decl, args=None): def helperAssignGen(l1, l2, allComb): if l2 == []: allComb.append(l1) else: for cur in range(l2[0]): helperAssignGen(l1 + [cur], l2[1:], allComb) prog = IR.Prog([]) expr = self.getTempVar() expr.inputVar = True # If there is a valueList, then add assignment commands specialBitLen = -1 if node.valueList: # Add assignment statements for each element of the tensor in a different array comment = IR.Comment(str(node.metadata)) prog = IRUtil.prog_merge(prog, IR.Prog([comment, IR.Comment('Element assignments for ' + expr.idf)])) allComb = [] helperAssignGen([], node.shape, allComb) for i,curComb in enumerate(allComb): curVal = node.valueList[i] finalVal = None if isinstance(curVal, AST.Int): finalVal = IR.Int(curVal.value, curVal.bitLen) if (specialBitLen == -1 and curVal.bitLen != Util.Config.wordLength): specialBitLen = curVal.bitLen elif isinstance(curVal, AST.Float): finalVal = IR.DataType.getInt(np.ldexp(curVal.value, Util.Config.consSF)) else: # Assuming the elements can only be either int or floats assert False prog = IRUtil.prog_merge(prog, IR.Prog([IR.Assn(IRUtil.addIndex(expr, list(map(lambda x: IR.Int(x), curComb))), finalVal)])) self.decls[expr.idf] = [node.type, "public" if node.isSecret is False else "secret", Util.Config.wordLength if specialBitLen == -1 else specialBitLen ] prog = IRUtil.prog_merge(IR.Prog([IR.Decl(expr.idf, node.type, Util.Config.wordLength if specialBitLen == -1 else specialBitLen, "public" if node.isSecret is False else "secret" )]), prog) return (prog, expr)
def visitBopMulScalar1DTensor(self, node:AST.BOp, args=None): (prog_1, expr_1) = self.visit(node.expr1) (prog_2, expr_2) = self.visit(node.expr2) typ_1 = node.expr1.type typ_2 = node.expr2.type typ_3 = node.type isIntMult = False if typ_1.dim == 0 or Type.isInt(typ_1): a, b = expr_1, expr_2 outputShape = typ_2.shape isIntMult = (Type.isInt(typ_1)) else: a, b = expr_2, expr_1 outputShape = typ_1.shape isIntMult = (Type.isInt(typ_2)) # decl fresh vars expr_3 = self.getTempVar() cmd0 = IR.Comment(expr_1.idf + ' * ' + expr_2.idf) funcCallArgsDict = OrderedDict() for ii,curDimSize in enumerate(outputShape): funcCallArgsDict[IR.Int(curDimSize, 32)] = "size_" + str(ii) funcCallArgsDict[a] = "A" funcCallArgsDict[b] = "B" funcCallArgsDict[expr_3] = "C" progExtraBefore = IR.Prog([]) progExtraAfter = IR.Prog([]) if (Util.Config.disableTruncOpti): progExtraAfter = self.addTruncateFunctionCall(node, "ScalarMul", expr_3, Util.Config.consSF) else: expr1_sf = self.scaleFacMapping[expr_1.idf] expr2_sf = self.scaleFacMapping[expr_2.idf] if (expr1_sf > self.scaleFac): progExtraBefore = self.addTruncateFunctionCall(node.expr1, "ScalarMul", expr_1, expr1_sf-self.scaleFac) self.scaleFacMapping[expr_1.idf] = self.scaleFac if (expr2_sf > self.scaleFac): progExtraBefore = IRUtil.prog_merge(progExtraBefore, self.addTruncateFunctionCall(node.expr2, "ScalarMul", expr_2, expr2_sf-self.scaleFac)) self.scaleFacMapping[expr_2.idf] = self.scaleFac self.scaleFacMapping[expr_3.idf] = 2*self.scaleFac funcCall = IR.FuncCall('ScalarMul' + self.varNameDelim + str(len(outputShape)), funcCallArgsDict) prog_3 = IRUtil.prog_merge(prog_1, prog_2, progExtraBefore, IR.Prog([cmd0, funcCall])) prog_3 = IRUtil.prog_merge(IR.Prog([IR.Decl(expr_3.idf, node.type)]), prog_3, progExtraAfter) return (prog_3, expr_3)
def visitBopMulInt(self, node: AST.BOp, args=None): (prog_1, expr_1) = self.visit(node.expr1) (prog_2, expr_2) = self.visit(node.expr2) prog_3 = IRUtil.prog_merge(prog_1, prog_2) expr_3 = IRUtil.mul(expr_1, expr_2) return (prog_3, expr_3)
def visitReduce(self, node: AST.Reduce, args=None): (prog_1, expr1) = self.visit(node.expr) (prog_2, expr2) = self.visit(node.dim) returnExpr = self.getTempVar() assert (node.op in [AST.Operators.ADD, AST.Operators.MUL, AST.Operators.Mean]) scalingFac = None if (node.op == AST.Operators.ADD): funcName = "ReduceSum" elif (node.op == AST.Operators.MUL): funcName = "ReduceMul" scalingFac = Util.Config.consSF elif (node.op == AST.Operators.Mean): funcName = "ReduceMean" else: print("Unknown node.op in AST.Reduce.", file=sys.stderr) assert False funcArgsList = OrderedDict() outputShape = node.type.shape for ii, curDim in enumerate(outputShape): funcArgsList[IR.Int(curDim, 32)] = "OutputShape_" + str(ii) inputShape = node.expr.type.shape for ii, curDim in enumerate(inputShape): funcArgsList[IR.Int(curDim, 32)] = "InputShape_" + str(ii) funcArgsList[expr1] = "inputArr" funcArgsList[expr2] = "dimension" funcArgsList[returnExpr] = "outArr" if scalingFac: funcArgsList[IR.Int(scalingFac)] = "ScalingFactor" funcCall = IR.FuncCall( funcName + self.varNameDelim + str(len(outputShape)) + self.varNameDelim + str(len(inputShape)), funcArgsList) comment = IR.Comment(str(node.metadata)) prog_3 = IRUtil.prog_merge(prog_1, prog_2, IR.Prog([comment, funcCall])) self.decls[returnExpr.idf] = [node.type] prog_3 = IRUtil.prog_merge( IR.Prog([IR.Decl(returnExpr.idf, node.type)]), prog_3) return (prog_3, returnExpr)
def visitLet(self, node: AST.Let, args=None): (prog_1, expr_1) = self.visit(node.decl) typ_1 = node.decl.type idf = node.name.name (prog_2, expr_2) = self.visit(node.expr) prog_2 = prog_2.subst(idf, expr_1) expr_2 = expr_2.subst(idf, expr_1) prog_3 = IRUtil.prog_merge(prog_1, prog_2) return (prog_3, expr_2)
def visitBopConv(self, node: AST.BOp, args=None): (prog1, expr1) = self.visit(node.expr1) (prog2, expr2) = self.visit(node.expr2) [N, H, W, CI] = node.expr1.type.shape [FH, FW, CI, CO] = node.expr2.type.shape returnExpr = self.getTempVar() comment = IR.Comment(expr1.idf + ' # ' + expr2.idf) funcCallArgsDict = OrderedDict() funcCallArgsDict[IR.Int(N, 32)] = "N" funcCallArgsDict[IR.Int(H, 32)] = "H" funcCallArgsDict[IR.Int(W, 32)] = "W" funcCallArgsDict[IR.Int(CI, 32)] = "CI" funcCallArgsDict[IR.Int(FH, 32)] = "FH" funcCallArgsDict[IR.Int(FW, 32)] = "FW" funcCallArgsDict[IR.Int(CO, 32)] = "CO" funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.zPadHLeft], 32)] = "zPadHLeft" funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.zPadHRight], 32)] = "zPadHRight" funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.zPadWLeft], 32)] = "zPadWLeft" funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.zPadWRight], 32)] = "zPadWRight" funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.strideH], 32)] = "strideH" funcCallArgsDict[IR.Int(node.options[AST.PaddingKeysDict.strideW], 32)] = "strideW" funcCallArgsDict[expr1] = "input" funcCallArgsDict[expr2] = "filter" funcCallArgsDict[IR.Int(Util.Config.consSF, 32)] = "consSF" funcCallArgsDict[returnExpr] = "output" funcCall = IR.FuncCall("Conv2DCSF", funcCallArgsDict) progConv = IR.Prog([comment, funcCall]) returnProg = IRUtil.prog_merge(prog1, prog2, progConv) self.decls[returnExpr.idf] = [node.type] returnProg = IRUtil.prog_merge( IR.Prog([IR.Decl(returnExpr.idf, node.type)]), returnProg) return (returnProg, returnExpr)
def visitFloat(self, node:AST.Float, args=None): r = node.value p = self.get_expnt(abs(r)) k = IR.DataType.getInt(np.ldexp(r, p)) expr = self.getTempVar() prog = IR.Prog([IR.Comment('Float to int : {0} to {1}, isSecret = {2}.'.format(str(r), str(k), node.isSecret))]) prog = IRUtil.prog_merge(IR.Prog([IR.Decl(expr.idf, node.type, -1, node.isSecret, [k])]), prog) if (not(Util.Config.disableTruncOpti)): self.scaleFacMapping[expr.idf] = self.scaleFac return (prog, expr)
def visitBopMulScalar1DTensor(self, node: AST.BOp, args=None): (prog_1, expr_1) = self.visit(node.expr1) (prog_2, expr_2) = self.visit(node.expr2) typ_1 = node.expr1.type typ_2 = node.expr2.type typ_3 = node.type isIntMult = False if typ_1.dim == 0 or Type.isInt(typ_1): a, b = expr_1, expr_2 outputShape = typ_2.shape isIntMult = (Type.isInt(typ_1)) else: a, b = expr_2, expr_1 outputShape = typ_1.shape isIntMult = (Type.isInt(typ_2)) # a represents the scalar and b the tensor shr3 = IR.Int(Util.Config.consSF) if isIntMult: shr3 = 0 # If multiplying with an int, then sf = 0 # decl fresh vars expr_3 = self.getTempVar() cmd0 = IR.Comment(expr_1.idf + ' * ' + expr_2.idf) funcCallArgsDict = OrderedDict() for ii, curDimSize in enumerate(outputShape): funcCallArgsDict[IR.Int(curDimSize, 32)] = "size_" + str(ii) funcCallArgsDict[a] = "A" funcCallArgsDict[b] = "B" funcCallArgsDict[expr_3] = "C" funcCallArgsDict[shr3] = "shr3" funcCall = IR.FuncCall( 'ScalarMul' + self.varNameDelim + str(len(outputShape)), funcCallArgsDict) prog_3 = IRUtil.prog_merge(prog_1, prog_2, IR.Prog([cmd0, funcCall])) self.decls[expr_3.idf] = [typ_3] prog_3 = IRUtil.prog_merge(IR.Prog([IR.Decl(expr_3.idf, node.type)]), prog_3) return (prog_3, expr_3)
def visitIndex(self, node:AST.Index, args=None): (prog_1, expr_1) = self.visit(node.expr) prog_idx = expr_idx = [] for curIdx in node.index: (prog_cur, expr_cur) = self.visit(curIdx) prog_idx.append(prog_cur) expr_idx.append(expr_cur) prog_3 = IRUtil.prog_merge(prog_1, [curCmd for curProg in prog_idx for curCmd in curProg]) expr_3 = IRUtil.addIndex(expr_1, expr_idx) return (prog_3, expr_3)
def visitInt(self, node:AST.Int, args=None): n = node.value prog = IR.Prog([IR.Comment('Int node, isSecret = {0}.'.format(node.isSecret))]) expr = self.getTempVar() bitlen = -1 if node.bitLen: bitlen = node.bitLen prog = IRUtil.prog_merge(IR.Prog([IR.Decl(expr.idf, node.type, bitlen, node.isSecret, [n])]), prog) if (not(Util.Config.disableTruncOpti)): self.scaleFacMapping[expr.idf] = self.scaleFac if node.isScaled else 0 return (prog, expr)
def visitBopElemWiseOp(self, node:AST.BOp, args=None): (prog_1, expr_1) = self.visit(node.expr1) (prog_2, expr_2) = self.visit(node.expr2) if (node.op == AST.Operators.ElemWiseMul): op_ir = IR.Op.Op['.*'] funcName = "ElemWiseMul" elif (node.op == AST.Operators.ElemWiseDiv): op_ir = IR.Op.Op['./'] funcName = "ElemWiseDiv" typ_3 = node.type expr_3 = self.getTempVar() cmd0 = IR.Comment(expr_1.idf + ' ' + op_ir.name + ' ' + expr_2.idf) outputShape = typ_3.shape argsDict = OrderedDict() for ii,curDimSize in enumerate(outputShape): argsDict[IR.Int(curDimSize, 32)] = "size_" + str(ii) argsDict[expr_1] = "A" argsDict[expr_2] = "B" argsDict[expr_3] = "C" progExtraBefore = IR.Prog([]) progExtraAfter = IR.Prog([]) if (Util.Config.disableTruncOpti): progExtraAfter = self.addTruncateFunctionCall(node, funcName, expr_3, Util.Config.consSF) else: expr1_sf = self.scaleFacMapping[expr_1.idf] expr2_sf = self.scaleFacMapping[expr_2.idf] if (expr1_sf > self.scaleFac): progExtraBefore = self.addTruncateFunctionCall(node.expr1, funcName, expr_1, expr1_sf-self.scaleFac) self.scaleFacMapping[expr_1.idf] = self.scaleFac if (expr2_sf > self.scaleFac): progExtraBefore = IRUtil.prog_merge(progExtraBefore, self.addTruncateFunctionCall(node.expr2, funcName, expr_2, expr2_sf-self.scaleFac)) self.scaleFacMapping[expr_2.idf] = self.scaleFac self.scaleFacMapping[expr_3.idf] = 2*self.scaleFac funcCall = IR.FuncCall(funcName + self.varNameDelim + str(len(outputShape)), argsDict) prog_3 = IRUtil.prog_merge(prog_1, prog_2, progExtraBefore, IR.Prog([cmd0, funcCall])) prog_3 = IRUtil.prog_merge(IR.Prog([IR.Decl(expr_3.idf, node.type)]), prog_3, progExtraAfter) return (prog_3, expr_3)
def visitFloat(self, node:AST.Float, args=None): r = node.value p = self.get_expnt(abs(r)) k = IR.DataType.getInt(np.ldexp(r, p)) expr = None prog = IR.Prog([IR.Comment('Float to int : {0} to {1}, isSecret = {2}.'.format(str(r), str(k), node.isSecret))]) if not(node.isSecret): expr = IR.Int(k) else: expr = self.getTempVar() self.decls[expr.idf] = [node.type] prog = IRUtil.prog_merge(IR.Prog([IR.Decl(expr.idf, node.type)]), prog) return (prog, expr)
def visitFloat(self, node: AST.Float, args=None): r = node.value p = self.get_expnt(abs(r)) k = IR.DataType.getInt(np.ldexp(r, p)) comment = IR.Comment('Float to int : ' + str(r) + ' to ' + str(k)) expr = None if not (node.isSecret): expr = IR.Int(k) else: expr = self.getTempVar() self.decls[expr.idf] = [node.type] prog = IRUtil.prog_merge(IR.Prog([IR.Decl(expr.idf, node.type)]), prog) return (IR.Prog([comment]), expr)