예제 #1
0
def Interpret(op_attribute, parallel_conf, blob_register):
    if op_attribute.op_conf.HasField("cast_to_mirrored_conf"):
        return MirroredCast(op_attribute, blob_register)
    if op_attribute.op_conf.HasField("cast_from_mirrored_conf"):
        return MirroredCast(op_attribute, blob_register)
    if type(parallel_conf) is str:
        parallel_conf = text_format.Parse(parallel_conf,
                                          placement_pb.ParallelConf())
    else:
        assert isinstance(parallel_conf, placement_pb.ParallelConf)
    if op_attribute.op_conf.HasField("distribute_split_conf"):
        return DistributeSplitOrClone(op_attribute, parallel_conf,
                                      blob_register)
    if op_attribute.op_conf.HasField("distribute_clone_conf"):
        return DistributeSplitOrClone(op_attribute, parallel_conf,
                                      blob_register)
    if op_attribute.op_conf.HasField("distribute_concat_conf"):
        return DistributeConcatOrAdd(op_attribute, parallel_conf,
                                     blob_register)
    if op_attribute.op_conf.HasField("distribute_add_conf"):
        return DistributeConcatOrAdd(op_attribute, parallel_conf,
                                     blob_register)
    if op_attribute.op_conf.HasField("variable_conf"):
        return _FindOrCreateVarBlobObject(op_attribute, parallel_conf,
                                          blob_register)
    if op_attribute.op_conf.HasField("foreign_watch_conf"):
        return _Watch(op_attribute, parallel_conf, blob_register)
    return _NaiveInterpret(op_attribute, parallel_conf, blob_register)
예제 #2
0
def MakeScopeSymbol(job_conf_str, parallel_conf_str, is_mirrored):
    job_conf = text_format.Parse(job_conf_str, job_conf_pb.JobConfigProto())
    parallel_conf = text_format.Parse(parallel_conf_str,
                                      placement_pb.ParallelConf())
    return scope_util.MakeInitialScope(job_conf, parallel_conf.device_tag,
                                       list(parallel_conf.device_name),
                                       is_mirrored).symbol_id
예제 #3
0
def GetConcatSplitBoxingParallelDescSymbol(builder, blob_parallel_desc_symbol,
                                           max_parallel_num):
    random_rank_id = random.randint(0, max_parallel_num - 1)
    parallel_conf = placement_pb.ParallelConf()
    parallel_conf.device_tag = "cpu"
    for machine_id, _ in blob_parallel_desc_symbol.machine_id2device_id_list.items(
    ):
        parallel_conf.device_name.append("%s:%s" %
                                         (machine_id, random_rank_id))
    return builder.GetParallelDescSymbol(parallel_conf)
예제 #4
0
def MakeParallelDescSymbol(parallel_conf_str):
    parallel_conf = text_format.Parse(parallel_conf_str, placement_pb.ParallelConf())
    symbol_id = None

    def BuildInstruction(builder):
        nonlocal symbol_id
        symbol_id = builder.GetParallelDescSymbol(parallel_conf).symbol_id

    vm_util.LogicalRun(BuildInstruction)
    return symbol_id
예제 #5
0
def JobBuildAndInferCtx_GetParallelConfFromProducerView(job_name, lbn):
    job_name = str(job_name)
    lbn = str(lbn)
    GetParallelConf = (
        oneflow._oneflow_internal.
        JobBuildAndInferCtx_GetSerializedParallelConfFromProducerView)
    serialized_parallel_conf = GetParallelConf(job_name, lbn)
    parallel_conf = text_format.Parse(serialized_parallel_conf,
                                      placement_pb.ParallelConf())
    return parallel_conf
예제 #6
0
def ReplaceDeviceTag(parallel_desc_symbol, device_tag, builder=None):
    assert parallel_desc_symbol.device_tag != device_tag
    parallel_conf = placement_pb.ParallelConf()
    parallel_conf.device_tag = device_tag
    for device_name in parallel_desc_symbol.parallel_conf.device_name:
        parallel_conf.device_name.append(device_name)
    if builder is None:
        return symbol_util.ParallelDescSymbol(parallel_desc_symbol.symbol_id,
                                              parallel_conf, device_tag)
    else:
        return builder.GetParallelDescSymbol(parallel_conf)
예제 #7
0
def JobBuildAndInferCtx_GetParallelConfFromProducerView(job_name, lbn):
    job_name = str(job_name)
    lbn = str(lbn)
    GetParallelConf = (
        oneflow_internal.JobBuildAndInferCtx_GetSerializedParallelConfFromProducerView
    )
    parallel_conf, error_str = GetParallelConf(job_name, lbn)
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
    return text_format.Parse(parallel_conf, placement_pb.ParallelConf())
예제 #8
0
def JobBuildAndInferCtx_GetParallelConfFromProducerView(job_name, lbn):
    job_name = str(job_name)
    lbn = str(lbn)
    GetParallelConf = (
        oneflow._oneflow_internal.
        JobBuildAndInferCtx_GetSerializedParallelConfFromProducerView)
    parallel_conf = GetParallelConf(job_name, lbn)
    parallel_conf = text_format.Parse(parallel_conf,
                                      placement_pb.ParallelConf())
    parallel_conf_cfg = placement_cfg.ParallelConf()
    parallel_conf_cfg.set_device_tag(parallel_conf.device_tag)
    for device_name in parallel_conf.device_name:
        parallel_conf_cfg.add_device_name(device_name)
    return parallel_conf_cfg
예제 #9
0
def JobBuildAndInferCtx_GetParallelConfFromProducerView(job_name, lbn):
    job_name = str(job_name)
    lbn = str(lbn)
    GetParallelConf = (
        oneflow_api.JobBuildAndInferCtx_GetSerializedParallelConfFromProducerView
    )
    parallel_conf = GetParallelConf(job_name, lbn)
    parallel_conf = text_format.Parse(parallel_conf, placement_pb.ParallelConf())
    # TODO(oyy) change temporary transformation after python code migrated into cpp code
    parallel_conf_cfg = placement_cfg.ParallelConf()
    parallel_conf_cfg.set_device_tag(parallel_conf.device_tag)
    for device_name in parallel_conf.device_name:
        parallel_conf_cfg.add_device_name(device_name)

    return parallel_conf_cfg
예제 #10
0
def MakeParallelConf(device_tag, machine_device_ids):
    assert isinstance(machine_device_ids, (list, tuple))
    device_names = []
    for machine_device_id in machine_device_ids:
        assert isinstance(
            machine_device_id,
            str), "type of machine_device_id (%s) is not string" % type(
                machine_device_id)
        assert re.match("^\d+:\d+(-\d+)?$", machine_device_id) is not None, (
            "machine_device_id: %s is not valid" % machine_device_id)
        device_names.append(machine_device_id)

    parallel_conf = placement_pb.ParallelConf()
    parallel_conf.device_tag = device_tag
    parallel_conf.device_name.extend(device_names)
    return parallel_conf
예제 #11
0
def MakeParallelConf(device_tag, machine_device_ids):
    assert isinstance(machine_device_ids, collections.Sized)
    device_names = []
    for machine_device_id in machine_device_ids:
        assert isinstance(
            machine_device_id,
            str), "type of machine_device_id (%s) is not string" % type(
                machine_device_id)
        assert re.match("^\d+:\d+(-\d+)?$", machine_device_id) is not None, (
            "machine_device_id: %s is not valid" % machine_device_id)
        pair = machine_device_id.split(":")
        device_names.append("%s:%s" % (pair[0], pair[1]))

    parallel_conf = placement_pb.ParallelConf()
    parallel_conf.device_tag = device_tag
    parallel_conf.device_name.extend(device_names)
    return parallel_conf
예제 #12
0
def RandomParallelIdPerMachine(parallel_desc_symbol,
                               device_tag=None,
                               builder=None):
    if device_tag is None:
        device_tag = parallel_desc_symbol.parallel_conf.device_tag
    assert device_tag is not None
    parallel_conf = placement_pb.ParallelConf()
    parallel_conf.device_tag = device_tag
    for machine_id, dev_ids in parallel_desc_symbol.machine_id2device_id_list.items(
    ):
        dev_id = dev_ids[random.randint(0, len(dev_ids) - 1)]
        parallel_conf.device_name.append("%s:%s" % (machine_id, dev_id))
    if builder is None:
        return symbol_util.ParallelDescSymbol(parallel_desc_symbol.symbol_id,
                                              parallel_conf, device_tag)
    else:
        return builder.GetParallelDescSymbol(parallel_conf)
예제 #13
0
def JobBuildAndInferCtx_GetParallelConfFromProducerView(job_name, lbn):
    job_name = str(job_name)
    lbn = str(lbn)
    GetParallelConf = (
        oneflow_internal.JobBuildAndInferCtx_GetSerializedParallelConfFromProducerView
    )
    parallel_conf, error_str = GetParallelConf(job_name, lbn)
    error = text_format.Parse(error_str, error_util.ErrorProto())
    if error.HasField("error_type"):
        raise JobBuildAndInferError(error)
    parallel_conf = text_format.Parse(parallel_conf, placement_pb.ParallelConf())
    # TODO(oyy) change temporary transformation after python code migrated into cpp code
    parallel_conf_cfg = placement_cfg.ParallelConf()
    parallel_conf_cfg.set_device_tag(parallel_conf.device_tag)
    for device_name in parallel_conf.device_name:
        parallel_conf_cfg.add_device_name(device_name)

    return parallel_conf_cfg
예제 #14
0
def InterNodeOneToMany(builder, produced_blob_object,
                       consumer_op_arg_parallel_attr):
    out_blobs = []
    consumer_dev_ids = (consumer_op_arg_parallel_attr.parallel_desc_symbol.
                        machine_id2device_id_list)
    for machine_id, device_ids in consumer_dev_ids.items():
        for device_id in device_ids:
            parallel_conf = placement_pb.ParallelConf()
            parallel_conf.device_tag = "cpu"
            parallel_conf.device_name.append("%s:%s" % (machine_id, device_id))
            parallel_desc_symbol = builder.GetParallelDescSymbol(parallel_conf)
            out_blob = builder.Build121To(produced_blob_object,
                                          parallel_desc_symbol)
            out_blobs.append(out_blob)

    return PackPhysicalBoxingBlobObjectsToLogical(
        builder,
        out_blobs,
        consumer_op_arg_parallel_attr,
        produced_blob_object.op_arg_blob_attr,
    )
예제 #15
0
 def BoxingToSingleDevice(builder):
     parallel_conf = placement_pb.ParallelConf()
     parallel_conf.device_tag = blob_object.parallel_desc_symbol.device_tag
     parallel_conf.device_name.append("{}:{}".format(0, 0))
     tmp_parallel_desc_symbol = builder.GetParallelDescSymbol(parallel_conf)
     tmp_op_arg_parallel_attr = op_arg_util.OpArgParallelAttribute(
         tmp_parallel_desc_symbol,
         blob_object.op_arg_parallel_attr.sbp_parallel,
         blob_object.op_arg_parallel_attr.opt_mirrored_parallel,
     )
     with oneflow.scope.placement(
         self.parallel_conf.device_tag, list(self.parallel_conf.device_name)
     ):
         tmp_blob_object = boxing_util.BoxingTo(
             builder, blob_object, tmp_op_arg_parallel_attr
         )
     nonlocal consistent_blob_name
     consistent_blob_name = "{}-consistent".format(self.logical_blob_name)
     if not blob_register.HasObject4BlobName(consistent_blob_name):
         blob_register.SetObject4BlobName(
             consistent_blob_name, tmp_blob_object
         )
예제 #16
0
def _GetReturnOpConfAndOutLbiAndScope(remote_blob, allow_cpu_return_op=True):
    op_conf = op_conf_util.OperatorConf()
    op_conf.name = id_util.UniqueStr("Return_")
    setattr(op_conf.return_conf, "in", remote_blob.unique_name)
    op_conf.return_conf.out = "out"
    if allow_cpu_return_op:
        op_conf.device_tag = "cpu"

    lbi = logical_blob_id_util.LogicalBlobId()
    lbi.op_name = op_conf.name
    lbi.blob_name = "out"

    parallel_conf = placement_proto_pb.ParallelConf()
    parallel_conf.CopyFrom(remote_blob.parallel_conf)

    def BuildScope(old_scope, builder):
        return old_scope.BuildWithNewParallelConf(builder, parallel_conf)

    sess = session_ctx.GetDefaultSession()
    scope = scope_util.MakeScope(BuildScope)

    return op_conf, lbi, scope
예제 #17
0
 def AppendPhyParallelDescSymbol(machine_id, device_id):
     parallel_conf = placement_pb_util.ParallelConf()
     parallel_conf.device_tag = device_tag
     parallel_conf.device_name.append("%d:%d" % (machine_id, device_id))
     phy_parallel_desc_symbols.append(
         self.GetParallelDescSymbol(parallel_conf))