Ejemplo n.º 1
0
def MakeBroadcastOpArgParallelAttribute(parallel_desc_symbol):
    sbp_parallel = sbp_parallel_pb.SbpParallel()
    sbp_parallel.broadcast_parallel.SetInParent()
    opt_mirrored_parallel = mirrored_parallel_pb.OptMirroredParallel()
    return OpArgParallelAttribute(
        parallel_desc_symbol=parallel_desc_symbol,
        sbp_parallel=sbp_parallel,
        opt_mirrored_parallel=opt_mirrored_parallel,
    )
Ejemplo n.º 2
0
 def ToInterfaceBlobConf(self):
     interface_blob_conf = inter_face_blob_conf_util.InterfaceBlobConf()
     interface_blob_conf.shape.dim.extend(self.shape_)
     interface_blob_conf.data_type = oneflow._oneflow_internal.deprecated.GetProtoDtype4OfDtype(
         self.dtype_
     )
     interface_blob_conf.is_dynamic = self.is_dynamic
     sbp_parallel = sbp_parallel_pb.SbpParallel()
     sbp_parallel.split_parallel.axis = 0
     interface_blob_conf.nd_sbp.sbp_parallel.extend([sbp_parallel])
     return interface_blob_conf
Ejemplo n.º 3
0
 def ToInterfaceBlobConf(self):
     interface_blob_conf = inter_face_blob_conf_util.InterfaceBlobConf()
     interface_blob_conf.shape.dim.extend(self.shape_)
     interface_blob_conf.data_type = oneflow_api.deprecated.GetProtoDtype4OfDtype(
         self.dtype_
     )
     interface_blob_conf.is_dynamic = self.is_dynamic
     # NOTE(chengcheng): rm batch_axis, so set split_axis always = 0 for safe. will support
     #     set sbp in future, or will delete in multi-client
     sbp_parallel = sbp_parallel_pb.SbpParallel()
     sbp_parallel.split_parallel.axis = 0
     interface_blob_conf.parallel_distribution.sbp_parallel.extend([sbp_parallel])
     return interface_blob_conf
Ejemplo n.º 4
0
def GetInterfaceBlobConf(job_name, lbn, blob_conf=None):
    assert isinstance(job_name, str)
    assert isinstance(lbn, str)
    if blob_conf is None:
        blob_conf = interface_blob_conf_pb.InterfaceBlobConf()
    else:
        assert isinstance(blob_conf, interface_blob_conf_pb.InterfaceBlobConf)
    shape = c_api_util.JobBuildAndInferCtx_GetStaticShape(job_name, lbn)
    dtype = c_api_util.JobBuildAndInferCtx_GetDataType(job_name, lbn)
    split_axis = c_api_util.JobBuildAndInferCtx_GetSplitAxisFromProducerView(
        job_name, lbn)
    is_dynamic = c_api_util.JobBuildAndInferCtx_IsDynamic(job_name, lbn)
    blob_conf.shape.dim.extend(shape)
    blob_conf.data_type = dtype
    if split_axis is not None:
        sbp_parallel = sbp_parallel_pb.SbpParallel()
        sbp_parallel.split_parallel.axis = split_axis
        blob_conf.nd_sbp.sbp_parallel.extend([sbp_parallel])
    blob_conf.is_dynamic = is_dynamic
    return blob_conf
Ejemplo n.º 5
0
def BroadcastParallel(builder, produced_blob_object,
                      consumer_op_arg_parallel_attr):
    sbp_parallel = sbp_parallel_pb.SbpParallel()
    sbp_parallel.broadcast_parallel.SetInParent()
    return sbp_parallel