def visitTranspose(self, node:AST.Transpose, args=None): (inp_prog, inp_arr) = self.visit(node.expr) inp_type = node.expr.type out_type = node.type inp_iters = self.getTempIterators(inp_type.dim) out_iters = [] perm = node.perm if (perm is None): perm = [i for i in reversed(range(len(inp_type.shape)))] for i in perm: out_iters.append(inp_iters[i]) out_arr = self.getTempVar() out_arr_expr = IRUtil.addIndex(out_arr, out_iters) inp_arr_expr = IRUtil.addIndex(inp_arr, inp_iters) assign_expr = IR.Assn(out_arr_expr, inp_arr_expr) loop = IRUtil.loop(inp_type.shape, inp_iters, [assign_expr]) # Finalize comment1 = IR.Comment(str(node.metadata)) comment2 = IR.Comment("transpose(" + inp_arr.idf + ", [" + ', '.join(str(e) for e in inp_type.shape) + "] --> [" + ', '.join(str(e) for e in out_type.shape) + "])") transpose_prog = IR.Prog([comment1, comment2] + loop) final_prog = IRUtil.prog_merge(inp_prog, transpose_prog) for var in inp_iters: final_prog = IRUtil.prog_merge(IR.Prog([IR.Decl(var.idf, Type.Int(), isSecret=False)]), final_prog) final_prog = IRUtil.prog_merge(IR.Prog([IR.Decl(out_arr.idf, out_type)]), final_prog) if not(Util.Config.disableTruncOpti): self.scaleFacMapping[out_arr.idf] = self.scaleFacMapping[inp_arr.idf] return (final_prog, out_arr)
def visitUOp(self, node:AST.UOp, args=None): (prog_1, expr_1) = self.visit(node.expr) op = node.op if op == AST.Operators.ADD: return (prog_1, expr_1) assert op == AST.Operators.SUB typ_2 = node.type expr_2 = self.getTempVar() if Type.isInt(typ_2): comment = IR.Comment(str(node.metadata)) bitlen = node.expr.bitlen decl = IR.Decl(expr_2.idf, node.type, typ_2.bitlen, typ_2.isSecret) assign = IR.Assn(expr_2, IRUtil.negate(expr_1)) prog_2 = IRUtil.prog_merge(prog_1, IR.Prog([comment, decl, assign])) else: # decl fresh vars iters = self.getTempIterators(typ_2.dim) # cmdl_assn expr_1_elt = IRUtil.addIndex(expr_1, iters) expr_2_elt = IRUtil.addIndex(expr_2, iters) cmdl_assn = IRUtil.loop(typ_2.shape, iters, [IR.Assn(expr_2_elt, IRUtil.negate(expr_1_elt))]) comment = IR.Comment(str(node.metadata)) prog_2 = IRUtil.prog_merge(prog_1, IR.Prog([comment] + cmdl_assn)) prog_2 = IRUtil.prog_merge(IR.Prog([IR.Decl(expr_2.idf, node.type)]), prog_2) if not(Util.Config.disableTruncOpti): self.scaleFacMapping[expr_2.idf] = self.scaleFacMapping[expr_1.idf] return (prog_2, expr_2)
def visitUOp(self, node: AST.UOp, args=None): (prog_1, expr_1) = self.visit(node.expr) op = node.op if op == AST.Operators.ADD: return (prog_1, expr_1) assert op == AST.Operators.SUB typ_2 = node.type # e : Int if Type.isInt(typ_2): prog_2 = prog_1 expr_2 = IRUtil.negate(expr_1) # e: Tensor(), or Tensor(..) else: # decl fresh vars expr_2 = self.getTempVar() iters = self.getTempIterators(typ_2.dim) # cmdl_assn expr_1_elt = IRUtil.addIndex(expr_1, iters) expr_2_elt = IRUtil.addIndex(expr_2, iters) cmdl_assn = IRUtil.loop( typ_2.shape, iters, [IR.Assn(expr_2_elt, IRUtil.negate(expr_1_elt))]) comment = IR.Comment(str(node.metadata)) prog_2 = IRUtil.prog_merge(prog_1, IR.Prog([comment] + cmdl_assn)) self.decls[expr_2.idf] = [typ_2] prog_2 = IRUtil.prog_merge( IR.Prog([IR.Decl(expr_2.idf, node.type)]), prog_2) return (prog_2, expr_2)
def visitIndex(self, node:AST.Index, args=None): (prog_1, expr_1) = self.visit(node.expr) prog_idx = expr_idx = [] for curIdx in node.index: (prog_cur, expr_cur) = self.visit(curIdx) prog_idx.append(prog_cur) expr_idx.append(expr_cur) prog_3 = IRUtil.prog_merge(prog_1, [curCmd for curProg in prog_idx for curCmd in curProg]) expr_3 = IRUtil.addIndex(expr_1, expr_idx) return (prog_3, expr_3)
def visitTranspose(self, node: AST.Transpose, args=None): (inp_prog, inp_arr) = self.visit(node.expr) inp_type = node.expr.type out_type = node.type inp_iters = self.getTempIterators(inp_type.dim) out_iters = [] perm = node.perm for i in perm: out_iters.append(inp_iters[i]) out_arr = self.getTempVar() out_arr_expr = IRUtil.addIndex(out_arr, out_iters) inp_arr_expr = IRUtil.addIndex(inp_arr, inp_iters) assign_expr = IR.Assn(out_arr_expr, inp_arr_expr) loop = IRUtil.loop(inp_type.shape, inp_iters, [assign_expr]) # Finalize comment1 = IR.Comment(str(node.metadata)) comment2 = IR.Comment("transpose(" + inp_arr.idf + ", [" + ', '.join(str(e) for e in inp_type.shape) + "] --> [" + ', '.join(str(e) for e in out_type.shape) + "])") transpose_prog = IR.Prog([comment1, comment2] + loop) final_prog = IRUtil.prog_merge(inp_prog, transpose_prog) # Update context self.decls[out_arr.idf] = [out_type] # Update declarations self.decls.update( dict((var.idf, [Type.Int(), 'public']) for var in inp_iters)) for var in inp_iters: final_prog = IRUtil.prog_merge( IR.Prog([IR.Decl(var.idf, Type.Int(), isSecret="public")]), final_prog) final_prog = IRUtil.prog_merge( IR.Prog([IR.Decl(out_arr.idf, out_type)]), final_prog) return (final_prog, out_arr)
def visitDecl(self, node:AST.Decl, args=None): def helperAssignGen(l1, l2, allComb): if l2 == []: allComb.append(l1) else: for cur in range(l2[0]): helperAssignGen(l1 + [cur], l2[1:], allComb) prog = IR.Prog([]) expr = self.getTempVar() expr.inputVar = True # If there is a valueList, then add assignment commands specialBitLen = -1 if node.valueList: # Add assignment statements for each element of the tensor in a different array comment = IR.Comment(str(node.metadata)) prog = IRUtil.prog_merge(prog, IR.Prog([comment, IR.Comment('Element assignments for ' + expr.idf)])) allComb = [] helperAssignGen([], node.shape, allComb) for i,curComb in enumerate(allComb): curVal = node.valueList[i] finalVal = None if isinstance(curVal, AST.Int): finalVal = IR.Int(curVal.value, curVal.bitLen) if (specialBitLen == -1 and curVal.bitLen != Util.Config.wordLength): specialBitLen = curVal.bitLen elif isinstance(curVal, AST.Float): finalVal = IR.DataType.getInt(np.ldexp(curVal.value, Util.Config.consSF)) else: # Assuming the elements can only be either int or floats assert False prog = IRUtil.prog_merge(prog, IR.Prog([IR.Assn(IRUtil.addIndex(expr, list(map(lambda x: IR.Int(x), curComb))), finalVal)])) self.decls[expr.idf] = [node.type, "public" if node.isSecret is False else "secret", Util.Config.wordLength if specialBitLen == -1 else specialBitLen ] prog = IRUtil.prog_merge(IR.Prog([IR.Decl(expr.idf, node.type, Util.Config.wordLength if specialBitLen == -1 else specialBitLen, "public" if node.isSecret is False else "secret" )]), prog) return (prog, expr)
def visitReshape(self, node:AST.Reshape, args=None): (prog_1, expr_1) = self.visit(node.expr) ''' reshape(A, n, h, w) cmd1: t1 = t2 = t3 = 0; loop2: for n in 0:N: for h in 0:H: for w in 0:W: cmd3: B[n][h][w] = A[t1][t2][t3] cmd4: t3++; cmd5: if (t3 == WW) t3 = 0; t2++; if (t2 == HH) t2 = 0; t1++; ''' typ_1 = node.expr.type typ_2 = node.type # Declare variables expr_2 = self.getTempVar() iters_1 = self.getTempIterators(typ_1.dim) iters_2 = self.getTempIterators(typ_2.dim) # Initialize to 0 cmd1 = [IR.Assn(var, IRUtil.zero) for var in iters_1] # Incrementing the first index first_iter = iters_1[0] cmd4 = IRUtil.incCmd(first_iter) # Incrementing other indices using a loop cmd5 = [cmd4] for i in range(1, typ_1.dim): curr_iter = iters_1[i] curr_size = IR.Int(typ_1.shape[i]) cmd5 = [IRUtil.incCmd(curr_iter), IR.If(IRUtil.eq(curr_iter, curr_size), [IRUtil.initVarToZero(curr_iter)] + cmd5)] # Outer loop # The iterators are selected based on the selection order specified by the user loopShape = [] loopIters = [] if(node.order): for order in node.order: order = order - 1 loopShape.append(typ_2.shape[order]) loopIters.append(iters_2[order]) else: loopShape = typ_2.shape loopIters = iters_2 loop2 = IRUtil.loop(loopShape, loopIters, [IR.Assn(IRUtil.addIndex(expr_2, iters_2), IRUtil.addIndex(expr_1, iters_1))] + cmd5) # Finalize comment1 = IR.Comment(str(node.metadata)) comment2 = IR.Comment("reshape(" + expr_1.idf + ", " + ', '.join(str(e) for e in typ_2.shape) + ")") reshape_prog = IR.Prog([comment1, comment2] + cmd1 + loop2) prog_2 = IRUtil.prog_merge(prog_1, reshape_prog) for var in iters_1: prog_2 = IRUtil.prog_merge(IR.Prog([IR.Decl(var.idf, Type.Int(), isSecret=False)]), prog_2) for var in iters_2: prog_2 = IRUtil.prog_merge(IR.Prog([IR.Decl(var.idf, Type.Int(), isSecret=False)]), prog_2) prog_2 = IRUtil.prog_merge(IR.Prog([IR.Decl(expr_2.idf, typ_2)]), prog_2) if not(Util.Config.disableTruncOpti): self.scaleFacMapping[expr_2.idf] = self.scaleFacMapping[expr_1.idf] return (prog_2, expr_2)