示例#1
0
    def compile(cls, runnable, workspace_type=None, setup_net_list=None):
        if isinstance(runnable, CompiledRunnable):
            assert cls == runnable.session_class, (
                'Runnable was compiled for different session type. ' +
                'Need: %s, got: %s' %
                (cls.__name__, runnable.session_class.__name__))
            return runnable

        if runnable in cls._compiled_cache:
            return cls._compiled_cache[runnable]

        if isinstance(runnable, TaskGroup):
            if workspace_type:
                if runnable.workspace_type():
                    assert runnable.workspace_type() == workspace_type, \
                        "Require {} but already have {}".format(
                            workspace_type, runnable.workspace_type())
                else:
                    runnable._workspace_type = workspace_type
            tg = runnable
        else:
            if workspace_type is None:
                workspace_type = WorkspaceType.GLOBAL
            tg = TaskGroup(workspace_type=workspace_type)
            if isinstance(runnable, Task):
                tg.add(runnable)
            elif isinstance(runnable, core.ExecutionStep):
                tg.add(Task(step=runnable))
            elif isinstance(runnable, core.Plan):
                # ExecutionSteps in Plan() object is supposed to run sequentially, while
                # tasks in TaskGroup run in parallel. So if we have multiple
                # ExecutionSteps in Plan() object, we choose to have a root
                # ExecutionStep to wrap all ExecutionSteps.
                assert len(runnable.Steps()) > 0
                if len(runnable.Steps()) == 1:
                    tg.add(Task(step=runnable.Steps()[0]))
                else:
                    # Task takes a list of ExecutionSteps and automatically wrap into
                    # a root ExecutionStep
                    tg.add(Task(step=runnable.Steps()))
            else:
                step = core.execution_step('runnable', runnable)
                tg.add(Task(step=step))
        compiled = CompiledRunnable(cls._compile_task_group(
            tg, setup_net_list),
                                    session_class=cls)
        cls._compiled_cache[runnable] = compiled
        return compiled
示例#2
0
    def compile(cls, runnable, workspace_type=None, setup_net_list=None):
        if isinstance(runnable, CompiledRunnable):
            assert cls == runnable.session_class, (
                'Runnable was compiled for different session type. ' +
                'Need: %s, got: %s' % (
                    cls.__name__, runnable.session_class.__name__))
            return runnable

        if runnable in cls._compiled_cache:
            return cls._compiled_cache[runnable]

        if isinstance(runnable, TaskGroup):
            if workspace_type:
                if runnable.workspace_type():
                    assert runnable.workspace_type() == workspace_type, \
                        "Require {} but already have {}".format(
                            workspace_type, runnable.workspace_type())
                else:
                    runnable._workspace_type = workspace_type
            tg = runnable
        else:
            if workspace_type is None:
                workspace_type = WorkspaceType.GLOBAL
            tg = TaskGroup(workspace_type=workspace_type)
            if isinstance(runnable, Task):
                tg.add(runnable)
            elif isinstance(runnable, core.ExecutionStep):
                tg.add(Task(step=runnable))
            elif isinstance(runnable, core.Plan):
                # ExecutionSteps in Plan() object is supposed to run sequentially, while
                # tasks in TaskGroup run in parallel. So if we have multiple
                # ExecutionSteps in Plan() object, we choose to have a root
                # ExecutionStep to wrap all ExecutionSteps.
                assert len(runnable.Steps()) > 0
                if len(runnable.Steps()) == 1:
                    tg.add(Task(step=runnable.Steps()[0]))
                else:
                    # Task takes a list of ExecutionSteps and automatically wrap into
                    # a root ExecutionStep
                    tg.add(Task(step=runnable.Steps()))
            else:
                step = core.execution_step('runnable', runnable)
                tg.add(Task(step=step))
        compiled = CompiledRunnable(
            cls._compile_task_group(tg, setup_net_list), session_class=cls)
        cls._compiled_cache[runnable] = compiled
        return compiled
示例#3
0
    def compile(cls, runnable):
        if isinstance(runnable, CompiledRunnable):
            assert cls == runnable.session_class, (
                'Runnable was compiled for different session type. ' +
                'Need: %s, got: %s' % (
                    cls.__name__, runnable.session_class.__name__))
            return runnable

        if runnable in cls._compiled_cache:
            return cls._compiled_cache[runnable]

        if isinstance(runnable, TaskGroup):
            tg = runnable
        else:
            tg = TaskGroup(workspace_type=WorkspaceType.GLOBAL)
            if isinstance(runnable, Task):
                tg.add(runnable)
            elif isinstance(runnable, core.ExecutionStep):
                tg.add(Task(step=runnable))
            else:
                step = core.execution_step('runnable', runnable)
                tg.add(Task(step=step))
        compiled = CompiledRunnable(
            cls._compile_task_group(tg), session_class=cls)
        cls._compiled_cache[runnable] = compiled
        return compiled
示例#4
0
文件: session.py 项目: xpo454/Caffe2
    def compile(cls, runnable):
        if isinstance(runnable, CompiledRunnable):
            assert cls == runnable.session_class, (
                'Runnable was compiled for different session type. ' +
                'Need: %s, got: %s' %
                (cls.__name__, runnable.session_class.__name__))
            return runnable

        if runnable in cls._compiled_cache:
            return cls._compiled_cache[runnable]

        if isinstance(runnable, TaskGroup):
            tg = runnable
        else:
            tg = TaskGroup(workspace_type=WorkspaceType.GLOBAL)
            if isinstance(runnable, Task):
                tg.add(runnable)
            elif isinstance(runnable, core.ExecutionStep):
                tg.add(Task(step=runnable))
            else:
                step = core.execution_step('runnable', runnable)
                tg.add(Task(step=step))
        compiled = CompiledRunnable(cls._compile_task_group(tg),
                                    session_class=cls)
        cls._compiled_cache[runnable] = compiled
        return compiled
示例#5
0
文件: session.py 项目: zhxxhit/caffe2
 def run(self, runnable):
     assert self.is_open(), 'Session is closed.'
     if runnable not in self._runnable_cache:
         if isinstance(runnable, TaskGroup):
             tg = runnable
         else:
             tg = TaskGroup(workspace_type=WorkspaceType.GLOBAL)
             if isinstance(runnable, Task):
                 tg.add(runnable)
             elif isinstance(runnable, core.ExecutionStep):
                 tg.add(Task(step=runnable))
             else:
                 step = core.execution_step('runnable', runnable)
                 tg.add(Task(step=step))
         self._runnable_cache[runnable] = tg
     self._run_task_group(self._runnable_cache[runnable])