예제 #1
0
파일: source.py 프로젝트: tarsbase/scanner
    def to_proto(self, indices):
        e = protobufs.Op()
        e.name = self._name
        e.is_source = True

        inp = e.inputs.add()
        inp.column = self._inputs[0]._col
        inp.op_index = -1

        if isinstance(self._args, dict):
            # To convert an arguments dict, we search for a protobuf with the
            # name {Name}SourceArgs (e.g. ColumnSourceArgs) and the name
            # {Name}EnumeratorArgs (e.g. ColumnEnumeratorArgs) in the
            # args.proto module, and fill that in with keys from the args dict.
            if len(self._args) > 0:
                source_info = self._sc._get_source_info(self._name)
                if len(source_info.protobuf_name) > 0:
                    proto_name = source_info.protobuf_name
                    e.kernel_args = python_to_proto(proto_name, self._args)
                else:
                    e.kernel_args = self._args
        else:
            # If arguments are a protobuf object, serialize it directly
            e.kernel_args = self._args.SerializeToString()

        return e
예제 #2
0
파일: op.py 프로젝트: spillai/scanner
        def make_op(*args, **kwargs):
            inputs = []
            if op_info.variadic_inputs:
                inputs.extend(args)
            else:
                for c in op_info.input_columns:
                    val = kwargs.pop(c.name, None)
                    if val is None:
                        raise ScannerException(
                            'Op {} required sequence {} as input'.format(
                                orig_name, c.name))
                    inputs.append(val)

            device = kwargs.pop('device', DeviceType.CPU)
            batch = kwargs.pop('batch', -1)
            bounded_state = kwargs.pop('bounded_state', -1)
            stencil = kwargs.pop('stencil', [])
            extra = kwargs.pop('extra', None)
            args = kwargs.pop('args', None)

            if len(stream_params) > 0:
                stream_args = [(k, kwargs.pop(k, None)) for k in stream_params
                               if k in kwargs]
                if len(stream_args) == 0:
                    raise ScannerException(
                           "No arguments provided to op `{}` for stream parameters."
                            .format(orig_name))
                for e in stream_args:
                    if not isinstance(e[1], list):
                        raise ScannerException(
                            "The argument `{}` to op `{}` is a stream config argument and must be a list."
                            .format(e[0], orig_name))
                example_list = stream_args[0][1]
                N = len(example_list)
                if not isinstance(example_list[0], SliceList):
                    stream_args = [
                        (k, [SliceList([x]) for x in arg]) for (k, arg) in stream_args]
                M = len(stream_args[0][1][0])

                if orig_name in PYTHON_OP_REGISTRY:
                    stream_args = [
                        SliceList([pickle.dumps({k: v[i][j] for k, v in stream_args})
                                   for j in range(M)])
                        for i in range(N)
                    ]
                else:
                    stream_args = [
                        SliceList([python_to_proto(op_info.stream_protobuf_name,
                                                   {k: v[i][j] for k, v in stream_args})
                                  for j in range(M)])
                        for i in range(N)
                    ]
            else:
                stream_args = None

            op = Op(self._sc, name, inputs, device, batch, bounded_state,
                    stencil, kwargs if args is None else args, extra, stream_args)
            return op.outputs()
예제 #3
0
def collect_per_stream_args(name, protobuf_name, kwargs):
    stream_arg_names = list(
        analyze_proto(getattr(protobufs, protobuf_name)).keys())
    stream_args = {k: kwargs.pop(k) for k in stream_arg_names if k in kwargs}

    if len(stream_args) == 0:
        raise ScannerException(
            "Op `{}` received no per-stream arguments. Options: {}" \
            .format(name, ', '.join(stream_args)))

    N = [len(v) for v in stream_args.values() if isinstance(v, list)][0]

    job_args = [
        python_to_proto(
            protobuf_name, {
                k: v[i] if isinstance(v, list) else v
                for k, v in stream_args.items() if v is not None
            }) for i in range(N)
    ]

    return job_args
예제 #4
0
    def to_proto(self, indices):
        e = protobufs.Op()
        e.name = self._name
        e.device_type = DeviceType.to_proto(protobufs, self._device)
        e.stencil.extend(self._stencil)
        e.batch = self._batch
        e.warmup = self._warmup

        if e.name == "Input":
            inp = e.inputs.add()
            inp.column = self._inputs[0]._col
            inp.op_index = -1
        else:
            for i in self._inputs:
                inp = e.inputs.add()
                idx = indices[i._op] if i._op is not None else -1
                inp.op_index = idx
                inp.column = i._col

        if isinstance(self._args, dict):
            if self._name in self._sc._python_ops:
                e.kernel_args = pickle.dumps(self._args)
            elif len(self._args) > 0:
                # To convert an arguments dict, we search for a protobuf with the
                # name {Op}Args (e.g. BlurArgs, HistogramArgs) in the
                # args.proto module, and fill that in with keys from the args dict.
                op_info = self._sc._get_op_info(self._name)
                if len(op_info.protobuf_name) > 0:
                    proto_name = op_info.protobuf_name
                    e.kernel_args = python_to_proto(proto_name, self._args)
                else:
                    e.kernel_args = self._args
        else:
            # If arguments are a protobuf object, serialize it directly
            e.kernel_args = self._args.SerializeToString()

        return e