예제 #1
0
def InterpretScope(session, function_desc, config_proto):
    job_conf = function_desc.job_config_proto
    job_conf.set_job_name(function_desc.job_func.__name__)
    placement_scope = function_desc.function_attribute.default_placement_scope
    if placement_scope is None:
        tag_and_dev_ids = placement_util.GetDefaultMachineDeviceIds(
            session.resource)
    else:
        assert isinstance(placement_scope, placement_ctx.EmptyPlacementScope)
        tag_and_dev_ids = (
            placement_scope.device_tag,
            placement_scope.machine_device_ids,
        )
    distribute_strategy = function_desc.function_attribute.default_distribute_strategy
    if distribute_strategy is None:
        distribute_strategy = distribute_util.DistributeConsistentStrategy()
    is_mirrored = isinstance(distribute_strategy,
                             distribute_util.DistributeMirroredStrategy)
    scope = scope_util.MakeInitialScope(job_conf, *tag_and_dev_ids,
                                        is_mirrored)
    with _JobBuildAndInferCtx(job_conf.job_name()), distribute_strategy:
        c_api_util.CurJobBuildAndInferCtx_SetJobConf(job_conf)
        with runtime_mode.ModeScope(runtime_mode.GLOBAL_MODE):
            with scope_util.ScopeContext(scope):
                yield
예제 #2
0
def InterpretScope(session, function_desc, config_proto):
    job_conf = function_desc.job_config_proto
    job_conf.job_name = function_desc.job_func.__name__
    placement_scope = function_desc.function_attribute.default_placement_scope
    if placement_scope is None:
        tag_and_dev_ids = placement_util.GetDefaultMachineDeviceIds(
            oneflow.env.current_resource())
        placement_scope = placement_util.GetPlacementScope(*tag_and_dev_ids)
    distribute_strategy = function_desc.function_attribute.default_distribute_strategy
    if distribute_strategy is None:
        distribute_strategy = distribute_util.DistributeConsistentStrategy()
    is_mirrored = isinstance(distribute_strategy,
                             distribute_util.DistributeMirroredStrategy)
    tag_and_dev_ids = parallel_conf_util.GetDeviceTagAndMachineDeviceIds(
        placement_scope.default_parallel_conf)
    scope = MakeInitialScope(job_conf, *tag_and_dev_ids, is_mirrored)
    with _JobBuildAndInferCtx(
            job_conf.job_name), placement_scope, distribute_strategy:
        c_api_util.CurJobBuildAndInferCtx_SetJobConf(job_conf)
        with runtime_mode.ModeScope(runtime_mode.GLOBAL_MODE):
            with _SessionInitialScope(session, scope):
                yield