コード例 #1
0
 def FirstMatched(builder, produced_blob_object, consumer_op_arg_parallel_attr):
     ctx = BoxingHobContext(produced_blob_object, consumer_op_arg_parallel_attr)
     for boxing_method in boxing_methods:
         hob_expr = enable_if.get_condition_hob(boxing_method)
         if not hob_expr(ctx):
             continue
         return boxing_method(
             builder, produced_blob_object, consumer_op_arg_parallel_attr
         )
コード例 #2
0
ファイル: boxing_util.py プロジェクト: zhouyuegit/oneflow
def BoxingTo(builder, produced_blob_object, consumer_op_arg_parallel_attr):
    hob_context = BoxingHobContext(produced_blob_object,
                                   consumer_op_arg_parallel_attr)
    if enable_if.get_condition_hob(NoBoxing)(hob_context):
        return produced_blob_object

    producer_opt_mirrored_parallel = (
        produced_blob_object.op_arg_parallel_attr.opt_mirrored_parallel)
    consumer_opt_mirrored_parallel = consumer_op_arg_parallel_attr.opt_mirrored_parallel
    assert producer_opt_mirrored_parallel == consumer_opt_mirrored_parallel, (
        "\nproducer_op_arg_parallel_attr: %s\nconsumer_op_arg_parallel_attr: %s"
        % (produced_blob_object.op_arg_parallel_attr,
           consumer_op_arg_parallel_attr))

    def default(get_failed_info, *args, **kwargs):
        raise NotImplementedError(
            "%s\n"
            "no boxing method found.\n"
            "logical_blob_name: %s\n"
            "x_arg_attribute: %s\n"
            "consumer_op_arg_parallel_attr: %s\n" % (
                get_failed_info(),
                produced_blob_object.op_arg_blob_attr.logical_blob_name,
                produced_blob_object.op_arg_parallel_attr,
                consumer_op_arg_parallel_attr,
            ))

    global conditional_function_table
    function = enable_if.unique(
        conditional_function_table,
        context=BoxingHobContext(produced_blob_object,
                                 consumer_op_arg_parallel_attr),
        default=default,
    )
    return function(builder, produced_blob_object,
                    consumer_op_arg_parallel_attr)