Exemple #1
0
def _EagerRunModelSave(var_blobs, snapshot_path):
    path_input_op_conf, path_lbi = _GenModelIOPathInputOpConfAndRetLbi()
    path_input_blob_objects = oneflow._oneflow_internal.deprecated.BnInOp2BlobObject(
    )
    (
        BuildModelIOPathInputInstruction,
        BuildFeedPathInstruction,
    ) = _MakeModelIOPathInputBuilds(path_input_op_conf, snapshot_path,
                                    path_input_blob_objects)

    model_save_op_conf = _GenModelSaveOpConf(var_blobs, path_lbi)
    model_save_blob_objects = oneflow._oneflow_internal.deprecated.BnInOp2BlobObject(
    )

    def BuildModelSaveInstruction(builder):
        path_blob_object = path_input_blob_objects["out"]
        model_save_blob_objects["path"] = path_blob_object
        for i, blob in enumerate(var_blobs):
            model_save_blob_objects["in_{}".format(i)] = blob.blob_object

        op_attribute = op_infer_util.Infer(
            model_save_op_conf, ibn2blob_object=model_save_blob_objects)
        parallel_conf = path_blob_object.parallel_desc_symbol.parallel_conf
        cfg_op_attribute = oneflow._oneflow_internal.deprecated.MakeOpAttributeByString(
            str(op_attribute))
        builder.StatelessCall(
            cfg_op_attribute,
            parallel_conf,
            model_save_blob_objects,
            boxing_util.BoxingTo,
        )

    sess = session_ctx.GetDefaultSession()
    with scope_util.ScopeContext(scope_util.MakeScope(_BuildNotMirroredScope)):
        oneflow._oneflow_internal.deprecated.LogicalRun(
            BuildModelIOPathInputInstruction)
        oneflow._oneflow_internal.deprecated.LogicalRun(
            BuildFeedPathInstruction)
        oneflow._oneflow_internal.deprecated.LogicalRun(
            BuildModelSaveInstruction)
Exemple #2
0
def _EagerRunModelInit(var_op_conf):
    op_conf, _ = _GenModelInitOpConfAndRetLbi(var_op_conf)
    bn_in_op2blob_object = oneflow._oneflow_internal.deprecated.BnInOp2BlobObject(
    )

    def BuildModelInitInstruction(builder):
        upstream_signature = op_node_signature_pb.OpNodeSignature()
        op_conf.scope_symbol_id = oneflow.current_scope().symbol_id
        op_attribute = c_api_util.InferOpConf(op_conf, upstream_signature)
        parallel_conf = (
            oneflow.current_scope().device_parallel_desc_symbol.parallel_conf)
        cfg_op_attribute = oneflow._oneflow_internal.deprecated.MakeOpAttributeByString(
            str(op_attribute))
        builder.StatelessCall(cfg_op_attribute, parallel_conf,
                              bn_in_op2blob_object, boxing_util.BoxingTo)

    sess = session_ctx.GetDefaultSession()
    with scope_util.ScopeContext(scope_util.MakeScope(_BuildNotMirroredScope)):
        oneflow._oneflow_internal.deprecated.LogicalRun(
            BuildModelInitInstruction)

    return bn_in_op2blob_object["out_0"]
Exemple #3
0
def _GetReturnOpConfAndOutLbiAndScope(remote_blob, allow_cpu_return_op=True):
    op_conf = op_conf_util.OperatorConf()
    op_conf.name = id_util.UniqueStr("Return_")
    setattr(op_conf.return_conf, "in", remote_blob.unique_name)
    op_conf.return_conf.out = "out"
    if allow_cpu_return_op:
        op_conf.device_tag = "cpu"

    lbi = logical_blob_id_util.LogicalBlobId()
    lbi.op_name = op_conf.name
    lbi.blob_name = "out"

    parallel_conf = placement_proto_pb.ParallelConf()
    parallel_conf.CopyFrom(remote_blob.parallel_conf)

    def BuildScope(old_scope, builder):
        return old_scope.BuildWithNewParallelConf(builder, parallel_conf)

    sess = session_ctx.GetDefaultSession()
    scope = scope_util.MakeScope(BuildScope)

    return op_conf, lbi, scope
    def __init__(self, error_proto):
        assert error_proto.HasField("error_type")
        self.error_proto_ = error_proto
        self.error_summary_ = self.error_proto_.error_summary
        self.error_proto_.ClearField("error_summary")
        self.msg_ = self.error_proto_.msg
        self.error_proto_.ClearField("msg")
        resource = session_ctx.GetDefaultSession().config_proto.resource

        def get_op_kernel_not_found_error_str(error_proto):
            error_msg = str(self.error_proto_.op_kernel_not_found_error)
            error_msg = error_msg.replace("\\", "")
            error_msg = error_msg.replace("op_kernels_not_found_debug_str:",
                                          "")
            error_msg = "\n".join(
                [e.strip()[1:-1] for e in error_msg.strip().split("\n")])

            return (
                "\n\nFailure messages of registered kernels for current Op node: \n%s"
                % error_msg)

        def get_multiple_op_kernels_matched_error_str(error_proto):
            error_msg = str(
                self.error_proto_.multiple_op_kernels_matched_error)
            error_msg = error_msg.replace("\\", "")
            error_msg = error_msg.replace("matched_op_kernels_debug_str:", "")
            error_msg = "\n".join(
                [e.strip()[1:-1] for e in error_msg.strip().split("\n")])

            return (
                "\n\nThere exists multiple registered kernel candidates for current Op node: \n%s"
                % error_msg)

        self.error_type2get_error_str = {
            "op_kernel_not_found_error":
            get_op_kernel_not_found_error_str,
            "multiple_op_kernels_matched_error":
            get_multiple_op_kernels_matched_error_str,
        }
Exemple #5
0
def find_or_create_module(module_name, create, reuse=False):
    assert callable(create)
    sess = session_ctx.GetDefaultSession()
    job_name = oneflow.current_global_function_desc(
    ).job_config_proto.job_name()
    if job_name not in sess.job_name2module_name2module_:
        sess.job_name2module_name2module_[job_name] = {}
    module_name2module = sess.job_name2module_name2module_[job_name]
    if module_name not in module_name2module:
        module = create()
        assert isinstance(module, module_util.Module)
        module_name2module[module_name] = module
    else:
        if not reuse:
            assert module_name not in sess.existed_module_names_, (
                "duplicated module_name `%s' in global_function `%s'" %
                (module_name, job_name))
        else:
            # do nothing
            pass
    sess.existed_module_names_.add(module_name)
    return module_name2module[module_name]
Exemple #6
0
    def Save(self):
        self._check_input_output_name_conflict()
        for _, graph_builder in self.graph_builders_.items():
            if not graph_builder.finished:
                graph_builder.Finish()

        sess = session_ctx.GetDefaultSession()
        for graph_name, graph_def in self.proto.graphs.items():
            job = sess.Job(graph_name)
            graph_def.op_list.extend(list(job.net.op))

        if not os.path.exists(self.saved_model_dir_):
            os.makedirs(self.saved_model_dir_)

        if self.version_ is None:
            raise ValueError("model version is not set")

        version_dir = os.path.join(self.saved_model_dir_, str(self.version_))
        if os.path.exists(version_dir):
            raise ValueError(
                'Directory of model "{}" version "{}" already exist.'.format(
                    self.saved_model_dir_, self.version_))

        os.makedirs(version_dir)
        self.proto.version = self.version_

        checkpoint_path = os.path.join(version_dir, self.checkpoint_dir_)
        flow.checkpoint.save(checkpoint_path)
        self.proto.checkpoint_dir = self.checkpoint_dir_

        saved_model_pb_path = os.path.join(version_dir,
                                           self.saved_model_pb_filename_)
        with open(saved_model_pb_path, "wb") as writer:
            writer.write(self.saved_model_proto_.SerializeToString())

        saved_model_pbtxt_path = os.path.join(version_dir,
                                              self.saved_model_pbtxt_filename_)
        with open(saved_model_pbtxt_path, "wt") as writer:
            writer.write(text_format.MessageToString(self.saved_model_proto_))
Exemple #7
0
    def __init__(self, device_tag, machine_device_ids):
        self.device_tag_ = device_tag
        if isinstance(machine_device_ids, (list, tuple)) == False:
            machine_device_ids = [machine_device_ids]
        self.machine_device_ids_ = machine_device_ids
        self.default_parallel_conf_ = MakeParallelConf(
            self.device_tag_, self.machine_device_ids_
        )
        self.machine_id2device_id_list_ = MakeMachineId2DeviceIdList(
            self.default_parallel_conf_
        )
        self.parallel_size_ = GetParallelSize(self.machine_id2device_id_list_)
        self.scope_context_ = None
        sess = session_ctx.GetDefaultSession()
        # bypass the first PlacementScope for avoiding None old_scope
        if sess.is_running and len(sess.placement_scope_stack) > 0:

            def BuildScope(old_scope, builder):
                return old_scope.BuildWithNewParallelDesc(
                    builder, device_tag, machine_device_ids
                )

            self.scope_context_ = sess.NewCurrentScope(sess.MakeScope(BuildScope))
Exemple #8
0
def MakeLazyRefBlobObject(self, interface_op_name):
    sess = session_ctx.GetDefaultSession()
    op_attribute = sess.OpAttribute4InterfaceOpName(interface_op_name)
    assert len(op_attribute.output_bns) == 1
    obn = op_attribute.output_bns[0]

    parallel_conf = sess.ParallelConf4LazyInterfaceOpName(interface_op_name)
    if not isinstance(parallel_conf,
                      oneflow_api.oneflow.core.job.placement.ParallelConf):
        parallel_conf_cfg = placement_cfg.ParallelConf()
        parallel_conf_cfg.set_device_tag(parallel_conf.device_tag)
        for device_name in parallel_conf.device_name:
            parallel_conf_cfg.add_device_name(device_name)
        parallel_conf = parallel_conf_cfg
    blob_parallel_desc_sym = self.GetParallelDescSymbol(parallel_conf)

    op_arg_parallel_attr = oneflow_api.GetOpArgParallelAttribute(
        blob_parallel_desc_sym, str(op_attribute), obn)
    op_arg_blob_attr = oneflow_api.GetOpArgBlobAttribute(
        str(op_attribute), obn)

    blob_object = self.NewBlobObject(op_arg_parallel_attr, op_arg_blob_attr)
    self.LazyReference(blob_object, interface_op_name)
    return blob_object
Exemple #9
0
def nccl_fusion_max_ops(val):
    sess = session_ctx.GetDefaultSession()
    assert type(val) is int
    sess.config_proto.resource.collective_boxing_conf.nccl_fusion_max_ops = val
Exemple #10
0
def nccl_fusion_reduce_scatter(val):
    sess = session_ctx.GetDefaultSession()
    assert type(val) is bool
    sess.config_proto.resource.collective_boxing_conf.nccl_fusion_reduce_scatter = val
Exemple #11
0
def nccl_fusion_broadcast(val):
    sess = session_ctx.GetDefaultSession()
    assert type(val) is bool
    sess.config_proto.resource.collective_boxing_conf.nccl_fusion_broadcast = val
Exemple #12
0
def enable_fusion(val=True):
    sess = session_ctx.GetDefaultSession()
    assert type(val) is bool
    sess.config_proto.resource.collective_boxing_conf.enable_fusion = val
Exemple #13
0
def num_callback_threads(val):
    sess = session_ctx.GetDefaultSession()
    assert type(val) is int
    sess.config_proto.resource.collective_boxing_conf.num_callback_threads = val
Exemple #14
0
def rdma_recv_msg_buf_mbyte(val):
    sess = session_ctx.GetDefaultSession()
    assert type(val) is int
    sess.config_proto.resource.rdma_recv_msg_buf_mbyte = val
Exemple #15
0
def collect_act_event(val=True):
    sess = session_ctx.GetDefaultSession()
    assert type(val) is int
    sess.config_proto.profile_conf.collect_act_event = val
Exemple #16
0
def gpu_device_num(val):
    sess = session_ctx.GetDefaultSession()
    assert type(val) is int
    sess.config_proto.resource.gpu_device_num = val
Exemple #17
0
def enable_numa_aware_cuda_malloc_host(val):
    sess = session_ctx.GetDefaultSession()
    assert type(val) is bool
    sess.config_proto.resource.enable_numa_aware_cuda_malloc_host = val
Exemple #18
0
def enable_debug_mode(val):
    sess = session_ctx.GetDefaultSession()
    assert type(val) is bool
    sess.config_proto.resource.enable_debug_mode = val
Exemple #19
0
def compute_thread_pool_size(val):
    sess = session_ctx.GetDefaultSession()
    assert type(val) is int
    sess.config_proto.resource.compute_thread_pool_size = val
Exemple #20
0
def thread_enable_local_message_queue(val):
    sess = session_ctx.GetDefaultSession()
    assert type(val) is bool
    sess.config_proto.resource.thread_enable_local_message_queue = val
Exemple #21
0
def use_rdma(val=True):
    sess = session_ctx.GetDefaultSession()
    assert type(val) is bool
    sess.config_proto.resource.use_rdma = val
Exemple #22
0
def reserved_device_mem_mbyte(val):
    sess = session_ctx.GetDefaultSession()
    assert type(val) is int
    sess.config_proto.resource.reserved_device_mem_mbyte = val
Exemple #23
0
def nccl_enable_all_to_all(val):
    sess = session_ctx.GetDefaultSession()
    assert type(val) is bool
    sess.config_proto.resource.collective_boxing_conf.nccl_enable_all_to_all = val
Exemple #24
0
def load_library(val):
    assert type(val) is str
    sess = session_ctx.GetDefaultSession()
    sess.config_proto.load_lib_path.append(val)
Exemple #25
0
def machine_num(val):
    sess = session_ctx.GetDefaultSession()
    assert type(val) is int
    sess.config_proto.resource.machine_num = val
Exemple #26
0
def rdma_mem_block_mbyte(val):
    sess = session_ctx.GetDefaultSession()
    assert type(val) is int
    sess.config_proto.resource.rdma_mem_block_mbyte = val
Exemple #27
0
def save_downloaded_file_to_local_fs(val=True):
    sess = session_ctx.GetDefaultSession()
    assert type(val) is bool
    sess.config_proto.io_conf.save_downloaded_file_to_local_fs = val
Exemple #28
0
def enable_model_io_v2(val):
    sess = session_ctx.GetDefaultSession()
    assert type(val) is bool
    sess.config_proto.io_conf.enable_model_io_v2 = val
Exemple #29
0
def sync_default_session() -> None:
    session_ctx.GetDefaultSession().Sync()
Exemple #30
0
def persistence_buf_byte(val):
    sess = session_ctx.GetDefaultSession()
    assert type(val) is int
    sess.config_proto.io_conf.persistence_buf_byte = val