Пример #1
0
def _EagerRunModelSave(var_blobs, snapshot_path):
    path_input_op_conf, path_lbi = _GenModelIOPathInputOpConfAndRetLbi()
    path_input_blob_objects = oneflow._oneflow_internal.deprecated.BnInOp2BlobObject(
    )
    (
        BuildModelIOPathInputInstruction,
        BuildFeedPathInstruction,
    ) = _MakeModelIOPathInputBuilds(path_input_op_conf, snapshot_path,
                                    path_input_blob_objects)

    model_save_op_conf = _GenModelSaveOpConf(var_blobs, path_lbi)
    model_save_blob_objects = oneflow._oneflow_internal.deprecated.BnInOp2BlobObject(
    )

    def BuildModelSaveInstruction(builder):
        path_blob_object = path_input_blob_objects["out"]
        model_save_blob_objects["path"] = path_blob_object
        for i, blob in enumerate(var_blobs):
            model_save_blob_objects["in_{}".format(i)] = blob.blob_object

        op_attribute = op_infer_util.Infer(
            model_save_op_conf, ibn2blob_object=model_save_blob_objects)
        parallel_conf = path_blob_object.parallel_desc_symbol.parallel_conf
        cfg_op_attribute = oneflow._oneflow_internal.deprecated.MakeOpAttributeByString(
            str(op_attribute))
        builder.StatelessCall(
            cfg_op_attribute,
            parallel_conf,
            model_save_blob_objects,
            boxing_util.BoxingTo,
        )

    sess = session_ctx.GetDefaultSession()
    with scope_util.ScopeContext(scope_util.MakeScope(_BuildNotMirroredScope)):
        oneflow._oneflow_internal.deprecated.LogicalRun(
            BuildModelIOPathInputInstruction)
        oneflow._oneflow_internal.deprecated.LogicalRun(
            BuildFeedPathInstruction)
        oneflow._oneflow_internal.deprecated.LogicalRun(
            BuildModelSaveInstruction)
Пример #2
0
def _EagerRunModelInit(var_op_conf):
    op_conf, _ = _GenModelInitOpConfAndRetLbi(var_op_conf)
    bn_in_op2blob_object = oneflow._oneflow_internal.deprecated.BnInOp2BlobObject(
    )

    def BuildModelInitInstruction(builder):
        upstream_signature = op_node_signature_pb.OpNodeSignature()
        op_conf.scope_symbol_id = oneflow.current_scope().symbol_id
        op_attribute = c_api_util.InferOpConf(op_conf, upstream_signature)
        parallel_conf = (
            oneflow.current_scope().device_parallel_desc_symbol.parallel_conf)
        cfg_op_attribute = oneflow._oneflow_internal.deprecated.MakeOpAttributeByString(
            str(op_attribute))
        builder.StatelessCall(cfg_op_attribute, parallel_conf,
                              bn_in_op2blob_object, boxing_util.BoxingTo)

    sess = session_ctx.GetDefaultSession()
    with scope_util.ScopeContext(scope_util.MakeScope(_BuildNotMirroredScope)):
        oneflow._oneflow_internal.deprecated.LogicalRun(
            BuildModelInitInstruction)

    return bn_in_op2blob_object["out_0"]
Пример #3
0
    def open(self, job_name, signature, batch_size=None):
        self._check_status(self.SessionStatus.OPEN)
        c_api_util.JobBuildAndInferCtx_Open(job_name)

        self.set_job_signature(job_name, signature)
        if isinstance(batch_size, int):
            self.set_job_batch_size(job_name, batch_size)
        job_conf = self._get_job_conf(job_name)
        c_api_util.CurJobBuildAndInferCtx_SetJobConf(job_conf)

        tag_and_dev_ids = placement_util.GetDefaultMachineDeviceIds(
            self.config_proto_.resource
        )
        scope = scope_util.MakeInitialScope(
            job_conf, *tag_and_dev_ids, self.is_mirrored_
        )

        with runtime_mode.ModeScope(runtime_mode.GLOBAL_MODE):
            with scope_util.ScopeContext(scope):
                yield self

        oneflow_api.JobBuildAndInferCtx_Close()
Пример #4
0
def InterpretScope(session, function_desc, config_proto):
    job_conf = function_desc.job_config_proto
    job_conf.job_name = function_desc.job_func.__name__
    placement_scope = function_desc.function_attribute.default_placement_scope
    if placement_scope is None:
        tag_and_dev_ids = placement_util.GetDefaultMachineDeviceIds(session.resource)
    else:
        assert isinstance(placement_scope, placement_ctx.EmptyPlacementScope)
        tag_and_dev_ids = (
            placement_scope.device_tag,
            placement_scope.machine_device_ids,
        )
    distribute_strategy = function_desc.function_attribute.default_distribute_strategy
    if distribute_strategy is None:
        distribute_strategy = distribute_util.DistributeConsistentStrategy()
    is_mirrored = isinstance(
        distribute_strategy, distribute_util.DistributeMirroredStrategy
    )
    scope = scope_util.MakeInitialScope(job_conf, *tag_and_dev_ids, is_mirrored)
    with _JobBuildAndInferCtx(job_conf.job_name), distribute_strategy:
        c_api_util.CurJobBuildAndInferCtx_SetJobConf(job_conf)
        with runtime_mode.ModeScope(runtime_mode.GLOBAL_MODE):
            with scope_util.ScopeContext(scope):
                yield
Пример #5
0
def GetNormalModePlacementScope(device_tag, machine_device_ids):
    sess = session_ctx.GetDefaultSession()
    scope = scope_util.MakeScope(
        lambda old_scope, builder: old_scope.BuildWithNewParallelDesc(
            builder, device_tag, machine_device_ids))
    return scope_util.ScopeContext(scope)