Exemple #1
0
    def execute_loop(self, *outputs):
        def recursive_feed_output(x, output_define):
            if isinstance(x, Tensor):
                output_define.append(x.define)
            elif isinstance(x, (list, tuple, set)):
                for i in x:
                    recursive_feed_output(i, output_define)
            elif isinstance(x, dict):
                for i in x:
                    recursive_feed_output(x[i], output_define)
            else:
                raise ValueError("cannot execute type {}".format(x))

        if (len(outputs) == 0):
            raise ValueError("execute loop need at least one state")
        output_spec = pybind.OutputSpecVector()
        for o in outputs:
            spec = []
            recursive_feed_output(o, spec)
            xdl_spec = pybind.OutputSpec()
            xdl_spec.output = pybind.StringVector(spec)
            xdl_spec.output_device = Graph.default_device()
            output_spec.append(xdl_spec)
        pybind.execute_loop(self._graph_def, output_spec)
Exemple #2
0
    def execute(self, outputs, run_option=None, run_statistic=None):
        if run_option and run_option.perf:
            if run_statistic is None:
                raise 'run_statistic must be specified when perf is turned on'
        output_define = []

        def recursive_feed_output(x, k):
            if isinstance(x, Tensor):
                x = x.define
            if isinstance(x, (str, unicode)):
                output_define.append(x)
                if x[0] == '^':
                    return None, k
                else:
                    return k, k + 1
            elif isinstance(x, (list, tuple, set)):
                rst = []
                for i in x:
                    y, k = recursive_feed_output(i, k)
                    rst += [y]
                return x.__class__(rst), k
            elif isinstance(x, dict):
                rst = {}
                for i in x:
                    y, k = recursive_feed_output(x[i], k)
                    rst[i] = y
                return rst, k
            else:
                raise ValueError("cannot execute type {}".format(x))

        output_spec, _ = recursive_feed_output(outputs, 0)
        xdl_output_spec = pybind.OutputSpec()
        xdl_output_spec.output = pybind.StringVector(output_define)
        xdl_output_spec.output_device = Graph.default_device()
        run_option = run_option if run_option is not None else pybind.RunOption(
        )
        result = pybind.execute(self._graph_def, xdl_output_spec, run_option)
        check_error(result.status)
        outputs = result.outputs
        if run_option and run_option.perf:
            run_statistic.perf_result = result.run_statistic.perf_result

        def recursive_build_result(x):
            if x is None:
                return None
            if isinstance(x, (int, long)):
                return numpy.array(outputs[x], copy=False)
            elif isinstance(x, (list, tuple, set)):
                rst = []
                for i in x:
                    y = recursive_build_result(i)
                    rst += [y]
                return x.__class__(rst)
            elif isinstance(x, dict):
                rst = {}
                for i in x:
                    y = recursive_build_result(x[i])
                    rst[i] = y
                return rst
            else:
                raise ValueError("Internal Error")

        return recursive_build_result(output_spec)