def is_simplifiable_multiplication(node, lhs, rhs):
    assert isinstance(node, Multiplication)
    # NOTES: floating-point multiplication is not easily simplifiable
    # for example 0 * a = a if a is finite, but 0 * a = NaN if a is infinite
    # Thus this simplification is limited to operations between fixed-point datum
    fixed_predicate = is_fixed_format(lhs.get_precision()) and is_fixed_format(rhs.get_precision())
    return fixed_predicate and ((is_constant(lhs) and lhs.get_value() in [0, 1]) or (is_constant(rhs) and rhs.get_value() in [0, 1]))
def is_simplifiable_add(node, node_inputs):
    if is_instance(node, Addition):
        lhs = node_inputs[0]
        rhs = node_inputs[1]
        return (is_constant(lhs) and lhs.get_value() == 0) or (is_constant(rhs) and rhs.get_value() == 0)
    else:
        return False
def is_simplifiable_logical_op(node, node_inputs):
    """ test if the LogicOperation node(node_inputs) is simplifiable """
    if isinstance(node, (LogicalAnd, LogicalOr)):
        lhs = node_inputs[0]
        rhs = node_inputs[1]
        return is_constant(lhs) or is_constant(rhs)
    elif isinstance(node, LogicalNot):
        return is_constant(node_inputs[0])
    else:
        return False
Esempio n. 4
0
def propagate_format_to_input(new_format, optree, input_index_list):
    """ Propgate new_format to @p optree's input whose index is listed in
        @p input_index_list """
    for op_index in input_index_list:
        op_input = optree.get_input(op_index)
        if op_input.get_precision() is None:
            op_input.set_precision(new_format)
            index_list = does_node_propagate_format(op_input)
            propagate_format_to_input(new_format, op_input, index_list)
        elif not test_format_equality(new_format, op_input.get_precision()):
            if is_constant(op_input):
                if not is_fixed_point(new_format):
                    Log.report(
                        Log.Error,
                        "format {} during propagation to input {} of {} is not a fixed-point format",
                        new_format, op_input, optree)
                elif format_does_fit(op_input, new_format):
                    Log.report(
                        Log.Info,
                        "Simplify Constant Conversion {} to larger Constant: {}",
                        op_input.get_str(display_precision=True)
                        if Log.is_level_enabled(Log.Info) else "",
                        str(new_format))
                    new_input = op_input.copy()
                    new_input.set_precision(new_format)
                    optree.set_input(op_index, new_input)
                else:
                    Log.report(
                        Log.Error,
                        "Constant is about to be reduced to a too constrained format: {}",
                        op_input.get_str(display_precision=True)
                        if Log.is_level_enabled(Log.Error) else "")
            else:
                new_input = Conversion(op_input, precision=new_format)
                optree.set_input(op_index, new_input)
Esempio n. 5
0
def evaluate_graph_value(optree, input_mapping, memoization_map=None):
    """ Given the node -> value mapping input_mapping, evaluate
        optree numerical value
    """
    # initializing memoization_map
    memoization_map = {} if memoization_map is None else memoization_map
    # computing values
    if optree in memoization_map:
        return memoization_map[optree]
    elif optree in input_mapping:
        value = input_mapping[optree]
    elif is_constant(optree):
        value = optree.get_value()
    elif is_typecast(optree):
        input_value = evaluate_graph_value(optree.get_input(0), input_mapping, memoization_map)
        value = evaluate_typecast_value(optree, input_value)
    elif is_conversion(optree):
        input_value = evaluate_graph_value(optree.get_input(0), input_mapping, memoization_map)
        value = evaluate_conversion_value(optree, input_value)
    else:
        args_interval = tuple(
            evaluate_graph_value(op, input_mapping, memoization_map) for op in
            optree.get_inputs()
        )
        value = optree.apply_bare_range_function(args_interval)
    memoization_map[optree] = value
    Log.report(LOG_RUNTIME_EVAL_ERROR, "node {} value has been evaluated to: {}", optree.get_tag(), value)
    return value
Esempio n. 6
0
def split_vectorial_op(node, output_vsize=2):
    """ Split a vectorial node <node> into a list of
        sub-vectors, each of size output_vsize

        input <node> vector-size must be a multiple of <output_vsize> """
    input_vsize = node.get_precision().get_vector_size()
    scalar_format = node.get_precision().get_scalar_format()
    if is_constant(node):
        sub_ops = [Constant(
            [node.get_value()[sub_id * output_vsize + j] for j in range(output_vsize)],
            precision=vectorize_format(scalar_format, output_vsize)
        ) for sub_id in range(input_vsize // output_vsize)]
    else:
        CI = lambda v: Constant(v, precision=ML_Integer)
        bool_specifier = None
        if scalar_format is ML_Bool:
            bool_specifier = node.get_precision().boolean_bitwidth
        sub_vector_fmt = vectorize_format(scalar_format, output_vsize, bool_specifier=bool_specifier)
        sub_ops = [SubVectorExtract(node, *tuple(CI(sub_id * output_vsize +j) for j in range(output_vsize)), precision=sub_vector_fmt) for sub_id in range(input_vsize // output_vsize)]
        #split_ops = [VectorElementSelection(node, Constant(i, precision=ML_Integer), precision=scalar_format) for i in range(input_vsize)]

        #sub_ops = [VectorAssembling(
        #    *tuple(split_ops[sub_id * output_vsize + j] for j in range(output_vsize)),
        #    precision=vectorize_format(scalar_format, output_vsize)
        #) for sub_id in range(input_vsize // output_vsize)]
    return sub_ops
def simplify_logical_op(node):
    """ Simplify LogicOperation node """
    if isinstance(node, LogicalAnd):
        lhs = node.get_input(0)
        rhs = node.get_input(1)
        if is_false(lhs) or is_false(rhs):
            # FIXME: manage vector constant
            cst = generate_uniform_cst(False, node.get_precision())
            return cst
        elif is_true(lhs) and is_true(rhs):
            # FIXME: manage vector constant
            cst = generate_uniform_cst(True, node.get_precision())
            return cst
        elif is_true(lhs):
            return rhs
        elif is_true(rhs):
            return lhs
    elif isinstance(node, LogicalOr):
        lhs = node.get_input(0)
        rhs = node.get_input(1)
        if is_false(lhs) and is_false(rhs):
            # FIXME: manage vector constant
            cst = generate_uniform_cst(False, node.get_precision())
            return cst
        elif is_true(lhs) or is_true(rhs):
            # FIXME: manage vector constant
            cst = generate_uniform_cst(True, node.get_precision())
            return cst
        elif is_constant(lhs) and is_constant(rhs):
            if node.get_precision().is_vector_format():
                return Constant(
                    [(sub_lhs or sub_rhs) for sub_lhs, sub_rhs in zip(lhs.get_value(), rhs.get_value())],
                    precision=node.get_precision())
            else:
                return Constant(lhs.get_value() or rhs.get_value(), precision=node.get_precision())
    elif isinstance(node, LogicalNot):
        op = node.get_input(0)
        if is_constant(op):
            # only support simplification of LogicalNot(Constant)
            if not op.get_precision().is_vector_format():
                return Constant(not op.get_value(), precision=node.get_precision())
            else:
                return Constant(
                    [not elt_value for elt_value in op.get_value()]
                    , precision=node.get_precision())
    return None
Esempio n. 8
0
 def transform_to_physical_reg(self, color_map, linearized_program):
     """ transform each MachineRegister in linearized_program into
         the corresponding PhysicalRegister """
     for bb in linearized_program.inputs:
         for node in bb.inputs:
             if isinstance(node, RegisterAssign):
                 # dst reg
                 node.set_input(
                     0, self.get_physical_reg(color_map, node.get_input(0)))
                 # src regs
                 value_node = node.get_input(1)
                 if not is_leaf_node(value_node):
                     for index, op in enumerate(value_node.inputs):
                         if isinstance(op, MachineRegister):
                             value_node.set_input(
                                 index,
                                 self.get_physical_reg(color_map, op))
                         elif is_constant(op):
                             value_node.set_input(index, op)
                         else:
                             raise NotImplementedError
                 elif isinstance(value_node, MachineRegister):
                     node.set_input(
                         1, self.get_physical_reg(color_map, value_node))
                 elif is_constant(value_node):
                     node.set_input(1, value_node)
                 else:
                     raise NotImplementedError
             elif isinstance(node, ConditionalBranch):
                 op = node.get_input(0)
                 if isinstance(op, MachineRegister):
                     node.set_input(0, self.get_physical_reg(color_map, op))
                 else:
                     raise NotImplementedError
             elif isinstance(node, TableStore):
                 for index, op in enumerate(node.inputs):
                     if isinstance(op, MachineRegister):
                         node.set_input(
                             index, self.get_physical_reg(color_map, op))
                     elif is_constant(op):
                         node.set_input(index, op)
                     else:
                         raise NotImplementedError
def simplify_multiplication(node):
    """ Simplify a multiplication node between two fixed-point
        if one of them is a constant in {0, 1} """
    assert isinstance(node, Multiplication)
    assert is_fixed_format(node.get_input(0).get_precision())
    assert is_fixed_format(node.get_input(1).get_precision())

    lhs = node.get_input(0)
    rhs = node.get_input(1)
    if is_constant(lhs):
        if lhs.get_value() == 0:
            return Constant(0, precision=node.get_precision())
        elif lhs.get_value() == 1:
            return rhs
    elif is_constant(rhs):
        if rhs.get_value() == 0:
            return Constant(0, precision=node.get_precision())
        elif rhs.get_value() == 1:
            return lhs
    # no simplification found
    return None
Esempio n. 10
0
def format_does_fit(cst_optree, new_format):
    """ Test if @p cst_optree fits into the precision @p new_format """
    assert is_constant(cst_optree)
    assert is_fixed_point(new_format)
    # min_format is a dummy format used simply to check
    # if constant fits in new_format
    min_format = determine_minimal_fixed_format_cst(cst_optree.get_value())
    sign_bias = 1 if (new_format.get_signed() and not min_format.get_signed()) \
        else 0
    return (new_format.get_integer_size() - sign_bias) >= \
        min_format.get_integer_size() and \
        new_format.get_frac_size() >= min_format.get_frac_size() and \
           (new_format.get_signed() or not min_format.get_signed())
Esempio n. 11
0
    def generate_insn_from_node(self, node):
        """ generate a asmde.Instruction which corresponds to node """
        if isinstance(node, RegisterAssign):
            dst_reg = node.get_input(0)
            dst_reg_list = [
                sub_reg for reg in [dst_reg]
                for sub_reg in self.generate_allocatable_register(reg)
            ]
            value_node = node.get_input(1)
            if is_constant(value_node):
                src_reg_list = []
            else:
                src_reg_list = [
                    sub_reg for reg in extract_src_regs_from_node(value_node)
                    for sub_reg in self.generate_allocatable_register(reg)
                ]
            insn = asmde.Instruction(node,
                                     dbg_object=node,
                                     def_list=dst_reg_list,
                                     use_list=src_reg_list)
            return insn

        elif isinstance(node, TableStore):
            # TODO/FIXME: may need the generation of a shaddow dependency chain
            # to maintain TableStore/TableLoad relative order when required
            src_reg_list = [
                sub_reg for reg in extract_src_regs_from_node(node)
                for sub_reg in self.generate_allocatable_register(reg)
            ]
            insn = asmde.Instruction(node,
                                     dbg_object=node,
                                     use_list=src_reg_list)
            return insn

        else:
            Log.report(
                Log.Error,
                "node unsupported in AssemblySynthesizer.generate_insn_from_node: {}",
                node)
            raise NotImplementedError
Esempio n. 12
0
def legalize_mp_3elt_comparison(optree):
    """ Transform comparison on ML_Compound_FP_Format object into
        comparison on sub-fields """
    specifier = optree.specifier
    lhs = optree.get_input(0)
    rhs = optree.get_input(1)
    # TODO/FIXME: assume than multi-limb operand are normalized
    if specifier == Comparison.Equal:
        # renormalize if not constant
        lhs = lhs if is_constant(lhs) else BuildFromComponent(*Normalize_33(lhs.hi, lhs.me, lhs.lo, precision=lhs.precision.get_limb_precision(0)), precision=lhs.precision)
        rhs = rhs if is_constant(rhs) else BuildFromComponent(*Normalize_33(rhs.hi, rhs.me, rhs.lo, precision=rhs.precision.get_limb_precision(0)), precision=rhs.precision)
        return LogicalAnd(
            Comparison(lhs.hi, rhs.hi, specifier=Comparison.Equal, precision=ML_Bool),
            LogicalAnd(
                Comparison(lhs.me, rhs.me, specifier=Comparison.Equal, precision=ML_Bool),
                Comparison(lhs.lo, rhs.lo, specifier=Comparison.Equal, precision=ML_Bool),
                precision=ML_Bool
            ),
            precision=ML_Bool
        )
    elif specifier == Comparison.NotEqual:
        # renormalize if not constant
        lhs = lhs if is_constant(lhs) else BuildFromComponent(*Normalize_33(lhs.hi, lhs.me, lhs.lo, precision=lhs.precision.get_limb_precision(0)), precision=lhs.precision)
        rhs = rhs if is_constant(rhs) else BuildFromComponent(*Normalize_33(rhs.hi, rhs.me, rhs.lo, precision=rhs.precision.get_limb_precision(0)), precision=rhs.precision)
        return LogicalOr(
            Comparison(lhs.hi, rhs.hi, specifier=Comparison.NotEqual, precision=ML_Bool),
            LogicalOr(
                Comparison(lhs.me, rhs.me, specifier=Comparison.NotEqual, precision=ML_Bool),
                Comparison(lhs.lo, rhs.lo, specifier=Comparison.NotEqual, precision=ML_Bool),
                precision=ML_Bool
            ),
            precision=ML_Bool
        )
    elif specifier in [Comparison.LessOrEqual, Comparison.GreaterOrEqual, Comparison.Greater, Comparison.Less]:
        strict_specifier = {
            Comparison.Less: Comparison.Less,
            Comparison.Greater: Comparison.Greater,
            Comparison.LessOrEqual: Comparison.Less,
            Comparison.GreaterOrEqual: Comparison.Greater
        }[specifier]
        # renormalize if not constant
        lhs = lhs if is_constant(lhs) else BuildFromComponent(*Normalize_33(lhs.hi, lhs.me, lhs.lo, precision=lhs.precision.get_limb_precision(0)), precision=lhs.precision)
        rhs = rhs if is_constant(rhs) else BuildFromComponent(*Normalize_33(rhs.hi, rhs.me, rhs.lo, precision=rhs.precision.get_limb_precision(0)), precision=rhs.precision)
        return LogicalOr(
            Comparison(lhs.hi, rhs.hi, specifier=strict_specifier, precision=ML_Bool),
            LogicalAnd(
                Comparison(lhs.hi, rhs.hi, specifier=Comparison.Equal, precision=ML_Bool),
                LogicalOr(
                    Comparison(lhs.me, rhs.me, specifier=strict_specifier, precision=ML_Bool),
                    LogicalAnd(
                        Comparison(lhs.me, rhs.me, specifier=Comparison.Equal, precision=ML_Bool),
                        Comparison(lhs.lo, rhs.lo, specifier=specifier, precision=ML_Bool),
                        precision=ML_Bool
                    ),
                    precision=ML_Bool
                ),
                precision=ML_Bool
            ),
            precision=ML_Bool
        )
    else:
        Log.report(Log.Error, "unsupported specifier {} in legalize_mp_2elt_comparison", specifier)