def _GetInterfaceBlobObject(builder, op_name):
    sess = session_ctx.GetDefaultSession()
    if oneflow._oneflow_internal.EagerExecutionEnabled():
        return sess.var_name2var_blob[op_name].blob_object
    sess = session_ctx.GetDefaultSession()
    op_attribute = sess.OpAttribute4InterfaceOpName(op_name)
    cfg_op_attribute = oneflow._oneflow_internal.deprecated.MakeOpAttributeByString(
        str(op_attribute))
    parallel_conf = sess.ParallelConf4LazyInterfaceOpName(op_name)
    if not isinstance(
            parallel_conf,
            oneflow._oneflow_internal.oneflow.core.job.placement.ParallelConf):
        parallel_conf_cfg = placement_cfg.ParallelConf()
        parallel_conf_cfg.set_device_tag(parallel_conf.device_tag)
        for device_name in parallel_conf.device_name:
            parallel_conf_cfg.add_device_name(device_name)
        if parallel_conf.HasField("hierarchy"):
            hierarchy = shape_proto_cfg.ShapeProto()
            for dim in parallel_conf.hierarchy.dim:
                hierarchy.add_dim(dim)
            assert hierarchy.dim_size() > 0
            parallel_conf_cfg.mutable_hierarchy().CopyFrom(hierarchy)
        parallel_conf = parallel_conf_cfg

    blob_object = builder.MakeLazyRefBlobObject(op_name, cfg_op_attribute,
                                                parallel_conf)
    return blob_object
Beispiel #2
0
def GetConcatSplitBoxingParallelDescSymbol(builder, blob_parallel_desc_symbol,
                                           max_parallel_num):
    random_rank_id = random.randint(0, max_parallel_num - 1)
    parallel_conf = placement_cfg.ParallelConf()
    parallel_conf.set_device_tag("cpu")
    for (machine_id,
         _) in blob_parallel_desc_symbol.machine_id2device_id_list.items():
        parallel_conf.add_device_name("@%s:%s" % (machine_id, random_rank_id))
    return builder.GetParallelDescSymbol(parallel_conf)
Beispiel #3
0
def JobBuildAndInferCtx_GetParallelConfFromProducerView(job_name, lbn):
    job_name = str(job_name)
    lbn = str(lbn)
    GetParallelConf = (
        oneflow._oneflow_internal.
        JobBuildAndInferCtx_GetSerializedParallelConfFromProducerView)
    parallel_conf = GetParallelConf(job_name, lbn)
    parallel_conf = text_format.Parse(parallel_conf,
                                      placement_pb.ParallelConf())
    parallel_conf_cfg = placement_cfg.ParallelConf()
    parallel_conf_cfg.set_device_tag(parallel_conf.device_tag)
    for device_name in parallel_conf.device_name:
        parallel_conf_cfg.add_device_name(device_name)
    return parallel_conf_cfg
Beispiel #4
0
def MakeParallelConf(device_tag, machine_device_ids):
    assert isinstance(machine_device_ids, (list, tuple))
    parallel_conf = placement_cfg.ParallelConf()
    parallel_conf.set_device_tag(device_tag)
    for machine_device_id in machine_device_ids:
        assert isinstance(
            machine_device_id,
            str), "type of machine_device_id (%s) is not string" % type(
                machine_device_id)
        assert re.match(
            "^\\d+:\\d+(-\\d+)?$", machine_device_id) is not None, (
                "machine_device_id: %s is not valid" % machine_device_id)
        parallel_conf.add_device_name(machine_device_id)
    return parallel_conf
Beispiel #5
0
def ReplaceDeviceTag(parallel_desc_symbol, device_tag, builder=None):
    assert parallel_desc_symbol.device_tag != device_tag
    parallel_conf = placement_cfg.ParallelConf()
    parallel_conf.set_device_tag(device_tag)
    for device_name in parallel_desc_symbol.parallel_conf.device_name():
        parallel_conf.add_device_name(device_name)
    hierarchy = shape_proto_cfg.ShapeProto()
    for dim in parallel_desc_symbol.hierarchy:
        hierarchy.add_dim(dim)
    assert hierarchy.dim_size() > 0
    parallel_conf.mutable_hierarchy().CopyFrom(hierarchy)
    if builder is None:
        return oneflow._oneflow_internal.PlacementSymbol(
            parallel_desc_symbol.symbol_id, parallel_conf)
    else:
        return builder.GetParallelDescSymbol(parallel_conf)
Beispiel #6
0
def JobBuildAndInferCtx_GetParallelConfFromProducerView(job_name, lbn):
    job_name = str(job_name)
    lbn = str(lbn)
    GetParallelConf = (
        oneflow._oneflow_internal.
        JobBuildAndInferCtx_GetSerializedParallelConfFromProducerView)
    parallel_conf = GetParallelConf(job_name, lbn)
    parallel_conf = text_format.Parse(parallel_conf,
                                      placement_pb.ParallelConf())
    # TODO(oyy) change temporary transformation after python code migrated into cpp code
    parallel_conf_cfg = placement_cfg.ParallelConf()
    parallel_conf_cfg.set_device_tag(parallel_conf.device_tag)
    for device_name in parallel_conf.device_name:
        parallel_conf_cfg.add_device_name(device_name)

    return parallel_conf_cfg
Beispiel #7
0
def RandomParallelIdPerMachine(parallel_desc_symbol,
                               device_tag=None,
                               builder=None):
    if device_tag is None:
        device_tag = parallel_desc_symbol.parallel_conf.device_tag()
    assert device_tag is not None
    parallel_conf = placement_cfg.ParallelConf()
    parallel_conf.set_device_tag(device_tag)
    for machine_id, dev_ids in parallel_desc_symbol.machine_id2device_id_list.items(
    ):
        dev_id = dev_ids[random.randint(0, len(dev_ids) - 1)]
        parallel_conf.add_device_name("@%s:%s" % (machine_id, dev_id))
    if builder is None:
        return oneflow._oneflow_internal.PlacementSymbol(
            parallel_desc_symbol.symbol_id, parallel_conf)
    else:
        return builder.GetParallelDescSymbol(parallel_conf)
Beispiel #8
0
def _GetReturnOpConfAndOutLbiAndScope(remote_blob, allow_cpu_return_op=True):
    op_conf = op_conf_util.OperatorConf()
    op_conf.name = id_util.UniqueStr("Return_")
    setattr(op_conf.return_conf, "in", remote_blob.unique_name)
    op_conf.return_conf.out = "out"
    if allow_cpu_return_op:
        op_conf.device_tag = "cpu"
    lbi = logical_blob_id_util.LogicalBlobId()
    lbi.op_name = op_conf.name
    lbi.blob_name = "out"
    parallel_conf = placement_cfg.ParallelConf()
    parallel_conf.CopyFrom(remote_blob.parallel_conf)

    def BuildScope(old_scope, builder):
        return builder.BuildScopeWithNewParallelConf(old_scope, parallel_conf)

    sess = session_ctx.GetDefaultSession()
    scope = scope_util.MakeScope(BuildScope)
    return (op_conf, lbi, scope)
Beispiel #9
0
def InterNodeOneToMany(builder, produced_blob_object,
                       consumer_op_arg_parallel_attr):
    out_blobs = []
    consumer_dev_ids = (consumer_op_arg_parallel_attr.parallel_desc_symbol.
                        machine_id2device_id_list)
    for (machine_id, device_ids) in consumer_dev_ids.items():
        for device_id in device_ids:
            parallel_conf = placement_cfg.ParallelConf()
            parallel_conf.set_device_tag("cpu")
            parallel_conf.add_device_name("@%s:%s" % (machine_id, device_id))
            parallel_desc_symbol = builder.GetParallelDescSymbol(parallel_conf)
            out_blob = builder.Build121To(produced_blob_object,
                                          parallel_desc_symbol)
            out_blobs.append(out_blob)
    return PackPhysicalBoxingBlobObjectsToLogical(
        builder,
        out_blobs,
        consumer_op_arg_parallel_attr,
        produced_blob_object.op_arg_blob_attr,
    )
Beispiel #10
0
 def BoxingToSingleDevice(builder):
     parallel_conf = placement_cfg.ParallelConf()
     parallel_conf.set_device_tag(blob_object.parallel_desc_symbol.device_tag)
     parallel_conf.add_device_name("{}:{}".format(0, 0))
     tmp_parallel_desc_symbol = builder.GetParallelDescSymbol(parallel_conf)
     tmp_op_arg_parallel_attr = oneflow._oneflow_internal.OpArgParallelAttribute(
         tmp_parallel_desc_symbol,
         str(blob_object.op_arg_parallel_attr.sbp_parallel),
         str(blob_object.op_arg_parallel_attr.opt_mirrored_parallel),
     )
     with oneflow.scope.placement(
         parallel_conf.device_tag(), list(parallel_conf.device_name()),
     ):
         tmp_blob_object = boxing_util.BoxingTo(
             builder, blob_object, tmp_op_arg_parallel_attr
         )
     nonlocal consistent_blob_name
     consistent_blob_name = tmp_name
     if not blob_register.HasObject4BlobName(consistent_blob_name):
         blob_register.SetObject4BlobName(consistent_blob_name, tmp_blob_object)