Exemplo n.º 1
0
def InterpretScope(session, function_desc, config_proto):
    job_conf = function_desc.job_config_proto
    job_conf.set_job_name(function_desc.job_func.__name__)
    placement_scope = function_desc.function_attribute.default_placement_scope
    if placement_scope is None:
        tag_and_dev_ids = placement_util.GetDefaultMachineDeviceIds(
            session.resource)
    else:
        assert isinstance(placement_scope, placement_ctx.EmptyPlacementScope)
        tag_and_dev_ids = (
            placement_scope.device_tag,
            placement_scope.machine_device_ids,
        )
    distribute_strategy = function_desc.function_attribute.default_distribute_strategy
    if distribute_strategy is None:
        distribute_strategy = distribute_util.DistributeConsistentStrategy()
    is_mirrored = isinstance(distribute_strategy,
                             distribute_util.DistributeMirroredStrategy)
    scope = scope_util.MakeInitialScope(job_conf, *tag_and_dev_ids,
                                        is_mirrored)
    with _JobBuildAndInferCtx(job_conf.job_name()), distribute_strategy:
        c_api_util.CurJobBuildAndInferCtx_SetJobConf(job_conf)
        with runtime_mode.ModeScope(runtime_mode.GLOBAL_MODE):
            with scope_util.ScopeContext(scope):
                yield
Exemplo n.º 2
0
    def open(self, job_name, signature=None, batch_size=None):
        self._check_status(self.SessionStatus.OPEN)
        c_api_util.JobBuildAndInferCtx_Open(job_name)

        if signature is not None:
            self.set_job_signature(job_name, signature)

        if isinstance(batch_size, int):
            self.set_job_batch_size(job_name, batch_size)

        job_conf = self._get_job_conf(job_name)
        c_api_util.CurJobBuildAndInferCtx_SetJobConf(job_conf)

        tag_and_dev_ids = placement_util.GetDefaultMachineDeviceIds(
            self.config_proto_.resource)
        scope = scope_util.MakeInitialScope(job_conf, *tag_and_dev_ids, None,
                                            self.is_mirrored_)

        with runtime_mode.ModeScope(runtime_mode.GLOBAL_MODE):
            with scope_util.ScopeContext(scope):
                self.cur_job_name_ = job_name
                yield self
                self.cur_job_name_ = None

        oneflow_api.JobBuildAndInferCtx_Close()
Exemplo n.º 3
0
def InterpretScope(session, function_desc, config_proto):
    job_conf = function_desc.job_config_proto
    job_conf.job_name = function_desc.job_func.__name__
    placement_scope = function_desc.function_attribute.default_placement_scope
    if placement_scope is None:
        tag_and_dev_ids = placement_util.GetDefaultMachineDeviceIds(
            oneflow.env.current_resource())
        placement_scope = placement_util.GetPlacementScope(*tag_and_dev_ids)
    distribute_strategy = function_desc.function_attribute.default_distribute_strategy
    if distribute_strategy is None:
        distribute_strategy = distribute_util.DistributeConsistentStrategy()
    is_mirrored = isinstance(distribute_strategy,
                             distribute_util.DistributeMirroredStrategy)
    tag_and_dev_ids = parallel_conf_util.GetDeviceTagAndMachineDeviceIds(
        placement_scope.default_parallel_conf)
    scope = MakeInitialScope(job_conf, *tag_and_dev_ids, is_mirrored)
    with _JobBuildAndInferCtx(
            job_conf.job_name), placement_scope, distribute_strategy:
        c_api_util.CurJobBuildAndInferCtx_SetJobConf(job_conf)
        with runtime_mode.ModeScope(runtime_mode.GLOBAL_MODE):
            with _SessionInitialScope(session, scope):
                yield