예제 #1
0
    def open(self, job_name, signature=None, batch_size=None):
        self._check_status(self.SessionStatus.OPEN)
        c_api_util.JobBuildAndInferCtx_Open(job_name)

        if signature is not None:
            self.set_job_signature(job_name, signature)

        if isinstance(batch_size, int):
            self.set_job_batch_size(job_name, batch_size)

        job_conf = self._get_job_conf(job_name)
        c_api_util.CurJobBuildAndInferCtx_SetJobConf(job_conf)

        tag_and_dev_ids = placement_util.GetDefaultMachineDeviceIds(
            self.config_proto_.resource)
        scope = scope_util.MakeInitialScope(job_conf, *tag_and_dev_ids, None,
                                            self.is_mirrored_)

        with runtime_mode.ModeScope(runtime_mode.GLOBAL_MODE):
            with scope_util.ScopeContext(scope):
                self.cur_job_name_ = job_name
                yield self
                self.cur_job_name_ = None

        oneflow_api.JobBuildAndInferCtx_Close()
예제 #2
0
def _JobBuildAndInferCtx(job_name):
    c_api_util.JobBuildAndInferCtx_Open(job_name)
    try:
        yield
    finally:
        c_api_util.JobBuildAndInferCtx_Close()