Example #1
0
    def setupOutputs(self):
        
        # Output meta starts with a copy of the input meta, which is then modified
        self.Output.meta.assignFrom(self.Input.meta)
        
        numChannels  = 1
        inputSlot = self.inputs["Input"]
        if inputSlot.meta.axistags.axisTypeCount(vigra.AxisType.Channels) > 0:
            channelIndex = self.inputs["Input"].meta.axistags.channelIndex
            numChannels = self.inputs["Input"].meta.shape[channelIndex]
            inShapeWithoutChannels = popFlagsFromTheKey( self.inputs["Input"].meta.shape,self.inputs["Input"].meta.axistags,'c')
        else:
            inShapeWithoutChannels = inputSlot.meta.shape
            channelIndex = len(inputSlot.meta.shape)

        self.outputs["Output"].meta.dtype = self.outputDtype
        p = self.inputs["Input"].partner
        at = copy.copy(inputSlot.meta.axistags)

        if at.axisTypeCount(vigra.AxisType.Channels) == 0:
            at.insertChannelAxis()

        self.outputs["Output"].meta.axistags = at

        channelsPerChannel = self.resultingChannels()
        inShapeWithoutChannels = list(inShapeWithoutChannels)
        inShapeWithoutChannels.insert(channelIndex,numChannels * channelsPerChannel)
        self.outputs["Output"].meta.shape = tuple(inShapeWithoutChannels)

        if self.outputs["Output"].meta.axistags.axisTypeCount(vigra.AxisType.Channels) == 0:
            self.outputs["Output"].meta.axistags.insertChannelAxis()

        # The output data range is not necessarily the same as the input data range.
        if 'drange' in self.Output.meta:
            del self.Output.meta['drange']
Example #2
0
    def execute(self, slot, subindex, rroi, result, sourceArray=None):
        assert len(subindex) == self.Output.level == 0
        key = roiToSlice(rroi.start, rroi.stop)

        kwparams = {}
        for islot in self.inputs.values():
            if islot.name != "Input":
                kwparams[islot.name] = islot.value

        if self.inputs.has_key("sigma"):
            sigma = self.inputs["sigma"].value
        elif self.inputs.has_key("scale"):
            sigma = self.inputs["scale"].value
        elif self.inputs.has_key("sigma0"):
            sigma = self.inputs["sigma0"].value
        elif self.inputs.has_key("innerScale"):
            sigma = self.inputs["innerScale"].value

        windowSize = 4.0
        if self.supportsWindow:
            kwparams['window_size']=self.window_size
            windowSize = self.window_size

        largestSigma = sigma #ensure enough context for the vigra operators

        shape = self.outputs["Output"].meta.shape

        axistags = self.inputs["Input"].meta.axistags
        hasChannelAxis = self.inputs["Input"].meta.axistags.axisTypeCount(vigra.AxisType.Channels)
        channelAxis=self.inputs["Input"].meta.axistags.index('c')
        hasTimeAxis = self.inputs["Input"].meta.axistags.axisTypeCount(vigra.AxisType.Time)
        timeAxis=self.inputs["Input"].meta.axistags.index('t')

        subkey = popFlagsFromTheKey(key,axistags,'c')
        subshape=popFlagsFromTheKey(shape,axistags,'c')
        at2 = copy.copy(axistags)
        at2.dropChannelAxis()
        subshape=popFlagsFromTheKey(subshape,at2,'t')
        subkey = popFlagsFromTheKey(subkey,at2,'t')

        oldstart, oldstop = roi.sliceToRoi(key, shape)

        start, stop = roi.sliceToRoi(subkey,subkey)
        newStart, newStop = roi.extendSlice(start, stop, subshape, largestSigma, window = windowSize)
        readKey = roi.roiToSlice(newStart, newStop)

        writeNewStart = start - newStart
        writeNewStop = writeNewStart +  stop - start

        if (writeNewStart == 0).all() and (newStop == writeNewStop).all():
            fullResult = True
        else:
            fullResult = False

        writeKey = roi.roiToSlice(writeNewStart, writeNewStop)
        writeKey = list(writeKey)
        if timeAxis < channelAxis:
            writeKey.insert(channelAxis-1, slice(None,None,None))
        else:
            writeKey.insert(channelAxis, slice(None,None,None))
        writeKey = tuple(writeKey)

        channelsPerChannel = self.resultingChannels()

        if self.supportsRoi is False and largestSigma > 5:
            logger.warn("WARNING: operator", self.name, "does not support roi !!")

        i2 = 0
        for i in range(int(numpy.floor(1.0 * oldstart[channelAxis]/channelsPerChannel)),int(numpy.ceil(1.0 * oldstop[channelAxis]/channelsPerChannel))):
            treadKey=list(readKey)

            if hasTimeAxis:
                if channelAxis > timeAxis:
                    treadKey.insert(timeAxis, key[timeAxis])
                else:
                    treadKey.insert(timeAxis-1, key[timeAxis])
            treadKey.insert(channelAxis, slice(i,i+1,None))
            treadKey=tuple(treadKey)

            if sourceArray is None:
                req = self.inputs["Input"][treadKey].allocate()
                t = req.wait()
            else:
                t = sourceArray[getAllExceptAxis(len(treadKey),channelAxis,slice(i,i+1,None) )]

            t = numpy.require(t, dtype=self.inputDtype)
            t = t.view(vigra.VigraArray)
            t.axistags = copy.copy(axistags)
            t = t.insertChannelAxis()

            sourceBegin = 0

            if oldstart[channelAxis] > i * channelsPerChannel:
                sourceBegin = oldstart[channelAxis] - i * channelsPerChannel
            sourceEnd = channelsPerChannel
            if oldstop[channelAxis] < (i+1) * channelsPerChannel:
                sourceEnd = channelsPerChannel - ((i+1) * channelsPerChannel - oldstop[channelAxis])
            destBegin = i2
            destEnd = i2 + sourceEnd - sourceBegin



            if channelsPerChannel>1:
                tkey=getAllExceptAxis(len(shape),channelAxis,slice(destBegin,destEnd,None))
                resultArea = result[tkey]
            else:
                tkey=getAllExceptAxis(len(shape),channelAxis,slice(i2,i2+1,None))
                resultArea = result[tkey]

            i2 += destEnd-destBegin

            supportsOut = self.supportsOut
            if (destEnd-destBegin != channelsPerChannel):
                supportsOut = False

            supportsOut= False #disable for now due to vigra crashes!
            for step,image in enumerate(t.timeIter()):
                nChannelAxis = channelAxis - 1

                if timeAxis > channelAxis or not hasTimeAxis:
                    nChannelAxis = channelAxis
                twriteKey=getAllExceptAxis(image.ndim, nChannelAxis, slice(sourceBegin,sourceEnd,None))

                if hasTimeAxis > 0:
                    tresKey  = getAllExceptAxis(resultArea.ndim, timeAxis, step)
                else:
                    tresKey  = slice(None, None,None)

                #print tresKey, twriteKey, resultArea.shape, temp.shape
                vres = resultArea[tresKey]
                if supportsOut:
                    if self.supportsRoi:
                        vroi = (tuple(writeNewStart._asint()), tuple(writeNewStop._asint()))
                        try:
                            vres = vres.view(vigra.VigraArray)
                            vres.axistags = copy.copy(image.axistags)
                            print "FAST LANE", self.name, vres.shape, image[twriteKey].shape, vroi
                            temp = self.vigraFilter(image[twriteKey], roi = vroi,out=vres, **kwparams)
                        except:
                            print self.name, image.shape, vroi, kwparams
                    else:
                        try:
                            temp = self.vigraFilter(image, **kwparams)
                        except:
                            print self.name, image.shape, vroi, kwparams
                        temp=temp[writeKey]
                else:
                    if self.supportsRoi:
                        vroi = (tuple(writeNewStart._asint()), tuple(writeNewStop._asint()))
                        try:
                            temp = self.vigraFilter(image, roi = vroi, **kwparams)
                        except Exception, e:
                            print "EXCEPT 2.1", self.name, image.shape, vroi, kwparams
                            traceback.print_exc(e)
                            sys.exit(1)
                    else:
                        try:
                            temp = self.vigraFilter(image, **kwparams)
                        except Exception, e:
                            print "EXCEPT 2.2", self.name, image.shape, kwparams
                            traceback.print_exc(e)
                            sys.exit(1)
                        temp=temp[writeKey]


                    try:
                        vres[:] = temp[twriteKey]
                    except:
                        print "EXCEPT3", vres.shape, temp.shape, twriteKey
                        print "EXCEPT3", resultArea.shape,  tresKey, twriteKey
                        print "EXCEPT3", step, t.shape, timeAxis
                        raise
Example #3
0
    def execute(self, slot, subindex, rroi, result):
        assert slot == self.Features or slot == self.Output
        if slot == self.Features:
            key = roiToSlice(rroi.start, rroi.stop)
            index = subindex[0]
            subslot = self.Features[index]
            key = list(key)
            channelIndex = self.Input.meta.axistags.index('c')
            
            # Translate channel slice to the correct location for the output slot.
            key[channelIndex] = slice(self.featureOutputChannels[index][0] + key[channelIndex].start,
                                      self.featureOutputChannels[index][0] + key[channelIndex].stop)
            rroi = SubRegion(subslot, pslice=key)
    
            # Get output slot region for this channel
            return self.execute(self.Output, (), rroi, result)
        elif slot == self.outputs["Output"]:
            key = rroi.toSlice()
            cnt = 0
            written = 0
            assert (rroi.stop<=self.outputs["Output"].meta.shape).all()
            flag = 'c'
            channelAxis=self.inputs["Input"].meta.axistags.index('c')
            axisindex = channelAxis
            oldkey = list(key)
            oldkey.pop(axisindex)


            inShape  = self.inputs["Input"].meta.shape

            shape = self.outputs["Output"].meta.shape

            axistags = self.inputs["Input"].meta.axistags

            result = result.view(vigra.VigraArray)
            result.axistags = copy.copy(axistags)


            hasTimeAxis = self.inputs["Input"].meta.axistags.axisTypeCount(vigra.AxisType.Time)
            timeAxis=self.inputs["Input"].meta.axistags.index('t')

            subkey = popFlagsFromTheKey(key,axistags,'c')
            subshape=popFlagsFromTheKey(shape,axistags,'c')
            at2 = copy.copy(axistags)
            at2.dropChannelAxis()
            subshape=popFlagsFromTheKey(subshape,at2,'t')
            subkey = popFlagsFromTheKey(subkey,at2,'t')

            oldstart, oldstop = roi.sliceToRoi(key, shape)

            start, stop = roi.sliceToRoi(subkey,subkey)
            maxSigma = max(0.7,self.maxSigma)

            # The region of the smoothed image we need to give to the feature filter (in terms of INPUT coordinates)
            vigOpSourceStart, vigOpSourceStop = roi.extendSlice(start, stop, subshape, 0.7, window = 2)
            
            # The region of the input that we need to give to the smoothing operator (in terms of INPUT coordinates)
            newStart, newStop = roi.extendSlice(vigOpSourceStart, vigOpSourceStop, subshape, maxSigma, window = 3.5)

            # Translate coordinates (now in terms of smoothed image coordinates)
            vigOpSourceStart = roi.TinyVector(vigOpSourceStart - newStart)
            vigOpSourceStop = roi.TinyVector(vigOpSourceStop - newStart)

            readKey = roi.roiToSlice(newStart, newStop)


            writeNewStart = start - newStart
            writeNewStop = writeNewStart +  stop - start


            treadKey=list(readKey)

            if hasTimeAxis:
                if timeAxis < channelAxis:
                    treadKey.insert(timeAxis, key[timeAxis])
                else:
                    treadKey.insert(timeAxis-1, key[timeAxis])
            if  self.inputs["Input"].meta.axistags.axisTypeCount(vigra.AxisType.Channels) == 0:
                treadKey =  popFlagsFromTheKey(treadKey,axistags,'c')
            else:
                treadKey.insert(channelAxis, slice(None,None,None))

            treadKey=tuple(treadKey)

            req = self.inputs["Input"][treadKey].allocate()

            sourceArray = req.wait()
            req.result = None
            req.destination = None
            if sourceArray.dtype != numpy.float32:
                sourceArrayF = sourceArray.astype(numpy.float32)
                sourceArray.resize((1,), refcheck = False)
                del sourceArray
                sourceArray = sourceArrayF
            sourceArrayV = sourceArray.view(vigra.VigraArray)
            sourceArrayV.axistags =  copy.copy(axistags)





            dimCol = len(self.scales)
            dimRow = self.matrix.shape[0]


            sourceArraysForSigmas = [None]*dimCol

            #connect individual operators
            for j in range(dimCol):
                hasScale = False
                for i in range(dimRow):
                    if self.matrix[i,j]:
                        hasScale = True
                if not hasScale:
                    continue
                destSigma = 1.0
                if self.scales[j] > destSigma:
                    tempSigma = math.sqrt(self.scales[j]**2 - destSigma**2)
                else:
                    destSigma = 0.0
                    tempSigma = self.scales[j]
                vigOpSourceShape = list(vigOpSourceStop - vigOpSourceStart)
                if hasTimeAxis:

                    if timeAxis < channelAxis:
                        vigOpSourceShape.insert(timeAxis, ( oldstop - oldstart)[timeAxis])
                    else:
                        vigOpSourceShape.insert(timeAxis-1, ( oldstop - oldstart)[timeAxis])
                    vigOpSourceShape.insert(channelAxis, inShape[channelAxis])

                    sourceArraysForSigmas[j] = numpy.ndarray(tuple(vigOpSourceShape),numpy.float32)
                    for i,vsa in enumerate(sourceArrayV.timeIter()):
                        droi = (tuple(vigOpSourceStart._asint()), tuple(vigOpSourceStop._asint()))
                        tmp_key = getAllExceptAxis(len(sourceArraysForSigmas[j].shape),timeAxis, i)
                        sourceArraysForSigmas[j][tmp_key] = vigra.filters.gaussianSmoothing(vsa,tempSigma, roi = droi, window_size = 3.5 )
                else:
                    droi = (tuple(vigOpSourceStart._asint()), tuple(vigOpSourceStop._asint()))
                    #print droi, sourceArray.shape, tempSigma,self.scales[j]
                    sourceArraysForSigmas[j] = vigra.filters.gaussianSmoothing(sourceArrayV, sigma = tempSigma, roi = droi, window_size = 3.5)
                    #sourceArrayForSigma = sourceArrayForSigma.view(numpy.ndarray)

            del sourceArrayV
            try:
                sourceArray.resize((1,), refcheck = False)
            except ValueError:
                # Sometimes this fails, but that's okay.
                logger.debug("Failed to free array memory.")                
            del sourceArray

            
            closures = []

            #connect individual operators
            for i in range(dimRow):
                for j in range(dimCol):
                    val=self.matrix[i,j]
                    if val:
                        vop= self.featureOps[i][j]
                        oslot = vop.outputs["Output"]
                        req = None
                        inTagKeys = [ax.key for ax in oslot.meta.axistags]
                        if flag in inTagKeys:
                            slices = oslot.meta.shape[axisindex]
                            if cnt + slices >= rroi.start[axisindex] and rroi.start[axisindex]-cnt<slices and rroi.start[axisindex]+written<rroi.stop[axisindex]:
                                begin = 0
                                if cnt < rroi.start[axisindex]:
                                    begin = rroi.start[axisindex] - cnt
                                end = slices
                                if cnt + end > rroi.stop[axisindex]:
                                    end -= cnt + end - rroi.stop[axisindex]
                                key_ = copy.copy(oldkey)
                                key_.insert(axisindex, slice(begin, end, None))
                                reskey = [slice(None, None, None) for x in range(len(result.shape))]
                                reskey[axisindex] = slice(written, written+end-begin, None)

                                destArea = result[tuple(reskey)]
                                roi_ = SubRegion(self.Input, pslice=key_)                                
                                closure = partial(oslot.operator.execute, oslot, (), roi_, destArea, sourceArray = sourceArraysForSigmas[j])
                                closures.append(closure)

                                written += end - begin
                            cnt += slices
                        else:
                            if cnt>=rroi.start[axisindex] and rroi.start[axisindex] + written < rroi.stop[axisindex]:
                                reskey = copy.copy(oldkey)
                                reskey.insert(axisindex, written)
                                #print "key: ", key, "reskey: ", reskey, "oldkey: ", oldkey
                                #print "result: ", result.shape, "inslot:", inSlot.shape

                                destArea = result[tuple(reskey)]
                                logger.debug(oldkey, destArea.shape, sourceArraysForSigmas[j].shape)
                                oldroi = SubRegion(self.Input, pslice=oldkey)
                                closure = partial(oslot.operator.execute, oslot, (), oldroi, destArea, sourceArray = sourceArraysForSigmas[j])
                                closures.append(closure)

                                written += 1
                            cnt += 1
            pool = Pool()
            for c in closures:
                r = pool.request(c)
            pool.wait()
            pool.clean()

            for i in range(len(sourceArraysForSigmas)):
                if sourceArraysForSigmas[i] is not None:
                    try:
                        sourceArraysForSigmas[i].resize((1,))
                    except:
                        sourceArraysForSigmas[i] = None