Exemplo n.º 1
0
class ImagePlotNode(Node):
    """
    Node that simultaneously plots an image and different types of plots.
    Can pass in curves, points, and rois in addition to images
    """

    nodeName = 'ImagePlot'
    sigViewChanged = QtCore.Signal(object)

    def __init__(self, name, **kwargs):
        self.plot = None
        self.plots = {}  # list of available plots user may select from
        self.ui = None
        self.imageItem = None
        self.plotItem = None
        self.scatterDefaults = dict(symbol='+', symbolPen='r', symbolSize=8, symbolBrush='r', pen=None)
        self.curveDefaults = dict(pen='w')

        opts = dict(allowAddInput=True,
                       terminals=dict(image={'io': 'in'}, points={'io': 'in'}, curves={'io': 'in'}, rois={'io': 'in'}))
        opts.update(**kwargs)

        ## Initialize node with only a single input terminal
        Node.__init__(self, name, **opts)

    def setPlot(self, plot):  ## setView must be called by the program
        if plot == self.plot:
            return
        self.plot = plot
        self.plotItem = self.plot.plotItem
        self.plotItem.setAspectLocked()
        self.plotItem.invertY(True)
        # self.plot.plotItem.addItem(self.plotItem)
        # clear data from previous plot
        self.updateUi()
        self.update()
        self.sigViewChanged.emit(self)

    def process(self, image=None, points=None, curves=None, rois=None, display=True, **additionalPlots):
        self.plotItem.clear()
        defaults = dict(symbol='+', symbolPen='r', symbolSize=8, symbolBrush='r', pen=None)
        if display and self.plot is not None:
            if image is not None:
                self.plotImage(image)

            if points is not None:
                self.plotPoints(points)

            for name, additionalPlot in additionalPlots.items():
                if additionalPlot is not None:
                    for childItem in self.plotItem.items:
                        if isinstance(childItem, pg.PlotDataItem):
                            if childItem.name() == name:
                                childItem.clear()
                defaults.update(symbolPen='b')
                self.plotItem.plot(x=additionalPlot[:, 1], y=additionalPlot[:, 0], name=name, **defaults)

        if rois is not None:
            for roi in rois:
                plotROI = pg.ROI(roi.r0[::-1], roi.size[::-1])
                self.plotItem.addItem(plotROI)

    def plotImage(self, image):
        # sort out of image is a blob from caffenetOutputs[blobName]
        # if the image is 3D and the trailing dimension is not 3 or 4 (RGB(A)), it's a blob
        if image.ndim == 4:
            # TODO handle batched input
            # for now, just take a 3D blob: ch x h x w
            image = np.squeeze(image)
            if image.ndim == 3:
                if any([image.shape[0] == i for i in [3, 4]]):
                    # BGR - RGB
                    image = image[::-1]
                    # to h x w x ch
                    image = np.rollaxis(image, 0, 3)
                else:
                    image = self.constructImageFromBlob(blob=image)
        # pyqtgraph expects w x h..
        imageItem = pg.ImageItem(image.swapaxes(0, 1))
        self.plotItem.addItem(imageItem)

    def plotPoints(self, points, **plotArgs):
        defaults = self.scatterDefaults.copy()
        defaults.update(**plotArgs)
        self.plotItem.plot(x=points[:, 1], y=points[:, 0], **defaults)

    def plotCurves(self, curves, **plotArgs):
        defaults = self.curveDefaults.copy()
        defaults.update(**plotArgs)
        for curve in curves:
            self.plotItem.plot(curve, **defaults)

    def constructImageFromBlob(self, blob):
        imShape = np.array(blob.shape[1:])
        # pad by the smaller of the largest image dimension/10 or 3
        pad = min(np.max(imShape)/10, 3)
        blob = np.pad(blob, pad_width=((0, 0), (0, pad), (0, pad)), mode='constant')

        # make a squarish grid of the layers in the blob
        numChannels = blob.shape[0]
        minWidth = np.floor(np.sqrt(numChannels))

        aList = []

        # cols grow to accomodate demand first, then rows
        numCols = int(np.ceil(numChannels/minWidth))
        # numRows = int(np.ceil(numChannels/numCols))

        # take a stack of minWidth channels (like dealing cards one person at a time)
        splits = split_every(numCols, blob)
        for split in splits:
            aList.append(np.hstack(split))
            # aList.append(np.hstack(blob[i * numCols:(i + 1) * numCols]))
        aList[-1] = aList[-1].T.copy()
        aList[-1].resize(aList[0].shape[::-1])

        # last.resize(aList[0].shape[::-1], refCheck=False)
        # last = last.T
        # last.resize(aList[0].shape, refcheck=False)
        # aList[-1] = last
        aList[-1] = aList[-1].T
        return np.vstack(aList)[:-pad, :-pad]
        #
        # imShape = np.array(blob.shape[1:])
        # # pad by the smaller of the largest image dimension/10 or 3
        # pad = np.min(np.max(imShape)/10, 3)
        # # make a square grid of the layers in the blob
        # numChannels = blob.shape(0)
        # minWidth = np.floor(np.sqrt(numChannels))
        # boxSize = minWidth*minWidth
        # remainder = numChannels - boxSize
        # tileShape = np.array([minWidth, minWidth], dtype=int)
        # if remainder > 0:
        #     # if there are less extra elements than the size of the box, we can just add a column, otherwise add both
        #     # a row and col
        #     if boxSize < minWidth:
        #         tileShape[1] += 1
        #     else:
        #         tileShape += 1
        # blob = np.pad(blob, pad_width=((0,0),(0,pad),(0,pad)), mode='constant')






        # now make a big array to put everything into, this is image shape plus padding * num of tiles in each
        # dimension, with the padding subtracted off the last row/col
        # bigShape = (tileShape * (imShape + pad)) - pad
        # bigImage = 255 * np.zeros(bigShape)
        # yArr, xArr = np.indices(tileShape)
        # yxPairs = np.dstack((yArr.flatten(), xArr.flatten()))[0]
        # for i, yxPair in enumerate(yxPairs):
        #     start = yxPair * (imShape + pad)
        #     end = start + imShape
        #     roi = ROI(r0=start[::-1], r1=end[::-1])
        #     bigImage[roi.imSlice] = blob[i]
        # return bigImage


    def ctrlWidget(self):
        if self.ui is None:
            self.ui = ComboBox()
            self.ui.currentIndexChanged.connect(self.plotSelected)
            self.updateUi()
        return self.ui

    def plotSelected(self, index):
        self.setPlot(self.ui.value())

    def setPlotList(self, plots):
        """
        Specify the set of plots (ImageView) that the user may
        select from.

        *plots* must be a dictionary of {name: plot} pairs.
        """
        self.plots = plots
        self.updateUi()

    def updateUi(self):
        # sets list and automatically preserves previous selection
        self.ui.setItems(self.plots)
        try:
            self.ui.setValue(self.plots)
        except ValueError:
            pass