def Sequential(*boxing_methods, exclude=tuple(), middle_verbose=False): assert not isinstance(boxing_methods[-1], boxing_middle.BoxingToMiddle) composed = boxing_methods[-1] for boxing_to_middle in boxing_methods[-2::-1]: assert isinstance(boxing_to_middle, boxing_middle.BoxingToMiddle) if middle_verbose: middle_verbose_str = "middle op_arg_parallel_attr of %s->%s:" % ( GetBoxingDebugString(boxing_to_middle.boxing_method), GetBoxingLeftDebugString(composed), ) else: middle_verbose_str = None composed = ComposeBoxing( boxing_to_middle.boxing_method, composed, boxing_to_middle.get_middle_op_arg_parallel_attr, middle_verbose_str=middle_verbose_str, ) if len(exclude) > 0: exclude_hob = enable_if.get_condition_hob(exclude[0]) for method in exclude[1:]: exclude_hob = exclude_hob | enable_if.get_condition_hob(method) old_hob = enable_if.get_condition_hob(composed) enable_if.set_condition_hob(composed, old_hob & ~exclude_hob) return composed
def ComposeBoxing(lhs_boxing, rhs_boxing, get_middle_op_arg_parallel_attr, middle_verbose_str=None): composed_hob = boxing_hob.ComposeHob( enable_if.get_condition_hob(lhs_boxing), enable_if.get_condition_hob(rhs_boxing), get_middle_op_arg_parallel_attr=get_middle_op_arg_parallel_attr, middle_verbose_str=middle_verbose_str, ) @enable_if.condition(composed_hob) def Composed(builder, produced_blob_object, consumer_op_arg_parallel_attr): tmp_op_arg_parallel_attr = get_middle_op_arg_parallel_attr( builder, produced_blob_object, consumer_op_arg_parallel_attr) tmp = lhs_boxing(builder, produced_blob_object, tmp_op_arg_parallel_attr) return rhs_boxing(builder, tmp, consumer_op_arg_parallel_attr) Composed.__debug_str__ = "%s->%s" % ( GetBoxingDebugString(lhs_boxing), GetBoxingDebugString(rhs_boxing), ) Composed.__left_debug_str__ = GetBoxingLeftDebugString(lhs_boxing) Composed.__right_debug_str__ = GetBoxingRightDebugString(rhs_boxing) return Composed
def FirstMatched(builder, produced_blob_object, consumer_op_arg_parallel_attr): ctx = BoxingHobContext(produced_blob_object, consumer_op_arg_parallel_attr) for boxing_method in boxing_methods: hob_expr = enable_if.get_condition_hob(boxing_method) if not hob_expr(ctx): continue return boxing_method( builder, produced_blob_object, consumer_op_arg_parallel_attr )
def FirstMatchedBoxing(*boxing_methods): hob_expr = enable_if.get_condition_hob(boxing_methods[0]) for boxing_method in boxing_methods[1:]: hob_expr = hob_expr | enable_if.get_condition_hob(boxing_method) @enable_if.condition(hob_expr) def FirstMatched(builder, produced_blob_object, consumer_op_arg_parallel_attr): ctx = BoxingHobContext(produced_blob_object, consumer_op_arg_parallel_attr) for boxing_method in boxing_methods: hob_expr = enable_if.get_condition_hob(boxing_method) if not hob_expr(ctx): continue return boxing_method( builder, produced_blob_object, consumer_op_arg_parallel_attr ) boxing_methods_names = [GetBoxingDebugString(m) for m in boxing_methods] FirstMatched.__debug_str__ = "(%s)" % (" | ".join(boxing_methods_names)) return FirstMatched
def BoxingTo(builder, produced_blob_object, consumer_op_arg_parallel_attr): hob_context = BoxingHobContext(produced_blob_object, consumer_op_arg_parallel_attr) if enable_if.get_condition_hob(NoBoxing)(hob_context): return produced_blob_object producer_opt_mirrored_parallel = ( produced_blob_object.op_arg_parallel_attr.opt_mirrored_parallel) consumer_opt_mirrored_parallel = consumer_op_arg_parallel_attr.opt_mirrored_parallel assert producer_opt_mirrored_parallel == consumer_opt_mirrored_parallel, ( "\nproducer_op_arg_parallel_attr: %s\nconsumer_op_arg_parallel_attr: %s" % (produced_blob_object.op_arg_parallel_attr, consumer_op_arg_parallel_attr)) def default(get_failed_info, *args, **kwargs): raise NotImplementedError( "%s\n" "no boxing method found.\n" "logical_blob_name: %s\n" "x_arg_attribute: %s\n" "consumer_op_arg_parallel_attr: %s\n" % ( get_failed_info(), produced_blob_object.op_arg_blob_attr.logical_blob_name, produced_blob_object.op_arg_parallel_attr, consumer_op_arg_parallel_attr, )) global conditional_function_table function = enable_if.unique( conditional_function_table, context=BoxingHobContext(produced_blob_object, consumer_op_arg_parallel_attr), default=default, ) return function(builder, produced_blob_object, consumer_op_arg_parallel_attr)