コード例 #1
0
ファイル: pipeline.py プロジェクト: snoworld888/DALI
 def feed_input(self, ref, data):
     """Bind the NumPy array to a tensor produced by ExternalSource
     operator."""
     if not self._built:
         raise RuntimeError("Pipeline must be built first.")
     if not isinstance(ref, Edge.EdgeReference):
         raise TypeError(
             ("Expected argument one to "
              "be EdgeReference. "
              "Received output type {}").format(type(ref).__name__))
     if isinstance(data, list):
         inputs = []
         for datum in data:
             inputs.append(Edge.TensorCPU(datum))
         self._pipe.SetExternalTensorInput(ref.name, inputs)
     else:
         inp = Edge.TensorListCPU(data)
         self._pipe.SetExternalTLInput(ref.name, inp)
コード例 #2
0
ファイル: pipeline.py プロジェクト: cirquit/DALI
 def feed_input(self, ref, data, layout=types.NHWC):
     """Bind the NumPy array to a tensor produced by ExternalSource
     operator. It is worth mentioning that `ref` should not be overriden
     with other operator outputs."""
     if not self._built:
         raise RuntimeError("Pipeline must be built first.")
     if not isinstance(ref, Edge.EdgeReference):
         raise TypeError(
             ("Expected argument one to "
              "be EdgeReference. "
              "Received output type {}").format(type(ref).__name__))
     if isinstance(data, list):
         if self._batch_size != len(data):
             raise RuntimeError(
                 "Data list provided to feed_input needs to have batch_size length"
             )
         inputs = []
         for datum in data:
             inputs.append(Edge.TensorCPU(datum, layout))
         self._pipe.SetExternalTensorInput(ref.name, inputs)
     else:
         inp = Edge.TensorListCPU(data, layout)
         self._pipe.SetExternalTLInput(ref.name, inp)