def _GetInterfaceBlobObject(builder, op_name): sess = session_ctx.GetDefaultSession() if oneflow._oneflow_internal.EagerExecutionEnabled(): return sess.var_name2var_blob[op_name].blob_object sess = session_ctx.GetDefaultSession() op_attribute = sess.OpAttribute4InterfaceOpName(op_name) cfg_op_attribute = oneflow._oneflow_internal.deprecated.MakeOpAttributeByString( str(op_attribute)) parallel_conf = sess.ParallelConf4LazyInterfaceOpName(op_name) if not isinstance( parallel_conf, oneflow._oneflow_internal.oneflow.core.job.placement.ParallelConf): parallel_conf_cfg = placement_cfg.ParallelConf() parallel_conf_cfg.set_device_tag(parallel_conf.device_tag) for device_name in parallel_conf.device_name: parallel_conf_cfg.add_device_name(device_name) if parallel_conf.HasField("hierarchy"): hierarchy = shape_proto_cfg.ShapeProto() for dim in parallel_conf.hierarchy.dim: hierarchy.add_dim(dim) assert hierarchy.dim_size() > 0 parallel_conf_cfg.mutable_hierarchy().CopyFrom(hierarchy) parallel_conf = parallel_conf_cfg blob_object = builder.MakeLazyRefBlobObject(op_name, cfg_op_attribute, parallel_conf) return blob_object
def GetConcatSplitBoxingParallelDescSymbol(builder, blob_parallel_desc_symbol, max_parallel_num): random_rank_id = random.randint(0, max_parallel_num - 1) parallel_conf = placement_cfg.ParallelConf() parallel_conf.set_device_tag("cpu") for (machine_id, _) in blob_parallel_desc_symbol.machine_id2device_id_list.items(): parallel_conf.add_device_name("@%s:%s" % (machine_id, random_rank_id)) return builder.GetParallelDescSymbol(parallel_conf)
def JobBuildAndInferCtx_GetParallelConfFromProducerView(job_name, lbn): job_name = str(job_name) lbn = str(lbn) GetParallelConf = ( oneflow._oneflow_internal. JobBuildAndInferCtx_GetSerializedParallelConfFromProducerView) parallel_conf = GetParallelConf(job_name, lbn) parallel_conf = text_format.Parse(parallel_conf, placement_pb.ParallelConf()) parallel_conf_cfg = placement_cfg.ParallelConf() parallel_conf_cfg.set_device_tag(parallel_conf.device_tag) for device_name in parallel_conf.device_name: parallel_conf_cfg.add_device_name(device_name) return parallel_conf_cfg
def MakeParallelConf(device_tag, machine_device_ids): assert isinstance(machine_device_ids, (list, tuple)) parallel_conf = placement_cfg.ParallelConf() parallel_conf.set_device_tag(device_tag) for machine_device_id in machine_device_ids: assert isinstance( machine_device_id, str), "type of machine_device_id (%s) is not string" % type( machine_device_id) assert re.match( "^\\d+:\\d+(-\\d+)?$", machine_device_id) is not None, ( "machine_device_id: %s is not valid" % machine_device_id) parallel_conf.add_device_name(machine_device_id) return parallel_conf
def ReplaceDeviceTag(parallel_desc_symbol, device_tag, builder=None): assert parallel_desc_symbol.device_tag != device_tag parallel_conf = placement_cfg.ParallelConf() parallel_conf.set_device_tag(device_tag) for device_name in parallel_desc_symbol.parallel_conf.device_name(): parallel_conf.add_device_name(device_name) hierarchy = shape_proto_cfg.ShapeProto() for dim in parallel_desc_symbol.hierarchy: hierarchy.add_dim(dim) assert hierarchy.dim_size() > 0 parallel_conf.mutable_hierarchy().CopyFrom(hierarchy) if builder is None: return oneflow._oneflow_internal.PlacementSymbol( parallel_desc_symbol.symbol_id, parallel_conf) else: return builder.GetParallelDescSymbol(parallel_conf)
def JobBuildAndInferCtx_GetParallelConfFromProducerView(job_name, lbn): job_name = str(job_name) lbn = str(lbn) GetParallelConf = ( oneflow._oneflow_internal. JobBuildAndInferCtx_GetSerializedParallelConfFromProducerView) parallel_conf = GetParallelConf(job_name, lbn) parallel_conf = text_format.Parse(parallel_conf, placement_pb.ParallelConf()) # TODO(oyy) change temporary transformation after python code migrated into cpp code parallel_conf_cfg = placement_cfg.ParallelConf() parallel_conf_cfg.set_device_tag(parallel_conf.device_tag) for device_name in parallel_conf.device_name: parallel_conf_cfg.add_device_name(device_name) return parallel_conf_cfg
def RandomParallelIdPerMachine(parallel_desc_symbol, device_tag=None, builder=None): if device_tag is None: device_tag = parallel_desc_symbol.parallel_conf.device_tag() assert device_tag is not None parallel_conf = placement_cfg.ParallelConf() parallel_conf.set_device_tag(device_tag) for machine_id, dev_ids in parallel_desc_symbol.machine_id2device_id_list.items( ): dev_id = dev_ids[random.randint(0, len(dev_ids) - 1)] parallel_conf.add_device_name("@%s:%s" % (machine_id, dev_id)) if builder is None: return oneflow._oneflow_internal.PlacementSymbol( parallel_desc_symbol.symbol_id, parallel_conf) else: return builder.GetParallelDescSymbol(parallel_conf)
def _GetReturnOpConfAndOutLbiAndScope(remote_blob, allow_cpu_return_op=True): op_conf = op_conf_util.OperatorConf() op_conf.name = id_util.UniqueStr("Return_") setattr(op_conf.return_conf, "in", remote_blob.unique_name) op_conf.return_conf.out = "out" if allow_cpu_return_op: op_conf.device_tag = "cpu" lbi = logical_blob_id_util.LogicalBlobId() lbi.op_name = op_conf.name lbi.blob_name = "out" parallel_conf = placement_cfg.ParallelConf() parallel_conf.CopyFrom(remote_blob.parallel_conf) def BuildScope(old_scope, builder): return builder.BuildScopeWithNewParallelConf(old_scope, parallel_conf) sess = session_ctx.GetDefaultSession() scope = scope_util.MakeScope(BuildScope) return (op_conf, lbi, scope)
def InterNodeOneToMany(builder, produced_blob_object, consumer_op_arg_parallel_attr): out_blobs = [] consumer_dev_ids = (consumer_op_arg_parallel_attr.parallel_desc_symbol. machine_id2device_id_list) for (machine_id, device_ids) in consumer_dev_ids.items(): for device_id in device_ids: parallel_conf = placement_cfg.ParallelConf() parallel_conf.set_device_tag("cpu") parallel_conf.add_device_name("@%s:%s" % (machine_id, device_id)) parallel_desc_symbol = builder.GetParallelDescSymbol(parallel_conf) out_blob = builder.Build121To(produced_blob_object, parallel_desc_symbol) out_blobs.append(out_blob) return PackPhysicalBoxingBlobObjectsToLogical( builder, out_blobs, consumer_op_arg_parallel_attr, produced_blob_object.op_arg_blob_attr, )
def BoxingToSingleDevice(builder): parallel_conf = placement_cfg.ParallelConf() parallel_conf.set_device_tag(blob_object.parallel_desc_symbol.device_tag) parallel_conf.add_device_name("{}:{}".format(0, 0)) tmp_parallel_desc_symbol = builder.GetParallelDescSymbol(parallel_conf) tmp_op_arg_parallel_attr = oneflow._oneflow_internal.OpArgParallelAttribute( tmp_parallel_desc_symbol, str(blob_object.op_arg_parallel_attr.sbp_parallel), str(blob_object.op_arg_parallel_attr.opt_mirrored_parallel), ) with oneflow.scope.placement( parallel_conf.device_tag(), list(parallel_conf.device_name()), ): tmp_blob_object = boxing_util.BoxingTo( builder, blob_object, tmp_op_arg_parallel_attr ) nonlocal consistent_blob_name consistent_blob_name = tmp_name if not blob_register.HasObject4BlobName(consistent_blob_name): blob_register.SetObject4BlobName(consistent_blob_name, tmp_blob_object)