示例#1
0
def Sequential(*boxing_methods, exclude=tuple(), middle_verbose=False):
    assert not isinstance(boxing_methods[-1], boxing_middle.BoxingToMiddle)
    composed = boxing_methods[-1]
    for boxing_to_middle in boxing_methods[-2::-1]:
        assert isinstance(boxing_to_middle, boxing_middle.BoxingToMiddle)
        if middle_verbose:
            middle_verbose_str = "middle op_arg_parallel_attr of %s->%s:" % (
                GetBoxingDebugString(boxing_to_middle.boxing_method),
                GetBoxingLeftDebugString(composed),
            )
        else:
            middle_verbose_str = None
        composed = ComposeBoxing(
            boxing_to_middle.boxing_method,
            composed,
            boxing_to_middle.get_middle_op_arg_parallel_attr,
            middle_verbose_str=middle_verbose_str,
        )
    if len(exclude) > 0:
        exclude_hob = enable_if.get_condition_hob(exclude[0])
        for method in exclude[1:]:
            exclude_hob = exclude_hob | enable_if.get_condition_hob(method)
        old_hob = enable_if.get_condition_hob(composed)
        enable_if.set_condition_hob(composed, old_hob & ~exclude_hob)
    return composed
示例#2
0
def ComposeBoxing(lhs_boxing,
                  rhs_boxing,
                  get_middle_op_arg_parallel_attr,
                  middle_verbose_str=None):
    composed_hob = boxing_hob.ComposeHob(
        enable_if.get_condition_hob(lhs_boxing),
        enable_if.get_condition_hob(rhs_boxing),
        get_middle_op_arg_parallel_attr=get_middle_op_arg_parallel_attr,
        middle_verbose_str=middle_verbose_str,
    )

    @enable_if.condition(composed_hob)
    def Composed(builder, produced_blob_object, consumer_op_arg_parallel_attr):
        tmp_op_arg_parallel_attr = get_middle_op_arg_parallel_attr(
            builder, produced_blob_object, consumer_op_arg_parallel_attr)
        tmp = lhs_boxing(builder, produced_blob_object,
                         tmp_op_arg_parallel_attr)
        return rhs_boxing(builder, tmp, consumer_op_arg_parallel_attr)

    Composed.__debug_str__ = "%s->%s" % (
        GetBoxingDebugString(lhs_boxing),
        GetBoxingDebugString(rhs_boxing),
    )
    Composed.__left_debug_str__ = GetBoxingLeftDebugString(lhs_boxing)
    Composed.__right_debug_str__ = GetBoxingRightDebugString(rhs_boxing)
    return Composed
示例#3
0
 def FirstMatched(builder, produced_blob_object, consumer_op_arg_parallel_attr):
     ctx = BoxingHobContext(produced_blob_object, consumer_op_arg_parallel_attr)
     for boxing_method in boxing_methods:
         hob_expr = enable_if.get_condition_hob(boxing_method)
         if not hob_expr(ctx):
             continue
         return boxing_method(
             builder, produced_blob_object, consumer_op_arg_parallel_attr
         )
示例#4
0
def FirstMatchedBoxing(*boxing_methods):
    hob_expr = enable_if.get_condition_hob(boxing_methods[0])
    for boxing_method in boxing_methods[1:]:
        hob_expr = hob_expr | enable_if.get_condition_hob(boxing_method)

    @enable_if.condition(hob_expr)
    def FirstMatched(builder, produced_blob_object, consumer_op_arg_parallel_attr):
        ctx = BoxingHobContext(produced_blob_object, consumer_op_arg_parallel_attr)
        for boxing_method in boxing_methods:
            hob_expr = enable_if.get_condition_hob(boxing_method)
            if not hob_expr(ctx):
                continue
            return boxing_method(
                builder, produced_blob_object, consumer_op_arg_parallel_attr
            )

    boxing_methods_names = [GetBoxingDebugString(m) for m in boxing_methods]
    FirstMatched.__debug_str__ = "(%s)" % (" | ".join(boxing_methods_names))
    return FirstMatched
示例#5
0
def BoxingTo(builder, produced_blob_object, consumer_op_arg_parallel_attr):
    hob_context = BoxingHobContext(produced_blob_object,
                                   consumer_op_arg_parallel_attr)
    if enable_if.get_condition_hob(NoBoxing)(hob_context):
        return produced_blob_object

    producer_opt_mirrored_parallel = (
        produced_blob_object.op_arg_parallel_attr.opt_mirrored_parallel)
    consumer_opt_mirrored_parallel = consumer_op_arg_parallel_attr.opt_mirrored_parallel
    assert producer_opt_mirrored_parallel == consumer_opt_mirrored_parallel, (
        "\nproducer_op_arg_parallel_attr: %s\nconsumer_op_arg_parallel_attr: %s"
        % (produced_blob_object.op_arg_parallel_attr,
           consumer_op_arg_parallel_attr))

    def default(get_failed_info, *args, **kwargs):
        raise NotImplementedError(
            "%s\n"
            "no boxing method found.\n"
            "logical_blob_name: %s\n"
            "x_arg_attribute: %s\n"
            "consumer_op_arg_parallel_attr: %s\n" % (
                get_failed_info(),
                produced_blob_object.op_arg_blob_attr.logical_blob_name,
                produced_blob_object.op_arg_parallel_attr,
                consumer_op_arg_parallel_attr,
            ))

    global conditional_function_table
    function = enable_if.unique(
        conditional_function_table,
        context=BoxingHobContext(produced_blob_object,
                                 consumer_op_arg_parallel_attr),
        default=default,
    )
    return function(builder, produced_blob_object,
                    consumer_op_arg_parallel_attr)