Esempio n. 1
0
def ConstructNaiveBoxingOpConf(
    produced_blob_object,
    consumer_op_arg_parallel_attr,
    in_parallel_num,
    out_parallel_num,
):
    op_conf = op_conf_pb.OperatorConf()
    op_conf.name = "undefined_boxing_op_name"
    op_conf.device_tag = "cpu"
    op_conf.boxing_conf.lbi.op_name = "undefined_boxing_op_name"
    op_conf.boxing_conf.lbi.blob_name = "undefined_boxing_blob_name"
    op_conf.boxing_conf.in_num = in_parallel_num
    op_conf.boxing_conf.out_num = out_parallel_num
    in_sbp_parallel = produced_blob_object.op_arg_parallel_attr.sbp_parallel
    if in_sbp_parallel.HasField("split_parallel"):
        op_conf.boxing_conf.concat_box.axis = in_sbp_parallel.split_parallel.axis
    elif in_parallel_num == 1:
        op_conf.boxing_conf.concat_box.axis = 0
    else:
        assert in_sbp_parallel.HasField("partial_sum_parallel")
        op_conf.boxing_conf.add_box.SetInParent()
    out_sbp_parallel = consumer_op_arg_parallel_attr.sbp_parallel
    if out_sbp_parallel.HasField("split_parallel"):
        out_axis = out_sbp_parallel.split_parallel.axis
    else:
        assert out_parallel_num == 1
        out_axis = 0
    op_conf.boxing_conf.split_box.axis = out_axis
    shape = produced_blob_object.op_arg_blob_attr.shape
    op_conf.boxing_conf.split_box.part_num.extend(
        balanced_splitter.BalancedPartNums(shape[out_axis], out_parallel_num))
    bn_in_op2blob_object = {("in_%s" % i): produced_blob_object
                            for i in range(in_parallel_num)}
    return op_infer_util.Infer(op_conf, bn_in_op2blob_object)
Esempio n. 2
0
 def GetPhysicalOpArgBlobAttr(self, split_axis, parallel_num, parallel_id):
     blob_desc = blob_desc_pb.BlobDescProto()
     blob_desc.CopyFrom(self.blob_desc)
     physical_len = balanced_splitter.BalancedPartNums(
         self.shape[split_axis], parallel_num)[parallel_id]
     blob_desc.body.shape.dim[split_axis] = physical_len
     physical_blob_attr = OpArgBlobAttribute(
         self.batch_axis,
         blob_desc,
         self.logical_blob_name,
     )
     return physical_blob_attr
Esempio n. 3
0
def ConstructNaiveBoxingOpConf(
    produced_blob_object,
    consumer_op_arg_parallel_attr,
    in_parallel_num,
    out_parallel_num,
):
    op_conf = op_conf_pb.OperatorConf()
    op_conf.name = "undefined_boxing_op_name"
    op_conf.device_tag = "cpu"
    op_conf.boxing_conf.lbi.op_name = "undefined_boxing_op_name"
    op_conf.boxing_conf.lbi.blob_name = "undefined_boxing_blob_name"
    op_conf.boxing_conf.in_num = in_parallel_num
    op_conf.boxing_conf.out_num = out_parallel_num
    in_sbp_parallel = produced_blob_object.op_arg_parallel_attr.sbp_parallel
    if in_sbp_parallel.has_split_parallel():
        op_conf.boxing_conf.concat_box.axis = in_sbp_parallel.split_parallel(
        ).axis()
    elif in_parallel_num == 1:
        op_conf.boxing_conf.concat_box.axis = 0
    else:
        assert in_sbp_parallel.has_partial_sum_parallel()
        op_conf.boxing_conf.add_box.SetInParent()
    out_sbp_parallel = consumer_op_arg_parallel_attr.sbp_parallel
    if out_sbp_parallel.has_split_parallel():
        out_axis = out_sbp_parallel.split_parallel().axis()
    else:
        assert out_parallel_num == 1
        out_axis = 0
    op_conf.boxing_conf.split_box.axis = out_axis
    shape = produced_blob_object.op_arg_blob_attr.shape
    op_conf.boxing_conf.split_box.part_num.extend(
        balanced_splitter.BalancedPartNums(shape[out_axis], out_parallel_num))
    bn_in_op2blob_object = oneflow._oneflow_internal.deprecated.BnInOp2BlobObject(
    )
    for i in range(in_parallel_num):
        bn_in_op2blob_object["in_%s" % i] = produced_blob_object
    return op_infer_util.Infer(op_conf, bn_in_op2blob_object)