def is_simplifiable_multiplication(node, lhs, rhs): assert isinstance(node, Multiplication) # NOTES: floating-point multiplication is not easily simplifiable # for example 0 * a = a if a is finite, but 0 * a = NaN if a is infinite # Thus this simplification is limited to operations between fixed-point datum fixed_predicate = is_fixed_format(lhs.get_precision()) and is_fixed_format(rhs.get_precision()) return fixed_predicate and ((is_constant(lhs) and lhs.get_value() in [0, 1]) or (is_constant(rhs) and rhs.get_value() in [0, 1]))
def is_simplifiable_add(node, node_inputs): if is_instance(node, Addition): lhs = node_inputs[0] rhs = node_inputs[1] return (is_constant(lhs) and lhs.get_value() == 0) or (is_constant(rhs) and rhs.get_value() == 0) else: return False
def is_simplifiable_logical_op(node, node_inputs): """ test if the LogicOperation node(node_inputs) is simplifiable """ if isinstance(node, (LogicalAnd, LogicalOr)): lhs = node_inputs[0] rhs = node_inputs[1] return is_constant(lhs) or is_constant(rhs) elif isinstance(node, LogicalNot): return is_constant(node_inputs[0]) else: return False
def propagate_format_to_input(new_format, optree, input_index_list): """ Propgate new_format to @p optree's input whose index is listed in @p input_index_list """ for op_index in input_index_list: op_input = optree.get_input(op_index) if op_input.get_precision() is None: op_input.set_precision(new_format) index_list = does_node_propagate_format(op_input) propagate_format_to_input(new_format, op_input, index_list) elif not test_format_equality(new_format, op_input.get_precision()): if is_constant(op_input): if not is_fixed_point(new_format): Log.report( Log.Error, "format {} during propagation to input {} of {} is not a fixed-point format", new_format, op_input, optree) elif format_does_fit(op_input, new_format): Log.report( Log.Info, "Simplify Constant Conversion {} to larger Constant: {}", op_input.get_str(display_precision=True) if Log.is_level_enabled(Log.Info) else "", str(new_format)) new_input = op_input.copy() new_input.set_precision(new_format) optree.set_input(op_index, new_input) else: Log.report( Log.Error, "Constant is about to be reduced to a too constrained format: {}", op_input.get_str(display_precision=True) if Log.is_level_enabled(Log.Error) else "") else: new_input = Conversion(op_input, precision=new_format) optree.set_input(op_index, new_input)
def evaluate_graph_value(optree, input_mapping, memoization_map=None): """ Given the node -> value mapping input_mapping, evaluate optree numerical value """ # initializing memoization_map memoization_map = {} if memoization_map is None else memoization_map # computing values if optree in memoization_map: return memoization_map[optree] elif optree in input_mapping: value = input_mapping[optree] elif is_constant(optree): value = optree.get_value() elif is_typecast(optree): input_value = evaluate_graph_value(optree.get_input(0), input_mapping, memoization_map) value = evaluate_typecast_value(optree, input_value) elif is_conversion(optree): input_value = evaluate_graph_value(optree.get_input(0), input_mapping, memoization_map) value = evaluate_conversion_value(optree, input_value) else: args_interval = tuple( evaluate_graph_value(op, input_mapping, memoization_map) for op in optree.get_inputs() ) value = optree.apply_bare_range_function(args_interval) memoization_map[optree] = value Log.report(LOG_RUNTIME_EVAL_ERROR, "node {} value has been evaluated to: {}", optree.get_tag(), value) return value
def split_vectorial_op(node, output_vsize=2): """ Split a vectorial node <node> into a list of sub-vectors, each of size output_vsize input <node> vector-size must be a multiple of <output_vsize> """ input_vsize = node.get_precision().get_vector_size() scalar_format = node.get_precision().get_scalar_format() if is_constant(node): sub_ops = [Constant( [node.get_value()[sub_id * output_vsize + j] for j in range(output_vsize)], precision=vectorize_format(scalar_format, output_vsize) ) for sub_id in range(input_vsize // output_vsize)] else: CI = lambda v: Constant(v, precision=ML_Integer) bool_specifier = None if scalar_format is ML_Bool: bool_specifier = node.get_precision().boolean_bitwidth sub_vector_fmt = vectorize_format(scalar_format, output_vsize, bool_specifier=bool_specifier) sub_ops = [SubVectorExtract(node, *tuple(CI(sub_id * output_vsize +j) for j in range(output_vsize)), precision=sub_vector_fmt) for sub_id in range(input_vsize // output_vsize)] #split_ops = [VectorElementSelection(node, Constant(i, precision=ML_Integer), precision=scalar_format) for i in range(input_vsize)] #sub_ops = [VectorAssembling( # *tuple(split_ops[sub_id * output_vsize + j] for j in range(output_vsize)), # precision=vectorize_format(scalar_format, output_vsize) #) for sub_id in range(input_vsize // output_vsize)] return sub_ops
def simplify_logical_op(node): """ Simplify LogicOperation node """ if isinstance(node, LogicalAnd): lhs = node.get_input(0) rhs = node.get_input(1) if is_false(lhs) or is_false(rhs): # FIXME: manage vector constant cst = generate_uniform_cst(False, node.get_precision()) return cst elif is_true(lhs) and is_true(rhs): # FIXME: manage vector constant cst = generate_uniform_cst(True, node.get_precision()) return cst elif is_true(lhs): return rhs elif is_true(rhs): return lhs elif isinstance(node, LogicalOr): lhs = node.get_input(0) rhs = node.get_input(1) if is_false(lhs) and is_false(rhs): # FIXME: manage vector constant cst = generate_uniform_cst(False, node.get_precision()) return cst elif is_true(lhs) or is_true(rhs): # FIXME: manage vector constant cst = generate_uniform_cst(True, node.get_precision()) return cst elif is_constant(lhs) and is_constant(rhs): if node.get_precision().is_vector_format(): return Constant( [(sub_lhs or sub_rhs) for sub_lhs, sub_rhs in zip(lhs.get_value(), rhs.get_value())], precision=node.get_precision()) else: return Constant(lhs.get_value() or rhs.get_value(), precision=node.get_precision()) elif isinstance(node, LogicalNot): op = node.get_input(0) if is_constant(op): # only support simplification of LogicalNot(Constant) if not op.get_precision().is_vector_format(): return Constant(not op.get_value(), precision=node.get_precision()) else: return Constant( [not elt_value for elt_value in op.get_value()] , precision=node.get_precision()) return None
def transform_to_physical_reg(self, color_map, linearized_program): """ transform each MachineRegister in linearized_program into the corresponding PhysicalRegister """ for bb in linearized_program.inputs: for node in bb.inputs: if isinstance(node, RegisterAssign): # dst reg node.set_input( 0, self.get_physical_reg(color_map, node.get_input(0))) # src regs value_node = node.get_input(1) if not is_leaf_node(value_node): for index, op in enumerate(value_node.inputs): if isinstance(op, MachineRegister): value_node.set_input( index, self.get_physical_reg(color_map, op)) elif is_constant(op): value_node.set_input(index, op) else: raise NotImplementedError elif isinstance(value_node, MachineRegister): node.set_input( 1, self.get_physical_reg(color_map, value_node)) elif is_constant(value_node): node.set_input(1, value_node) else: raise NotImplementedError elif isinstance(node, ConditionalBranch): op = node.get_input(0) if isinstance(op, MachineRegister): node.set_input(0, self.get_physical_reg(color_map, op)) else: raise NotImplementedError elif isinstance(node, TableStore): for index, op in enumerate(node.inputs): if isinstance(op, MachineRegister): node.set_input( index, self.get_physical_reg(color_map, op)) elif is_constant(op): node.set_input(index, op) else: raise NotImplementedError
def simplify_multiplication(node): """ Simplify a multiplication node between two fixed-point if one of them is a constant in {0, 1} """ assert isinstance(node, Multiplication) assert is_fixed_format(node.get_input(0).get_precision()) assert is_fixed_format(node.get_input(1).get_precision()) lhs = node.get_input(0) rhs = node.get_input(1) if is_constant(lhs): if lhs.get_value() == 0: return Constant(0, precision=node.get_precision()) elif lhs.get_value() == 1: return rhs elif is_constant(rhs): if rhs.get_value() == 0: return Constant(0, precision=node.get_precision()) elif rhs.get_value() == 1: return lhs # no simplification found return None
def format_does_fit(cst_optree, new_format): """ Test if @p cst_optree fits into the precision @p new_format """ assert is_constant(cst_optree) assert is_fixed_point(new_format) # min_format is a dummy format used simply to check # if constant fits in new_format min_format = determine_minimal_fixed_format_cst(cst_optree.get_value()) sign_bias = 1 if (new_format.get_signed() and not min_format.get_signed()) \ else 0 return (new_format.get_integer_size() - sign_bias) >= \ min_format.get_integer_size() and \ new_format.get_frac_size() >= min_format.get_frac_size() and \ (new_format.get_signed() or not min_format.get_signed())
def generate_insn_from_node(self, node): """ generate a asmde.Instruction which corresponds to node """ if isinstance(node, RegisterAssign): dst_reg = node.get_input(0) dst_reg_list = [ sub_reg for reg in [dst_reg] for sub_reg in self.generate_allocatable_register(reg) ] value_node = node.get_input(1) if is_constant(value_node): src_reg_list = [] else: src_reg_list = [ sub_reg for reg in extract_src_regs_from_node(value_node) for sub_reg in self.generate_allocatable_register(reg) ] insn = asmde.Instruction(node, dbg_object=node, def_list=dst_reg_list, use_list=src_reg_list) return insn elif isinstance(node, TableStore): # TODO/FIXME: may need the generation of a shaddow dependency chain # to maintain TableStore/TableLoad relative order when required src_reg_list = [ sub_reg for reg in extract_src_regs_from_node(node) for sub_reg in self.generate_allocatable_register(reg) ] insn = asmde.Instruction(node, dbg_object=node, use_list=src_reg_list) return insn else: Log.report( Log.Error, "node unsupported in AssemblySynthesizer.generate_insn_from_node: {}", node) raise NotImplementedError
def legalize_mp_3elt_comparison(optree): """ Transform comparison on ML_Compound_FP_Format object into comparison on sub-fields """ specifier = optree.specifier lhs = optree.get_input(0) rhs = optree.get_input(1) # TODO/FIXME: assume than multi-limb operand are normalized if specifier == Comparison.Equal: # renormalize if not constant lhs = lhs if is_constant(lhs) else BuildFromComponent(*Normalize_33(lhs.hi, lhs.me, lhs.lo, precision=lhs.precision.get_limb_precision(0)), precision=lhs.precision) rhs = rhs if is_constant(rhs) else BuildFromComponent(*Normalize_33(rhs.hi, rhs.me, rhs.lo, precision=rhs.precision.get_limb_precision(0)), precision=rhs.precision) return LogicalAnd( Comparison(lhs.hi, rhs.hi, specifier=Comparison.Equal, precision=ML_Bool), LogicalAnd( Comparison(lhs.me, rhs.me, specifier=Comparison.Equal, precision=ML_Bool), Comparison(lhs.lo, rhs.lo, specifier=Comparison.Equal, precision=ML_Bool), precision=ML_Bool ), precision=ML_Bool ) elif specifier == Comparison.NotEqual: # renormalize if not constant lhs = lhs if is_constant(lhs) else BuildFromComponent(*Normalize_33(lhs.hi, lhs.me, lhs.lo, precision=lhs.precision.get_limb_precision(0)), precision=lhs.precision) rhs = rhs if is_constant(rhs) else BuildFromComponent(*Normalize_33(rhs.hi, rhs.me, rhs.lo, precision=rhs.precision.get_limb_precision(0)), precision=rhs.precision) return LogicalOr( Comparison(lhs.hi, rhs.hi, specifier=Comparison.NotEqual, precision=ML_Bool), LogicalOr( Comparison(lhs.me, rhs.me, specifier=Comparison.NotEqual, precision=ML_Bool), Comparison(lhs.lo, rhs.lo, specifier=Comparison.NotEqual, precision=ML_Bool), precision=ML_Bool ), precision=ML_Bool ) elif specifier in [Comparison.LessOrEqual, Comparison.GreaterOrEqual, Comparison.Greater, Comparison.Less]: strict_specifier = { Comparison.Less: Comparison.Less, Comparison.Greater: Comparison.Greater, Comparison.LessOrEqual: Comparison.Less, Comparison.GreaterOrEqual: Comparison.Greater }[specifier] # renormalize if not constant lhs = lhs if is_constant(lhs) else BuildFromComponent(*Normalize_33(lhs.hi, lhs.me, lhs.lo, precision=lhs.precision.get_limb_precision(0)), precision=lhs.precision) rhs = rhs if is_constant(rhs) else BuildFromComponent(*Normalize_33(rhs.hi, rhs.me, rhs.lo, precision=rhs.precision.get_limb_precision(0)), precision=rhs.precision) return LogicalOr( Comparison(lhs.hi, rhs.hi, specifier=strict_specifier, precision=ML_Bool), LogicalAnd( Comparison(lhs.hi, rhs.hi, specifier=Comparison.Equal, precision=ML_Bool), LogicalOr( Comparison(lhs.me, rhs.me, specifier=strict_specifier, precision=ML_Bool), LogicalAnd( Comparison(lhs.me, rhs.me, specifier=Comparison.Equal, precision=ML_Bool), Comparison(lhs.lo, rhs.lo, specifier=specifier, precision=ML_Bool), precision=ML_Bool ), precision=ML_Bool ), precision=ML_Bool ), precision=ML_Bool ) else: Log.report(Log.Error, "unsupported specifier {} in legalize_mp_2elt_comparison", specifier)