def solve_format_ArithOperation( optree, integer_size_func=lambda lhs_prec, rhs_prec: None, frac_size_func=lambda lhs_prec, rhs_prec: None, signed_func=lambda lhs, lhs_prec, rhs, rhs_prec: False): """ determining fixed-point format for a generic 2-op arithmetic operation (e.g. Multiplication, Addition, Subtraction) """ lhs = optree.get_input(0) rhs = optree.get_input(1) lhs_precision = lhs.get_precision() rhs_precision = rhs.get_precision() abstract_operation = (lhs_precision is ML_Integer) and (rhs_precision is ML_Integer) if abstract_operation: return ML_Integer if lhs_precision is ML_Integer: cst_eval = evaluate_cst_graph(lhs, input_prec_solver=solve_format_rec) lhs_precision = solve_format_Constant(Constant(cst_eval)) if rhs_precision is ML_Integer: cst_eval = evaluate_cst_graph(rhs, input_prec_solver=solve_format_rec) rhs_precision = solve_format_Constant(Constant(cst_eval)) if is_fixed_point(lhs_precision) and is_fixed_point(rhs_precision): # +1 for carry overflow int_size = integer_size_func(lhs_precision, rhs_precision) frac_size = frac_size_func(lhs_precision, rhs_precision) is_signed = signed_func(lhs, lhs_precision, rhs, rhs_precision) return fixed_point(int_size, frac_size, signed=is_signed) else: return optree.get_precision()
def mantissa_extraction_modifier_from_fields(op, field_op, exp_is_zero, tag="mant_extr"): """ Legalizing a MantissaExtraction node into a sub-graph of basic operation, assuming <field_op> bitfield and <exp_is_zero> flag are already available """ op_precision = op.get_precision().get_base_format() implicit_digit = Select( exp_is_zero, Constant(0, precision=ML_StdLogic), Constant(1, precision=ML_StdLogic), precision=ML_StdLogic, tag=tag + "_implicit_digit", ) result = Concatenation( implicit_digit, TypeCast(field_op, precision=ML_StdLogicVectorFormat( op_precision.get_field_size())), precision=ML_StdLogicVectorFormat(op_precision.get_mantissa_size()), ) return result
def generate_scheme(self): # declaring function input variable vx = self.implementation.add_input_variable("x", self.precision) vy = self.implementation.add_input_variable("y", self.precision) Cst0 = Constant(5, precision=self.precision) Cst1 = Constant(7, precision=self.precision) comp = Comparison(vx, vy, specifier=Comparison.Greater, precision=ML_Bool, tag="comp") comp_eq = Comparison(vx, vy, specifier=Comparison.Equal, precision=ML_Bool, tag="comp_eq") scheme = Statement( ConditionBlock( comp, Return(vy, precision=self.precision), ConditionBlock( comp_eq, Return(vx + vy * Cst0 - Cst1, precision=self.precision))), ConditionBlock(comp_eq, Return(Cst1 * vy, precision=self.precision)), Return(vx * vy, precision=self.precision)) return scheme
def generate_scheme(self): """ main scheme generation """ Log.report(Log.Info, "input_precision is {}".format(self.input_precision)) Log.report(Log.Info, "output_precision is {}".format(self.output_precision)) shift_amount_precision = fixed_point(3, 0, signed=False) # declaring main input variable var_x = self.implementation.add_input_signal("x", self.input_precision) var_y = self.implementation.add_input_signal("s", shift_amount_precision) cst_8 = Constant(9, precision=ML_Integer) cst_7 = Constant(7, precision=ML_Integer) cst_right_shifted_x = BitLogicRightShift(var_x, cst_8) cst_left_shifted_x = BitLogicLeftShift(var_x, cst_7) dyn_right_shifted_x = BitLogicRightShift(var_x, var_y) dyn_left_shifted_x = BitLogicLeftShift(var_x, var_y) result = cst_right_shifted_x + cst_left_shifted_x + dyn_right_shifted_x + dyn_left_shifted_x # output self.implementation.add_output_signal("vr_out", result) return [self.implementation]
def generate_embedded_testbench(self, tc_list, io_map, input_signals, output_signals, time_step, test_fname="test.input"): """ Generate testbench with embedded input and output data """ self_component = self.implementation.get_component_object() self_instance = self_component(io_map = io_map, tag = "tested_entity") test_statement = Statement() for index, (input_values, output_values) in enumerate(tc_list): test_statement.add( self.implement_test_case(io_map, input_values, output_signals, output_values, time_step, index=index) ) reset_statement = self.get_reset_statement(io_map, time_step) testbench = CodeEntity("testbench") test_process = Process( reset_statement, test_statement, # end of test Assert( Constant(0, precision = ML_Bool), " \"end of test, no error encountered \"", severity = Assert.Warning ), # infinite end loop WhileLoop( Constant(1, precision=ML_Bool), Statement( Wait(time_step * (self.stage_num + 2)), ) ) ) testbench_scheme = Statement( self_instance, test_process ) if self.pipelined: half_time_step = time_step / 2 assert (half_time_step * 2) == time_step # adding clock process for pipelined bench clk_process = Process( Statement( ReferenceAssign( io_map["clk"], Constant(1, precision = ML_StdLogic) ), Wait(half_time_step), ReferenceAssign( io_map["clk"], Constant(0, precision = ML_StdLogic) ), Wait(half_time_step), ) ) testbench_scheme.push(clk_process) testbench.add_process(testbench_scheme) return [testbench]
def test_ref_assign(self): """ test behavior of StaticVectorizer on predicated ReferenceAssign """ va = Variable("a") vb = Variable("b") vc = Variable("c") scheme = Statement( ReferenceAssign(va, Constant(3)), ConditionBlock( (va > vb).modify_attributes(likely=True), Statement(ReferenceAssign(vb, va), ReferenceAssign(va, Constant(11)), Return(va)), ), ReferenceAssign(va, Constant(7)), Return(vb)) vectorized_path = StaticVectorizer().extract_vectorizable_path( scheme, fallback_policy) linearized_most_likely_path = instanciate_variable( vectorized_path.linearized_optree, vectorized_path.variable_mapping) test_result = (isinstance(linearized_most_likely_path, Constant) and linearized_most_likely_path.get_value() == 11) if not test_result: print("test UT_StaticVectorizer failure") print("scheme: {}".format(scheme.get_str())) print("linearized_most_likely_path: {}".format( linearized_most_likely_path)) self.assertTrue(test_result)
def subnormalize(x_list, factor, precision=None, fma=True): """ x_list is a multi-component number with components ordered from the most to the least siginificant. x_list[0] must be the rounded evaluation of (x_list[0] + x_list[1] + ...) @return the field of x as a floating-point number assuming the exponent of the result is exponent(x) + factor and managing field subnormalization if required """ x_hi = x_list[0] int_precision = precision.get_integer_format() ex = ExponentExtraction(x_hi, precision=int_precision) scaled_ex = ex + factor # difference betwen x's real exponent and the minimal exponent # for a floating of format precision CI0 = Constant(0, precision=int_precision) CI1 = Constant(1, precision=int_precision) delta = Max(Min(precision.get_emin() - scaled_ex, CI0), Constant(precision.get_field_size(), precision=int_precision)) casted_int_x = TypeCast(x_hi, precision=int_precision) # compute a constant to be added to a casted floating-point to perform # rounding. This constant shall be equivalent to a half-ulp round_cst = BitLogicLeftShift(CI1, delta - 1, precision=int_precision) pre_rounded_value = TypeCast(casted_int_x + round_cst, precision=precision) sticky_shift = precision.get_bit_size() - (delta - 1) sticky = BitLogicLeftShift(casted_int_x, sticky_shift, precision=int_precision) low_sticky_sign = CI0 if len(x_list) > 1: for x_op in x_list[1:]: sticky = BitLogicOr(sticky, x_op) low_sticky_sign = BitLogicOr( BitLogicXor(CopySign(x_hi), CopySign(x_op)), low_sticky_sign) # does the low sticky (x_list[1:]) differs in signedness from x_hi ? parity_bit = BitLogicAnd(casted_int_x, BitLogicLeftShift(1, delta, precision=int_precision), precision=int_precision) inc_select = LogicalAnd(Equal(sticky, CI0), Equal(parity_bit, CI0)) rounded_value = Select(inc_select, x, pre_rounded_value, precision=precision) # cleaning trailing-bits return TypeCast(BitLogicRightShift(BitLogicLeftShift( TypeCast(rounded_value, precision=int_precision), delta, precision=int_precision), delta, precision=int_precision), precision=precision)
def subnormalize_multi(x_list, factor, precision=None, fma=True): """ x_list is a multi-component number with components ordered from the most to the least siginificant. x_list[0] must be the rounded evaluation of (x_list[0] + x_list[1] + ...) @return the field of x as a floating-point number assuming the exponent of the result is exponent(x) + factor and managing field subnormalization if required """ x_hi = x_list[0] int_precision = precision.get_integer_format() ex = ExponentExtraction(x_hi, precision=int_precision) scaled_ex = Addition(ex, factor, precision=int_precision) CI0 = Constant(0, precision=int_precision) CI1 = Constant(1, precision=int_precision) # difference betwen x's real exponent and the minimal exponent # for a floating of format precision delta = Max(Min(Subtraction(Constant(precision.get_emin_normal(), precision=int_precision), scaled_ex, precision=int_precision), CI0, precision=int_precision), Constant(precision.get_field_size(), precision=int_precision), precision=int_precision) round_factor_exp = Addition(delta, ex, precision=int_precision) round_factor = ExponentInsertion(round_factor_exp, precision=precision) # to force a rounding as if x_hi was of precision p - delta # we use round_factor as follows: # o(o(round_factor + x_hi) - round_factor) if len(x_list) == 2: rounded_x_hi = Subtraction(Add112(round_factor, x_list[0], x_list[1], precision=precision)[0], round_factor, precision=precision) elif len(x_list) == 3: rounded_x_hi = Subtraction(Add113(round_factor, x_list[0], x_list[1], x_list[2], precision=precision)[0], round_factor, precision=precision) else: Log.report(Log.Error, "len of x_list: {} is not supported in subnormalize_multi", len(x_list)) raise NotImplementedError return [rounded_x_hi] + [ Constant(0, precision=precision) for i in range(len(x_list) - 1) ]
def legalize_comp_sign(node): """ legalize a Test.CompSign node to a series of comparison with 0 and logical operation """ # TODO/IDEA: could also be implemented by two 2 copy sign with 1.0 and valuda # comparison lhs = node.get_input(0) lhs_zero = Constant(0, precision=lhs.get_precision()) rhs = node.get_input(1) rhs_zero = Constant(0, precision=rhs.get_precision()) return LogicalOr( LogicalAnd(lhs >= lhs_zero, rhs >= rhs_zero), LogicalAnd(lhs <= lhs_zero, rhs <= rhs_zero), )
def dirty_multi_node_expand(node, precision, mem_map=None, fma=True): """ Dirty expand node into Hi and Lo part, storing already processed temporary values in mem_map """ mem_map = mem_map or {} if node in mem_map: return mem_map[node] elif isinstance(node, Constant): value = node.get_value() value_hi = sollya.round(value, precision.sollya_object, sollya.RN) value_lo = sollya.round(value - value_hi, precision.sollya_object, sollya.RN) ch = Constant(value_hi, tag=node.get_tag() + "hi", precision=precision) cl = Constant(value_lo, tag=node.get_tag() + "lo", precision=precision) if value_lo != 0 else None if cl is None: Log.report(Log.Info, "simplified constant") result = ch, cl mem_map[node] = result return result else: # Case of Addition or Multiplication nodes: # 1. retrieve inputs # 2. dirty convert inputs recursively # 3. forward to the right metamacro assert isinstance(node, Addition) or isinstance(node, Multiplication) lhs = node.get_input(0) rhs = node.get_input(1) op1h, op1l = dirty_multi_node_expand(lhs, precision, mem_map, fma) op2h, op2l = dirty_multi_node_expand(rhs, precision, mem_map, fma) if isinstance(node, Addition): result = Add222(op1h, op1l, op2h, op2l) \ if op1l is not None and op2l is not None \ else Add212(op1h, op2h, op2l) \ if op1l is None and op2l is not None \ else Add212(op2h, op1h, op1l) \ if op2l is None and op1l is not None \ else Add211(op1h, op2h) mem_map[node] = result return result elif isinstance(node, Multiplication): result = Mul222(op1h, op1l, op2h, op2l, fma=fma) \ if op1l is not None and op2l is not None \ else Mul212(op1h, op2h, op2l, fma=fma) \ if op1l is None and op2l is not None \ else Mul212(op2h, op1h, op1l, fma=fma) \ if op2l is None and op1l is not None \ else Mul211(op1h, op2h, fma=fma) mem_map[node] = result return result
def get_reset_statement(self, io_map, time_step): reset_statement = Statement() if self.reset_pipeline: # TODO: fix pipeline register reset reset_value = 0 if self.negate_reset else 1 unreset_value = 1 - reset_value reset_signal = io_map[self.reset_name] reset_statement.add(ReferenceAssign(reset_signal, Constant(reset_value, precision=ML_StdLogic))) # to account for synchronous reset reset_statement.add(Wait(time_step * 3)) reset_statement.add(ReferenceAssign(reset_signal, Constant(unreset_value, precision=ML_StdLogic))) reset_statement.add(Wait(time_step * 3)) for recirculate_signal in self.recirculate_signal_map.values(): reset_statement.add(ReferenceAssign(io_map[recirculate_signal.get_tag()], Constant(0, precision=ML_StdLogic))) return reset_statement
def check_true(element): if element.get_precision() is ML_Bool: return element else: return NotEqual( element, Constant(0, precision=element.get_precision()))
def check_false(element): if element.get_precision() is ML_Bool: return LogicalNot(element, precision=ML_Bool) else: return Equal( element, Constant(0, precision=element.get_precision()))
def implement_test_case(self, io_map, input_values, output_signals, output_values, time_step): """ Implement the test case check and assertion whose I/Os values are described in input_values and output_values dict """ test_statement = Statement() input_msg = "" # Adding input setting for input_tag in input_values: input_signal = io_map[input_tag] # FIXME: correct value generation depending on signal precision input_value = input_values[input_tag] test_statement.add(get_input_assign(input_signal, input_value)) input_msg += get_input_msg(input_tag, input_signal, input_value) test_statement.add(Wait(time_step * (self.stage_num + 2))) # Adding output value comparison for output_tag in output_signals: output_signal = output_signals[output_tag] output_value = output_values[output_tag] output_cst_value = Constant(output_value, precision=output_signal.get_precision()) value_msg = get_output_value_msg(output_signal, output_value) test_pass_cond, check_statement = get_output_check_statement(output_signal, output_tag, output_cst_value) test_statement.add(check_statement) assert_statement = Assert( test_pass_cond, "\"unexpected value for inputs {input_msg}, output {output_tag}, expecting {value_msg}, got: \"".format(input_msg = input_msg, output_tag = output_tag, value_msg = value_msg), severity = Assert.Failure ) test_statement.add(assert_statement) return test_statement
def get_input_assign(input_signal, input_value): """ Get input assignation statement """ input_assign = ReferenceAssign( input_signal, Constant(input_value, precision=input_signal.get_precision()) ) return input_assign
def generate_scheme(self): vx = self.implementation.add_input_variable("x", FIXED_FORMAT) # declaring specific interval for input variable <x> vx.set_interval(Interval(-1, 1)) acc_format = ML_Custom_FixedPoint_Format(6, 58, False) c = Constant(2, precision=acc_format, tag="C2") ivx = vx add_ivx = Addition( c, Multiplication(ivx, ivx, precision=acc_format, tag="mul"), precision=acc_format, tag="add" ) result = add_ivx input_mapping = {ivx: ivx.get_precision().round_sollya_object(0.125)} error_eval_map = runtime_error_eval.generate_error_eval_graph(result, input_mapping) # dummy scheme to make functionnal code generation scheme = Statement() for node in error_eval_map: scheme.add(error_eval_map[node]) scheme.add(Return(result)) return scheme
def generate_node_eval_error(optree, input_mapping, node_error_map, node_value_map): if optree in node_error_map or optree in input_mapping: return # placeholder to avoid diplicate complication node_error_map[optree] = None # recursive on node inputs if not is_leaf_node(optree): for op in optree.get_inputs(): generate_node_eval_error(op, input_mapping, node_error_map, node_value_map) expected_value = node_value_map[optree] assert expected_value != None expected_node = Constant(expected_value, precision=optree.get_precision()) precision = optree.get_precision() #error_node = Abs( # # FIXME/ may need to insert signed type if optree/expected_node are # # unsigned # Subtraction( # optree, # expected_node, # precision=precision), # precision=precision) error_node = Subtraction( optree, expected_node, precision=precision) error_display_statement = get_printf_value(optree, error_node, expected_node) node_error_map[optree] = error_display_statement
def generate_test_wrapper(self, tensor_descriptors, input_tables, output_tables): auto_test = CodeFunction("test_wrapper", output_format=ML_Int32) tested_function = self.implementation.get_function_object() function_name = self.implementation.get_name() failure_report_op = FunctionOperator("report_failure") failure_report_function = FunctionObject("report_failure", [], ML_Void, failure_report_op) printf_success_op = FunctionOperator( "printf", arg_map={0: "\"test successful %s\\n\"" % function_name}, void_function=True, require_header=["stdio.h"]) printf_success_function = FunctionObject("printf", [], ML_Void, printf_success_op) # accumulate element number acc_num = Variable("acc_num", precision=ML_Int64, var_type=Variable.Local) test_loop = self.get_tensor_test_wrapper( tested_function, tensor_descriptors, input_tables, output_tables, acc_num, self.generate_tensor_check_loop) # common test scheme between scalar and vector functions test_scheme = Statement(test_loop, printf_success_function(), Return(Constant(0, precision=ML_Int32))) auto_test.set_scheme(test_scheme) return FunctionGroup([auto_test])
def generate_scheme(self): """ main scheme generation """ Log.report(Log.Info, "width parameter is {}".format(self.width)) int_size = 3 frac_size = self.width - int_size input_precision = fixed_point(int_size, frac_size) output_precision = fixed_point(int_size, frac_size) # declaring main input variable var_x = self.implementation.add_input_signal("x", input_precision) var_y = self.implementation.add_input_signal("y", input_precision) var_x.set_attributes(debug = debug_fixed) var_y.set_attributes(debug = debug_fixed) sub = var_x - var_y c = Constant(0) self.implementation.start_new_stage() #pre_result = Select( # c > sub, # c, # sub #) pre_result = Max(0, sub) self.implementation.start_new_stage() result = Conversion(pre_result + var_x, precision=output_precision) self.implementation.add_output_signal("vr_out", result) return [self.implementation]
def fixed_point_position_legalizer(optree, input_prec_solver=default_prec_solver): """ Legalize a FixedPointPosition node to a constant """ assert isinstance(optree, FixedPointPosition) fixed_input = optree.get_input(0) fixed_precision = input_prec_solver(fixed_input) if not is_fixed_point(fixed_precision): Log.report( Log.Error, "in fixed_point_position_legalizer: precision of {} should be fixed-point but is {}" .format(fixed_input, fixed_precision)) position = optree.get_input(1).get_value() align = optree.get_align() value_computation_map = { FixedPointPosition.FromLSBToLSB: position, FixedPointPosition.FromMSBToLSB: fixed_precision.get_bit_size() - 1 - position, FixedPointPosition.FromPointToLSB: fixed_precision.get_frac_size() + position, FixedPointPosition.FromPointToMSB: fixed_precision.get_integer_size() - position } cst_value = value_computation_map[align] # display value Log.report( Log.LogLevel("FixedPoint"), "fixed-point position {tag} has been resolved to {value}".format( tag=optree.get_tag(), value=cst_value)) result = Constant(cst_value, precision=ML_Integer) forward_attributes(optree, result) return result
def lower_node(self, node): dst_reg = node.get_input(0) materialze_op = node.get_input(1) assert isinstance(materialze_op, MaterializeConstant) src_value = materialze_op.get_input(0) sub_regs = split_register(dst_reg, self.num_chunk) out_vsize = src_value.get_precision().get_vector_size( ) // self.num_chunk out_vformat = vectorize_format( src_value.get_precision().get_scalar_format(), out_vsize) sub_values = [ MaterializeConstant(Constant([ src_value.get_value()[chunk_id * out_vsize + sub_id] for sub_id in range(out_vsize) ], precision=out_vformat), precision=out_vformat) for chunk_id in range(self.num_chunk) ] # single RegisterAssign is lowered into a sequence # of sub-register assign lowered_sequence = SequentialBlock(*tuple( RegisterAssign(sub_reg, sub_value) for sub_reg, sub_value in zip(sub_regs, sub_values))) return lowered_sequence
def generate_field_extraction(optree, precision, lo_index, hi_index): """ extract bit-field optree[lo_index:hi_index-1] and cast to precision """ if optree.precision != precision: optree = TypeCast(optree, precision=precision) result = optree if lo_index != 0: result = BitLogicRightShift(optree, Constant(vectorize_cst( lo_index, precision), precision=precision), precision=precision) if (hi_index - lo_index + 1) != precision.get_bit_size(): mask = Constant(vectorize_cst(2**(hi_index - lo_index + 1) - 1, precision), precision=precision) result = BitLogicAnd(result, mask, precision=precision) return result
def simplify_logical_op(node): """ Simplify LogicOperation node """ if isinstance(node, LogicalAnd): lhs = node.get_input(0) rhs = node.get_input(1) if is_false(lhs) or is_false(rhs): # FIXME: manage vector constant cst = generate_uniform_cst(False, node.get_precision()) return cst elif is_true(lhs) and is_true(rhs): # FIXME: manage vector constant cst = generate_uniform_cst(True, node.get_precision()) return cst elif is_true(lhs): return rhs elif is_true(rhs): return lhs elif isinstance(node, LogicalOr): lhs = node.get_input(0) rhs = node.get_input(1) if is_false(lhs) and is_false(rhs): # FIXME: manage vector constant cst = generate_uniform_cst(False, node.get_precision()) return cst elif is_true(lhs) or is_true(rhs): # FIXME: manage vector constant cst = generate_uniform_cst(True, node.get_precision()) return cst elif is_constant(lhs) and is_constant(rhs): if node.get_precision().is_vector_format(): return Constant( [(sub_lhs or sub_rhs) for sub_lhs, sub_rhs in zip(lhs.get_value(), rhs.get_value())], precision=node.get_precision()) else: return Constant(lhs.get_value() or rhs.get_value(), precision=node.get_precision()) elif isinstance(node, LogicalNot): op = node.get_input(0) if is_constant(op): # only support simplification of LogicalNot(Constant) if not op.get_precision().is_vector_format(): return Constant(not op.get_value(), precision=node.get_precision()) else: return Constant( [not elt_value for elt_value in op.get_value()] , precision=node.get_precision()) return None
def legalize_invsqrt_seed(optree): """ Legalize an InverseSquareRootSeed optree """ assert isinstance(optree, ReciprocalSquareRootSeed) op_prec = optree.get_precision() # input = 1.m_hi-m_lo * 2^e # approx = 2^(-int(e/2)) * approx_insqrt(1.m_hi) * (e % 2 ? 1.0 : ~2**-0.5) op_input = optree.get_input(0) convert_back = False approx_prec = ML_Binary32 if op_prec != approx_prec: op_input = Conversion(op_input, precision=ML_Binary32) convert_back = True # TODO: fix integer precision selection # as we are in a late code generation stage, every node's precision # must be set op_exp = ExponentExtraction(op_input, tag="op_exp", debug=debug_multi, precision=ML_Int32) neg_half_exp = Division(Negation(op_exp, precision=ML_Int32), Constant(2, precision=ML_Int32), precision=ML_Int32) approx_exp = ExponentInsertion(neg_half_exp, tag="approx_exp", debug=debug_multi, precision=approx_prec) op_exp_parity = Modulo(op_exp, Constant(2, precision=ML_Int32), precision=ML_Int32) approx_exp_correction = Select(Equal(op_exp_parity, Constant(0, precision=ML_Int32)), Constant(1.0, precision=approx_prec), Select(Equal( op_exp_parity, Constant(-1, precision=ML_Int32)), Constant(S2**0.5, precision=approx_prec), Constant(S2**-0.5, precision=approx_prec), precision=approx_prec), precision=approx_prec, tag="approx_exp_correction", debug=debug_multi) table_index = invsqrt_approx_table.get_index_function()(op_input) table_index.set_attributes(tag="invsqrt_index", debug=debug_multi) approx = Multiplication(TableLoad(invsqrt_approx_table, table_index, precision=approx_prec), Multiplication(approx_exp_correction, approx_exp, precision=approx_prec), tag="invsqrt_approx", debug=debug_multi, precision=approx_prec) if approx_prec != op_prec: return Conversion(approx, precision=op_prec) else: return approx
def Split(a, precision=None): """... splitting algorithm for Dekker TwoMul""" cst_value = {ML_Binary32: 4097, ML_Binary64: 134217729}[a.precision] s = Constant(cst_value, precision=a.get_precision(), tag='fp_split') c = Multiplication(s, a, precision=precision) tmp = Subtraction(a, c, precision=precision) ah = Addition(tmp, c, precision=precision) al = Subtraction(a, ah, precision=precision) return ah, al
def get_ordered_arg_tuple(self, tensor_descriptors, input_tables, output_tables): (input_tensor_descriptor_list, output_tensor_descriptor_list) = tensor_descriptors tA_desc = input_tensor_descriptor_list[0] tB_desc = input_tensor_descriptor_list[1] p = tA_desc.sdim[0] n = tA_desc.sdim[1] m = tB_desc.sdim[0] index_format = ML_Int32 return ( input_tables[0], input_tables[1], output_tables[0], Constant(n, precision=index_format), Constant(m, precision=index_format), Constant(p, precision=index_format), )
def compute_sqrt(vx, init_approx, num_iter, debug_lftolx = None, precision = ML_Binary64): C_half = Constant(0.5, precision = precision) h = C_half * vx h.set_attributes(tag = "h", debug = debug_multi, silent = True, rounding_mode = ML_RoundToNearest) current_approx = init_approx # correctly-rounded inverse computation for i in range(num_iter): new_iteration = NR_Iteration(vx, current_approx, h, C_half) current_approx = new_iteration.get_new_approx() current_approx.set_attributes(tag = "iter_%d" % i, debug = debug_multi) final_approx = current_approx final_approx.set_attributes(tag = "final_approx", debug = debug_multi) # multiplication correction iteration # to get correctly rounded full square root Attributes.set_default_silent(True) Attributes.set_default_rounding_mode(ML_RoundToNearest) S = vx * final_approx t5 = final_approx * h H = C_half * final_approx d = FMSN(S, S, vx) t6 = FMSN(t5, final_approx, C_half) S1 = FMA(d, H, S) H1 = FMA(t6, H, H) d1 = FMSN(S1, S1, vx) pR = FMA(d1, H1, S1) d_last = FMSN(pR, pR, vx, silent = True, tag = "d_last") S.set_attributes(tag = "S") t5.set_attributes(tag = "t5") H.set_attributes(tag = "H") d.set_attributes(tag = "d") t6.set_attributes(tag = "t6") S1.set_attributes(tag = "S1") H1.set_attributes(tag = "H1") d1.set_attributes(tag = "d1") Attributes.unset_default_silent() Attributes.unset_default_rounding_mode() R = FMA(d_last, H1, pR, rounding_mode = ML_GlobalRoundMode, tag = "NR_Result", debug = debug_multi) # set precision propagate_format(R, precision) propagate_format(S1, precision) propagate_format(H1, precision) propagate_format(d1, precision) return R
def expand_kernel_expr(kernel, iterator_format=ML_Int32): """ Expand a kernel expression into the corresponding MDL graph """ if isinstance(kernel, NDRange): return expand_ndrange(kernel) elif isinstance(kernel, Sum): var_iter = kernel.index_iter_range.var_index # TODO/FIXME to be uniquified acc = Variable("acc", var_type=Variable.Local, precision=kernel.precision) # TODO/FIXME implement proper acc init if kernel.precision.is_vector_format(): C0 = Constant([0] * kernel.precision.get_vector_size(), precision=kernel.precision) else: C0 = Constant(0, precision=kernel.precision) scheme = Loop( Statement( ReferenceAssign(var_iter, kernel.index_iter_range.first_index), ReferenceAssign(acc, C0)), var_iter <= kernel.index_iter_range.last_index, Statement( ReferenceAssign( acc, Addition(acc, expand_kernel_expr(kernel.elt_operation), precision=kernel.precision)), # loop iterator increment ReferenceAssign(var_iter, var_iter + kernel.index_iter_range.index_step))) return PlaceHolder(acc, scheme) elif isinstance(kernel, (ReadAccessor, WriteAccessor)): return expand_accessor(kernel) elif is_leaf_node(kernel): return kernel else: # vanilla metalibm ops are left unmodified (except # recursive expansion) for index, op in enumerate(kernel.inputs): new_op = expand_kernel_expr(op) kernel.set_input(index, new_op) return kernel
def get_index_node(self, vx): assert vx.precision is self.precision int_precision = vx.precision.get_integer_format() index_size = self.exp_bits + self.field_bits # building an index mask from the index_size index_mask = Constant(2**index_size - 1, precision=int_precision) shift_amount = Constant(vx.get_precision().get_field_size() - self.field_bits, precision=int_precision) exp_offset = Constant(self.precision.get_integer_coding( S2**self.low_exp_value), precision=int_precision) return BitLogicAnd(BitLogicRightShift(Subtraction( TypeCast(vx, precision=int_precision), exp_offset, precision=int_precision), shift_amount, precision=int_precision), index_mask, precision=int_precision)
def Split(a): """... splitting algorithm for Dekker TwoMul""" # if a.get_precision() == ML_Binary32: s = Constant(4097, precision=a.get_precision(), tag='fp_split') # elif a.get_precision() == ML_Binary64: # s = Constant(134217729, precision = a.get_precision(), tag = 'fp_split') c = Multiplication(s, a) tmp = Subtraction(a, c) ah = Addition(tmp, c) al = Subtraction(a, ah) return ah, al