def ConstructNaiveBoxingOpConf( produced_blob_object, consumer_op_arg_parallel_attr, in_parallel_num, out_parallel_num, ): op_conf = op_conf_pb.OperatorConf() op_conf.name = "undefined_boxing_op_name" op_conf.device_tag = "cpu" op_conf.boxing_conf.lbi.op_name = "undefined_boxing_op_name" op_conf.boxing_conf.lbi.blob_name = "undefined_boxing_blob_name" op_conf.boxing_conf.in_num = in_parallel_num op_conf.boxing_conf.out_num = out_parallel_num in_sbp_parallel = produced_blob_object.op_arg_parallel_attr.sbp_parallel if in_sbp_parallel.HasField("split_parallel"): op_conf.boxing_conf.concat_box.axis = in_sbp_parallel.split_parallel.axis elif in_parallel_num == 1: op_conf.boxing_conf.concat_box.axis = 0 else: assert in_sbp_parallel.HasField("partial_sum_parallel") op_conf.boxing_conf.add_box.SetInParent() out_sbp_parallel = consumer_op_arg_parallel_attr.sbp_parallel if out_sbp_parallel.HasField("split_parallel"): out_axis = out_sbp_parallel.split_parallel.axis else: assert out_parallel_num == 1 out_axis = 0 op_conf.boxing_conf.split_box.axis = out_axis shape = produced_blob_object.op_arg_blob_attr.shape op_conf.boxing_conf.split_box.part_num.extend( balanced_splitter.BalancedPartNums(shape[out_axis], out_parallel_num)) bn_in_op2blob_object = {("in_%s" % i): produced_blob_object for i in range(in_parallel_num)} return op_infer_util.Infer(op_conf, bn_in_op2blob_object)
def GetPhysicalOpArgBlobAttr(self, split_axis, parallel_num, parallel_id): blob_desc = blob_desc_pb.BlobDescProto() blob_desc.CopyFrom(self.blob_desc) physical_len = balanced_splitter.BalancedPartNums( self.shape[split_axis], parallel_num)[parallel_id] blob_desc.body.shape.dim[split_axis] = physical_len physical_blob_attr = OpArgBlobAttribute( self.batch_axis, blob_desc, self.logical_blob_name, ) return physical_blob_attr
def ConstructNaiveBoxingOpConf( produced_blob_object, consumer_op_arg_parallel_attr, in_parallel_num, out_parallel_num, ): op_conf = op_conf_pb.OperatorConf() op_conf.name = "undefined_boxing_op_name" op_conf.device_tag = "cpu" op_conf.boxing_conf.lbi.op_name = "undefined_boxing_op_name" op_conf.boxing_conf.lbi.blob_name = "undefined_boxing_blob_name" op_conf.boxing_conf.in_num = in_parallel_num op_conf.boxing_conf.out_num = out_parallel_num in_sbp_parallel = produced_blob_object.op_arg_parallel_attr.sbp_parallel if in_sbp_parallel.has_split_parallel(): op_conf.boxing_conf.concat_box.axis = in_sbp_parallel.split_parallel( ).axis() elif in_parallel_num == 1: op_conf.boxing_conf.concat_box.axis = 0 else: assert in_sbp_parallel.has_partial_sum_parallel() op_conf.boxing_conf.add_box.SetInParent() out_sbp_parallel = consumer_op_arg_parallel_attr.sbp_parallel if out_sbp_parallel.has_split_parallel(): out_axis = out_sbp_parallel.split_parallel().axis() else: assert out_parallel_num == 1 out_axis = 0 op_conf.boxing_conf.split_box.axis = out_axis shape = produced_blob_object.op_arg_blob_attr.shape op_conf.boxing_conf.split_box.part_num.extend( balanced_splitter.BalancedPartNums(shape[out_axis], out_parallel_num)) bn_in_op2blob_object = oneflow._oneflow_internal.deprecated.BnInOp2BlobObject( ) for i in range(in_parallel_num): bn_in_op2blob_object["in_%s" % i] = produced_blob_object return op_infer_util.Infer(op_conf, bn_in_op2blob_object)