def EList2BinaryExpr(elist=[]): assert (len(elist) > 0) if ((len(elist) == 2) and isinstance(elist[0], EXPR.BinaryOp) and (elist[0].label == "-") and isinstance(elist[1], EXPR.Expr)): revised_bop = String2BinaryOp("*") return EList2BinaryExpr( [EXPR.ConstantExpr(-1.0), revised_bop, elist[1]]) if (len(elist) != 3): return None else: if (not isinstance(elist[0], EXPR.ArithmeticExpr)): return None if (not isinstance(elist[1], EXPR.BinaryOp)): return None if (not isinstance(elist[2], EXPR.ArithmeticExpr)): return None if (coalesceConstBinaryExpr(elist[0], elist[2])): return computeConstBinaryExpr(elist[1], elist[0], elist[2]) else: if (elist[1].label == "^"): assert (isinstance(elist[2], EXPR.ConstantExpr)) return ExpandConstPowerExpression(elist[1], elist[0], elist[2].value()) else: return EXPR.BinaryExpr(elist[1], elist[0], elist[2])
def GroupErrorVarSum (gid, epss=[]): assert((type(gid) is int) and (0 <= gid)) checkValidEpsilonList(epss) expr_sum = None for ei in range(0, len(epss)): if (expr_sum is None): expr_sum = GroupErrorVar(gid, ei) else: expr_sum = tft_expr.BinaryExpr(tft_expr.BinaryOp(-1, "+"), expr_sum, GroupErrorVar(gid, ei)) return expr_sum
def ExpandConstPowerExpression(op, e_base, c_power): assert ((isinstance(op, EXPR.BinaryOp)) and (op.label == "^")) assert (isinstance(e_base, EXPR.Expr)) assert (float(c_power) == int(c_power)) c_power = int(c_power) assert (c_power >= 0) if (c_power == 0): return EXPR.ConstantExpr(1.0) elif (c_power == 1): return e_base else: return EXPR.BinaryExpr( EXPR.BinaryOp(op.gid, "*"), ExpandConstPowerExpression(op, e_base, (c_power - 1)), e_base)
def String2ConstBinaryExpr(s): opd0 = None op = None opd1 = None for o in EXPR.BinaryOpLabels: i = s.find(o) if (i > 0): if ((opd0 is None) and (op is None) and (opd1 is None)): opd0 = String2ConstantExpr(s[0:i]) op = String2BinaryOp(o) opd1 = String2ConstantExpr(s[i + len(o):]) if (any([(i is None) for i in [opd0, op, opd1]])): return None else: return None if ((opd0 is None) and (op is None) and (opd1 is None)): return None else: assert (isinstance(opd0, EXPR.ConstantExpr) and isinstance(op, EXPR.BinaryOp) and isinstance(opd1, EXPR.ConstantExpr)) if (coalesceConstBinaryExpr(opd0, opd1)): return computeConstBinaryExpr(op, opd0, opd1) else: if (op.label == "^"): assert (isinstance(opd1, EXPR.ConstantExpr)) return ExpandConstPowerExpression(op, opd0, opd1.value()) else: return EXPR.BinaryExpr(op, opd0, opd1)
def OptimizeErrorFormByGroup (eform): assert(isinstance(eform, ErrorForm)) opt_eform = ErrorForm(eform.upper_bound, eform.M2) # These are Important!! # Need to overwrite the gid-counts opt_eform.gid_counts = eform.gid_counts.copy() # Need to overwrite the gid-weight opt_eform.gid_weight = eform.gid_weight.copy() # Need to overwrite the map of gid -> epsilons opt_eform.gid2epsilons = eform.gid2epsilons.copy() # Need to overwrite the casting_map!! opt_eform.casting_map = eform.casting_map.copy() # Need to overwrite eq_gids opt_eform.eq_gids = eform.eq_gids[:] # Need to overwrite constraints opt_eform.constraints = eform.constraints[:] handled_etids = [] for et in eform.terms: if (et.index in handled_etids): continue gid = et.getGid() context_gid = et.getContextGid() is_precise_opt = et.is_precise_opt assert(0 <= gid) group = [et] handled_etids.append(et.index) for et_p in eform.terms: if (et.index == et_p.index): continue if (et_p.index in handled_etids): continue if ((gid == et_p.getGid()) and (context_gid == et_p.getContextGid()) and ((is_precise_opt and et_p.is_precise_opt) or ((not is_precise_opt) and (not et_p.is_precise_opt)))): group.append(et_p) handled_etids.append(et_p.index) assert(len(group) > 0) combined_expr = group[0].absexpr() for i in range(1, len(group)): combined_expr = tft_expr.BinaryExpr(tft_expr.BinaryOp(-1, "+"), combined_expr, group[i].absexpr()) combined_expr = tft_expr.UnaryExpr(tft_expr.UnaryOp(-1, "abs"), combined_expr) combined_et = ErrorTerm(combined_expr, context_gid, gid, is_precise_opt) assert(gid in eform.gid_counts.keys()) opt_eform.add(combined_et) return opt_eform
def MakeBinaryExpr(op_label, op_gid, opd0, opd1, internal=False): global EXTERNAL_GIDS if ((not internal) and (TUNE_FOR_ALL)): op_gid = 0 assert (type(op_label) is str) assert (type(op_gid) is int) assert (isinstance(opd0, tft_expr.Expr)) assert (isinstance(opd1, tft_expr.Expr)) if (COALESCE_CONST): # if (isinstance(opd0, tft_expr.ConstantExpr) and isinstance(opd1, tft_expr.ConstantExpr)): if (isinstance(opd0, tft_expr.ConstantExpr) and isinstance(opd1, tft_expr.ConstantExpr) and tft_expr.isPreciseConstantExpr(opd0) and tft_expr.isPreciseConstantExpr(opd1)): v0 = opd0.value() v1 = opd1.value() if (op_label == "+"): eret = tft_expr.ConstantExpr(v0 + v1) AppendCppInst(eret) return eret elif (op_label == "-"): eret = tft_expr.ConstantExpr(v0 - v1) AppendCppInst(eret) return eret elif (op_label == "*"): eret = tft_expr.ConstantExpr(v0 * v1) AppendCppInst(eret) return eret elif (op_label == "/"): eret = tft_expr.ConstantExpr(v0 / v1) AppendCppInst(eret) return eret else: sys.exit("ERROR: unknown Binary Operator: " + op_label) # possibly bind the constant type if ((not tft_utils.FIX_CONST_TYPE) and tft_expr.isConstVar(opd0)): if (opd0.getGid() == tft_expr.PRESERVED_CONST_GID): CountGID(tft_expr.PRESERVED_CONST_GID, -1) opd0.gid = op_gid else: if (opd0.getGid() != op_gid): print("Warning: conflicting constant type...") if ((not tft_utils.FIX_CONST_TYPE) and tft_expr.isConstVar(opd1)): if (opd1.getGid() == tft_expr.PRESERVED_CONST_GID): CountGID(tft_expr.PRESERVED_CONST_GID, -1) opd1.gid = op_gid else: if (opd1.getGid() != op_gid): print("Warning: conflicting constant type...") if (not internal): CountGID(op_gid) CountCasting(opd0, op_gid) CountCasting(opd1, op_gid) if (internal): assert (-1 == op_gid) else: assert (0 <= op_gid) if (op_gid not in EXTERNAL_GIDS): EXTERNAL_GIDS.append(op_gid) ret_expr = tft_expr.BinaryExpr(tft_expr.BinaryOp(op_gid, op_label), opd0.copy((not internal)), opd1.copy((not internal))) AppendCppInst(ret_expr) return ret_expr