Ejemplo n.º 1
0
    def execute(self, slot, subindex, rroi, result):
        assert slot == self.Features or slot == self.Output
        if slot == self.Features:
            key = roiToSlice(rroi.start, rroi.stop)
            index = subindex[0]
            key = list(key)
            channelIndex = self.Input.meta.axistags.index('c')
            
            # Translate channel slice to the correct location for the output slot.
            key[channelIndex] = slice(self.featureOutputChannels[index][0] + key[channelIndex].start,
                                      self.featureOutputChannels[index][0] + key[channelIndex].stop)
            rroi = SubRegion(self.Output, pslice=key)

            # Get output slot region for this channel
            return self.execute(self.Output, (), rroi, result)
        elif slot == self.outputs["Output"]:
            key = rroi.toSlice()
            
            logger.debug("OpPixelFeaturesPresmoothed: request %s" % (rroi.pprint(),))
            
            cnt = 0
            written = 0
            assert (rroi.stop<=self.outputs["Output"].meta.shape).all()
            flag = 'c'
            channelAxis=self.inputs["Input"].meta.axistags.index('c')
            axisindex = channelAxis
            oldkey = list(key)
            oldkey.pop(axisindex)


            inShape  = self.inputs["Input"].meta.shape
            hasChannelAxis = (self.Input.meta.axistags.axisTypeCount(vigra.AxisType.Channels) > 0)
            #if (self.Input.meta.axistags.axisTypeCount(vigra.AxisType.Channels) == 0):
            #    noChannels = True
            inAxistags = self.inputs["Input"].meta.axistags
                
            shape = self.outputs["Output"].meta.shape
            axistags = self.outputs["Output"].meta.axistags

            result = result.view(vigra.VigraArray)
            result.axistags = copy.copy(axistags)


            hasTimeAxis = self.inputs["Input"].meta.axistags.axisTypeCount(vigra.AxisType.Time)
            timeAxis=self.inputs["Input"].meta.axistags.index('t')

            subkey = popFlagsFromTheKey(key,axistags,'c')
            subshape=popFlagsFromTheKey(shape,axistags,'c')
            at2 = copy.copy(axistags)
            at2.dropChannelAxis()
            subshape=popFlagsFromTheKey(subshape,at2,'t')
            subkey = popFlagsFromTheKey(subkey,at2,'t')

            oldstart, oldstop = roi.sliceToRoi(key, shape)

            start, stop = roi.sliceToRoi(subkey,subkey)
            maxSigma = max(0.7,self.maxSigma)  #we use 0.7 as an approximation of not doing any smoothing
            #smoothing was already applied previously
            
            # The region of the smoothed image we need to give to the feature filter (in terms of INPUT coordinates)
            # 0.7, because the features receive a pre-smoothed array and don't need much of a neighborhood 
            vigOpSourceStart, vigOpSourceStop = roi.enlargeRoiForHalo(start, stop, subshape, 0.7, self.WINDOW_SIZE)
            
            
            # The region of the input that we need to give to the smoothing operator (in terms of INPUT coordinates)
            newStart, newStop = roi.enlargeRoiForHalo(vigOpSourceStart, vigOpSourceStop, subshape, maxSigma, self.WINDOW_SIZE)
            
            newStartSmoother = roi.TinyVector(start - vigOpSourceStart)
            newStopSmoother = roi.TinyVector(stop - vigOpSourceStart)
            roiSmoother = roi.roiToSlice(newStartSmoother, newStopSmoother)

            # Translate coordinates (now in terms of smoothed image coordinates)
            vigOpSourceStart = roi.TinyVector(vigOpSourceStart - newStart)
            vigOpSourceStop = roi.TinyVector(vigOpSourceStop - newStart)

            readKey = roi.roiToSlice(newStart, newStop)

            writeNewStart = start - newStart
            writeNewStop = writeNewStart +  stop - start

            treadKey=list(readKey)

            if hasTimeAxis:
                if timeAxis < channelAxis:
                    treadKey.insert(timeAxis, key[timeAxis])
                else:
                    treadKey.insert(timeAxis-1, key[timeAxis])
            if  self.inputs["Input"].meta.axistags.axisTypeCount(vigra.AxisType.Channels) == 0:
                treadKey =  popFlagsFromTheKey(treadKey,axistags,'c')
            else:
                treadKey.insert(channelAxis, slice(None,None,None))

            treadKey=tuple(treadKey)

            req = self.inputs["Input"][treadKey]
            
            sourceArray = req.wait()
            req.clean()
            #req.result = None
            req.destination = None
            if sourceArray.dtype != numpy.float32:
                sourceArrayF = sourceArray.astype(numpy.float32)
                try:
                    sourceArray.resize((1,), refcheck = False)
                except:
                    pass
                del sourceArray
                sourceArray = sourceArrayF
                
            #if (self.Input.meta.axistags.axisTypeCount(vigra.AxisType.Channels) == 0):
                #add a channel dimension to make the code afterwards more uniform
            #    sourceArray = sourceArray.view(numpy.ndarray)
            #    sourceArray = sourceArray.reshape(sourceArray.shape+(1,))
            sourceArrayV = sourceArray.view(vigra.VigraArray)
            sourceArrayV.axistags =  copy.copy(inAxistags)
            
            dimCol = len(self.scales)
            dimRow = self.matrix.shape[0]

            sourceArraysForSigmas = [None]*dimCol

            #connect individual operators
            try:
                for j in range(dimCol):
                    hasScale = False
                    for i in range(dimRow):
                        if self.matrix[i,j]:
                            hasScale = True
                    if not hasScale:
                        continue
                    destSigma = 1.0
                    if self.scales[j] > destSigma:
                        tempSigma = math.sqrt(self.scales[j]**2 - destSigma**2)
                    else:
                        destSigma = 0.0
                        tempSigma = self.scales[j]
                    vigOpSourceShape = list(vigOpSourceStop - vigOpSourceStart)
                                        
                    if hasTimeAxis:
                        if timeAxis < channelAxis:
                            vigOpSourceShape.insert(timeAxis, ( oldstop - oldstart)[timeAxis])
                        else:
                            vigOpSourceShape.insert(timeAxis-1, ( oldstop - oldstart)[timeAxis])
                        vigOpSourceShape.insert(channelAxis, inShape[channelAxis])
    
                        sourceArraysForSigmas[j] = numpy.ndarray(tuple(vigOpSourceShape),numpy.float32)
                        
                        for i,vsa in enumerate(sourceArrayV.timeIter()):
                            droi = (tuple(vigOpSourceStart._asint()), tuple(vigOpSourceStop._asint()))
                            tmp_key = getAllExceptAxis(len(sourceArraysForSigmas[j].shape),timeAxis, i) 
                            sourceArraysForSigmas[j][tmp_key] = self._computeGaussianSmoothing(vsa, tempSigma, droi)

                    else:
                        droi = (tuple(vigOpSourceStart._asint()), tuple(vigOpSourceStop._asint()))                            
                        sourceArraysForSigmas[j] = self._computeGaussianSmoothing(sourceArrayV, tempSigma, droi)
            
            except RuntimeError as e:
                if e.message.find('kernel longer than line') > -1:
                    message = "Feature computation error:\nYour image is too small to apply a filter with sigma=%.1f. Please select features with smaller sigmas." % self.scales[j]
                    raise RuntimeError(message)
                else:
                    raise e

            del sourceArrayV
            try:
                sourceArray.resize((1,), refcheck = False)
            except ValueError:
                # Sometimes this fails, but that's okay.
                logger.debug("Failed to free array memory.")                
            del sourceArray

            closures = []

            #connect individual operators
            for i in range(dimRow):
                for j in range(dimCol):
                    val=self.matrix[i,j]
                    if val:
                        vop= self.featureOps[i][j]
                        oslot = vop.outputs["Output"]
                        req = None
                        #inTagKeys = [ax.key for ax in oslot.meta.axistags]
                        #print inTagKeys, flag
                        if hasChannelAxis:
                            slices = oslot.meta.shape[axisindex]
                            if cnt + slices >= rroi.start[axisindex] and rroi.start[axisindex]-cnt<slices and rroi.start[axisindex]+written<rroi.stop[axisindex]:
                                begin = 0
                                if cnt < rroi.start[axisindex]:
                                    begin = rroi.start[axisindex] - cnt
                                end = slices
                                if cnt + end > rroi.stop[axisindex]:
                                    end -= cnt + end - rroi.stop[axisindex]
                                key_ = copy.copy(oldkey)
                                key_.insert(axisindex, slice(begin, end, None))
                                reskey = [slice(None, None, None) for x in range(len(result.shape))]
                                reskey[axisindex] = slice(written, written+end-begin, None)
                                
                                destArea = result[tuple(reskey)]
                                #readjust the roi for the new source array
                                roiSmootherList = list(roiSmoother)
                                
                                roiSmootherList.insert(axisindex, slice(begin, end, None))
                                
                                if hasTimeAxis:
                                    # The time slice in the ROI doesn't matter:
                                    # The sourceArrayParameter below overrides the input data to be used.
                                    roiSmootherList.insert(timeAxis, 0)
                                roiSmootherRegion = SubRegion(oslot, pslice=roiSmootherList)
                                
                                closure = partial(oslot.operator.execute, oslot, (), roiSmootherRegion, destArea, sourceArray = sourceArraysForSigmas[j])
                                closures.append(closure)

                                written += end - begin
                            cnt += slices
                        else:
                            if cnt>=rroi.start[axisindex] and rroi.start[axisindex] + written < rroi.stop[axisindex]:
                                reskey = [slice(None, None, None) for x in range(len(result.shape))]
                                slices = oslot.meta.shape[axisindex]
                                reskey[axisindex]=slice(written, written+slices, None)
                                #print "key: ", key, "reskey: ", reskey, "oldkey: ", oldkey, "resshape:", result.shape
                                #print "roiSmoother:", roiSmoother
                                destArea = result[tuple(reskey)]
                                #print "destination area:", destArea.shape
                                logger.debug(oldkey, destArea.shape, sourceArraysForSigmas[j].shape)
                                oldroi = SubRegion(oslot, pslice=oldkey)
                                #print "passing roi:", oldroi
                                closure = partial(oslot.operator.execute, oslot, (), oldroi, destArea, sourceArray = sourceArraysForSigmas[j])
                                closures.append(closure)

                                written += 1
                            cnt += 1
            pool = RequestPool()
            for c in closures:
                r = pool.request(c)
            pool.wait()
            pool.clean()

            for i in range(len(sourceArraysForSigmas)):
                if sourceArraysForSigmas[i] is not None:
                    try:
                        sourceArraysForSigmas[i].resize((1,))
                    except:
                        sourceArraysForSigmas[i] = None