def FirstMatched(builder, produced_blob_object, consumer_op_arg_parallel_attr): ctx = BoxingHobContext(produced_blob_object, consumer_op_arg_parallel_attr) for boxing_method in boxing_methods: hob_expr = enable_if.get_condition_hob(boxing_method) if not hob_expr(ctx): continue return boxing_method( builder, produced_blob_object, consumer_op_arg_parallel_attr )
def BoxingTo(builder, produced_blob_object, consumer_op_arg_parallel_attr): hob_context = BoxingHobContext(produced_blob_object, consumer_op_arg_parallel_attr) if enable_if.get_condition_hob(NoBoxing)(hob_context): return produced_blob_object producer_opt_mirrored_parallel = ( produced_blob_object.op_arg_parallel_attr.opt_mirrored_parallel) consumer_opt_mirrored_parallel = consumer_op_arg_parallel_attr.opt_mirrored_parallel assert producer_opt_mirrored_parallel == consumer_opt_mirrored_parallel, ( "\nproducer_op_arg_parallel_attr: %s\nconsumer_op_arg_parallel_attr: %s" % (produced_blob_object.op_arg_parallel_attr, consumer_op_arg_parallel_attr)) def default(get_failed_info, *args, **kwargs): raise NotImplementedError( "%s\n" "no boxing method found.\n" "logical_blob_name: %s\n" "x_arg_attribute: %s\n" "consumer_op_arg_parallel_attr: %s\n" % ( get_failed_info(), produced_blob_object.op_arg_blob_attr.logical_blob_name, produced_blob_object.op_arg_parallel_attr, consumer_op_arg_parallel_attr, )) global conditional_function_table function = enable_if.unique( conditional_function_table, context=BoxingHobContext(produced_blob_object, consumer_op_arg_parallel_attr), default=default, ) return function(builder, produced_blob_object, consumer_op_arg_parallel_attr)