class ClassImageDeconvMachine():
    def __init__(self, GD=None, ModelMachine=None, RefFreq=None, *args, **kw):
        self.GD = GD
        self.ModelMachine = ModelMachine
        self.RefFreq = RefFreq
        if self.ModelMachine.DicoModel["Type"] != "MUFFIN":
            raise ValueError("ModelMachine Type should be MUFFIN")
        self.MultiFreqMode = (self.GD["Freq"]["NBand"] > 1)

    def SetPSF(self, DicoVariablePSF):
        self.PSFServer = ClassPSFServer(self.GD)
        DicoVariablePSF = shared_dict.attach(
            DicoVariablePSF.path)  #["CubeVariablePSF"]
        self.PSFServer.setDicoVariablePSF(DicoVariablePSF)
        self.PSFServer.setRefFreq(self.ModelMachine.RefFreq)
        self.DicoVariablePSF = DicoVariablePSF
        self.setFreqs(self.PSFServer.DicoMappingDesc)

    def setMaskMachine(self, MaskMachine):
        self.MaskMachine = MaskMachine

    def setFreqs(self, DicoMappingDesc):
        self.DicoMappingDesc = DicoMappingDesc
        if self.DicoMappingDesc is None: return
        self.SpectralFunctionsMachine = ClassSpectralFunctions.ClassSpectralFunctions(
            self.DicoMappingDesc,
            RefFreq=self.DicoMappingDesc["RefFreq"])  #,BeamEnable=False)
        self.SpectralFunctionsMachine.CalcFluxBands()

    def GiveModelImage(self, *args):
        return self.ModelMachine.GiveModelImage(*args)

    def Update(self, DicoDirty, **kwargs):
        """
        Method to update attributes from ClassDeconvMachine
        """
        #Update image dict
        self.SetDirty(DicoDirty)

    def ToFile(self, fname):
        """
        Write model dict to file
        """
        self.ModelMachine.ToFile(fname)

    def FromFile(self, fname):
        """
        Read model dict from file SubtractModel
        """
        self.ModelMachine.FromFile(fname)

    def FromDico(self, DicoName):
        """
        Read in model dict
        """
        self.ModelMachine.FromDico(DicoName)

    def setSideLobeLevel(self, SideLobeLevel, OffsetSideLobe):
        self.SideLobeLevel = SideLobeLevel
        self.OffsetSideLobe = OffsetSideLobe

    def Init(self, **kwargs):
        self.SetPSF(kwargs["PSFVar"])
        if "PSFSideLobes" not in self.DicoVariablePSF.keys():
            self.DicoVariablePSF["PSFSideLobes"] = kwargs["PSFAve"]
        self.setSideLobeLevel(kwargs["PSFAve"][0], kwargs["PSFAve"][1])
        self.ModelMachine.setRefFreq(kwargs["RefFreq"])
        # store grid and degrid freqs for ease of passing to MSMF
        #print kwargs["GridFreqs"],kwargs["DegridFreqs"]
        self.GridFreqs = kwargs["GridFreqs"]
        self.DegridFreqs = kwargs["DegridFreqs"]
        self.ModelMachine.setFreqMachine(kwargs["GridFreqs"],
                                         kwargs["DegridFreqs"])

    def SetDirty(self, DicoDirty):
        self.DicoDirty = DicoDirty
        self._Dirty = self.DicoDirty["ImageCube"]
        self._MeanDirty = self.DicoDirty["MeanImage"]
        NPSF = self.PSFServer.NPSF
        _, _, NDirty, _ = self._Dirty.shape
        off = (NPSF - NDirty) / 2
        self.DirtyExtent = (off, off + NDirty, off, off + NDirty)
        self.ModelMachine.setModelShape(self._Dirty.shape)

    def AdaptArrayShape(self, A, Nout):
        nch, npol, Nin, _ = A.shape
        if Nin == Nout:
            return A
        elif Nin > Nout:
            # dx=Nout/2
            # B=np.zeros((nch,npol,Nout,Nout),A.dtype)
            # print>>log,"  Adapt shapes: %s -> %s"%(str(A.shape),str(B.shape))
            # B[:]=A[...,Nin/2-dx:Nin/2+dx+1,Nin/2-dx:Nin/2+dx+1]

            N0 = A.shape[-1]
            xc0 = yc0 = N0 / 2
            N1 = Nout
            xc1 = yc1 = N1 / 2
            Aedge, Bedge = GiveEdges(xc0, yc0, N0, xc1, yc1, N1)
            x0d, x1d, y0d, y1d = Aedge
            x0p, x1p, y0p, y1p = Bedge
            B = A[..., x0d:x1d, y0d:y1d]

            return B
        else:
            stop
            return None

    def updateModelMachine(self, ModelMachine):
        self.ModelMachine = ModelMachine
        if self.ModelMachine.RefFreq != self.RefFreq:
            raise ValueError("freqs should be equal")

    def updateMask(self, Mask):
        nx, ny = Mask.shape
        self._MaskArray = np.zeros((1, 1, nx, ny), np.bool8)
        self._MaskArray[0, 0, :, :] = Mask[:, :]

    def Deconvolve(self):

        if self._Dirty.shape[-1] != self._Dirty.shape[-2]:
            # print "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!"
            # print self._Dirty.shape
            # print "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!"
            return "MaxIter", True, True

        dirty = self._Dirty
        nch, npol, nx, ny = dirty.shape
        Model = np.zeros_like(dirty)

        _, _, xp, yp = np.where(self._MeanDirty == np.max(self._MeanDirty))
        self.PSFServer.setLocation(xp, yp)
        self.iFacet = self.PSFServer.iFacet
        psf, _ = self.PSFServer.GivePSF()
        nxPSF = psf.shape[-1]
        nxDirty = dirty.shape[-1]

        Nout = np.min([dirty.shape[-1], psf.shape[-1]])
        dirty = self.AdaptArrayShape(dirty, Nout)
        SliceDirty = slice(0, None)
        if dirty.shape[-1] % 2 != 0:
            SliceDirty = slice(0, -1)

        d = dirty[:, :, SliceDirty, SliceDirty]
        psf = self.AdaptArrayShape(psf, d.shape[-1])

        SlicePSF = slice(0, None)
        if psf.shape[-1] % 2 != 0:
            SlicePSF = slice(0, -1)

        p = psf[:, :, SlicePSF, SlicePSF]

        dirty_MUFFIN = np.squeeze(d[:, 0, :, :])
        dirty_MUFFIN = dirty_MUFFIN.transpose((2, 1, 0))

        psf_MUFFIN = np.squeeze(p[:, 0, :, :])
        psf_MUFFIN = psf_MUFFIN.transpose((2, 1, 0))

        EM = EasyMuffin(mu_s=self.GD['MUFFIN']['mu_s'],
                        mu_l=self.GD['MUFFIN']['mu_l'],
                        nb=self.GD['MUFFIN']['nb'],
                        truesky=dirty_MUFFIN,
                        psf=psf_MUFFIN,
                        dirty=dirty_MUFFIN)
        EM.loop(nitermax=self.GD['MUFFIN']['NMinorIter'])

        nxModel = dirty_MUFFIN.shape[0]
        Aedge, Bedge = GiveEdges(nxModel // 2, nxModel // 2, nxModel,
                                 nxDirty // 2, nxDirty // 2, nxDirty)
        x0, x1, y0, y1 = Bedge

        Model = np.zeros((nxDirty, nxDirty, nch))
        Model[x0:x1, y0:y1, :] = EM.x
        self.ModelMachine.setMUFFINModel(Model)

        # if self._Dirty.shape[-1]!=self._Dirty.shape[-2]:
        #     # print "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!"
        #     # print self._Dirty.shape
        #     # print "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!"
        #     return "MaxIter", True, True

        # dirty=self._Dirty
        # nch,npol,_,_=dirty.shape
        # Model=np.zeros_like(dirty)

        # _,_,xp,yp=np.where(self._MeanDirty==np.max(self._MeanDirty))
        # self.PSFServer.setLocation(xp,yp)
        # self.iFacet=self.PSFServer.iFacet

        # psf,_=self.PSFServer.GivePSF()

        # Nout=np.min([dirty.shape[-1],psf.shape[-1]])
        # dirty=self.AdaptArrayShape(dirty,Nout)
        # SliceDirty=slice(0,None)
        # if dirty.shape[-1]%2!=0:
        #     SliceDirty=slice(0,-1)

        # d=dirty[:,:,SliceDirty,SliceDirty]
        # psf=self.AdaptArrayShape(psf,d.shape[-1]*2)

        # SlicePSF=slice(0,None)
        # if psf.shape[-1]%2!=0:
        #     SlicePSF=slice(0,-1)

        # p=psf[:,:,SlicePSF,SlicePSF]
        # if p.shape[-1]!=2*d.shape[-1]:
        #     print "!!!!!!!!!!!!!!!!!!!!!!!!!"
        #     print "Could not adapt psf shape to 2*dirty shape!!!!!!!!!!!!!!!!!!!!!!!!!"
        #     print p.shape[-1],d.shape[-1]
        #     print "!!!!!!!!!!!!!!!!!!!!!!!!!"
        #     psf=self.AdaptArrayShape(psf,d.shape[-1])
        #     SlicePSF=SliceDirty

        # for ch in range(nch):
        #     CM=ClassMoresaneSingleSlice(dirty[ch,0,SliceDirty,SliceDirty],psf[ch,0,SlicePSF,SlicePSF],mask=None,GD=None)
        #     model,resid=CM.giveModelResid(major_loop_miter=self.GD["MORESANE"]["NMajorIter"],
        #                                   minor_loop_miter=self.GD["MORESANE"]["NMinorIter"],
        #                                   loop_gain=self.GD["MORESANE"]["Gain"],
        #                                   sigma_level=self.GD["MORESANE"]["SigmaCutLevel"],# tolerance=1.,
        #                                   enforce_positivity=self.GD["MORESANE"]["ForcePositive"])
        #     Model[ch,0,SliceDirty,SliceDirty]=model[:,:]

        #     import pylab
        #     pylab.clf()
        #     pylab.subplot(2,2,1)
        #     pylab.imshow(dirty[ch,0,SliceDirty,SliceDirty],interpolation="nearest")
        #     pylab.colorbar()

        #     pylab.subplot(2,2,2)
        #     pylab.imshow(psf[ch,0,SlicePSF,SlicePSF],interpolation="nearest")
        #     pylab.colorbar()

        #     pylab.subplot(2,2,3)
        #     pylab.imshow(model,interpolation="nearest")
        #     pylab.colorbar()

        #     pylab.subplot(2,2,4)
        #     pylab.imshow(resid,interpolation="nearest")
        #     pylab.colorbar()

        #     pylab.draw()
        #     pylab.show()

        # print
        # print np.max(np.max(Model,axis=-1),axis=-1)
        # print
        # print

        #_,_,nx,ny=Model.shape
        #Model=np.mean(Model,axis=0).reshape((1,1,nx,ny))

        #Model.fill(0)
        #Model[:,:,xp,yp]=self._Dirty[:,:,xp,yp]

        return "MaxIter", True, True  # stop deconvolution but do update model
class ClassImageDeconvMachine():
    """
    Currently constructor inputs match those in MSMF, should figure out which are truly generic and put the rest in
    parset's MinorCycleConfig option.
    These methods may be called from ClassDeconvMachine
        Init(**kwargs) - contains minor cycle specific initialisations which are only used once
            Input: currently kwargs are minor cycle specific and should be set from ClassDeconvMachine but a
                     ideally a generic interface has these set in the parset somehow.
        Deconvolve() - does joint deconvolution over all the channels/bands.
            Output: return_code - "MaxIter"????
                    continue - whether to continue the deconvolution
                    updated - whether the model has been updated
        GiveModelImage(freq) - returns current model at freq
            Input: freq - tuple of frequencies at which to return the model
            Output: Mod - the current model at freq
        Update(DicoDirty,**kwargs) - updates to minor cycle at the end of each major cycle
            Input:  DicoDirty - updated image dict at start of each major cycle
                    Use kwargs to pass any other minor cycle specific options
        ToFile(fname) - saves dico model to file
            Input: fname - the name of the file to write the dico image to
        FromFile(fname) - reads model dict from file
            Input: fname - the name of the file to write the dico image to
    """
    def __init__(
            self,
            Gain=0.3,
            MaxMinorIter=100,
            NCPU=6,
            CycleFactor=2.5,
            FluxThreshold=None,
            RMSFactor=3,
            PeakFactor=0,
            GD=None,
            SearchMaxAbs=1,
            CleanMaskImage=None,
            ImagePolDescriptor=["I"],
            ModelMachine=None,
            **kw  # absorb any unknown keywords arguments into this
    ):
        self.SearchMaxAbs = SearchMaxAbs
        self.ModelImage = None
        self.MaxMinorIter = MaxMinorIter
        self.NCPU = NCPU
        self.MaskArray = None
        self.GD = GD
        self.MultiFreqMode = (self.GD["Freq"]["NBand"] > 1)
        self.NFreqBand = self.GD["Freq"]["NBand"]
        self.FluxThreshold = FluxThreshold
        self.CycleFactor = CycleFactor
        self.RMSFactor = RMSFactor
        self.PeakFactor = PeakFactor
        self.GainMachine = ClassGainMachine.ClassGainMachine(GainMin=Gain)
        if ModelMachine is None:
            import ClassModelMachineHogbom as ClassModelMachine
            self.ModelMachine = ClassModelMachine.ClassModelMachine(
                self.GD, GainMachine=self.GainMachine)
        else:
            self.ModelMachine = ModelMachine
        self.GainMachine = self.ModelMachine.GainMachine
        self.GiveEdges = GiveEdges.GiveEdges
        self._niter = 0
        if CleanMaskImage is not None:
            print >> log, "Reading mask image: %s" % CleanMaskImage
            MaskArray = image(CleanMaskImage).getdata()
            nch, npol, _, _ = MaskArray.shape
            self._MaskArray = np.zeros(MaskArray.shape, np.bool8)
            for ch in range(nch):
                for pol in range(npol):
                    self._MaskArray[ch, pol, :, :] = np.bool8(
                        1 - MaskArray[ch, pol].T[::-1].copy())[:, :]
            self.MaskArray = self._MaskArray[0]
        self._peakMode = "normal"

        self.CurrentNegMask = None
        self._NoiseMap = None
        self._PNRStop = None  # in _peakMode "sigma", provides addiitonal stopping criterion

    def Init(self, **kwargs):
        self.SetPSF(kwargs["PSFVar"])
        self.setSideLobeLevel(kwargs["PSFAve"][0], kwargs["PSFAve"][1])
        self.SetModelRefFreq(kwargs["RefFreq"])
        self.ModelMachine.setFreqMachine(kwargs["GridFreqs"],
                                         kwargs["DegridFreqs"])
        self.Freqs = kwargs["GridFreqs"]

    def Reset(self):
        pass

    def setMaskMachine(self, MaskMachine):
        self.MaskMachine = MaskMachine

    def SetModelRefFreq(self, RefFreq):
        """
        Sets ref freq in ModelMachine.
        """
        AllFreqs = []
        AllFreqsMean = np.zeros((self.NFreqBand, ), np.float32)
        for iChannel in range(self.NFreqBand):
            AllFreqs += self.DicoVariablePSF["freqs"][iChannel]
            AllFreqsMean[iChannel] = np.mean(
                self.DicoVariablePSF["freqs"][iChannel])
        #assume that the frequency variance is somewhat the same in all the stokes images:
        #RefFreq = np.sum(AllFreqsMean.ravel() * np.mean(self.DicoVariablePSF["WeightChansImages"],axis=1).ravel())
        self.ModelMachine.setRefFreq(RefFreq)

    def SetModelShape(self):
        """
        Sets the shape params of model, call in every update step
        """
        self.ModelMachine.setModelShape(self._Dirty.shape)

    def GiveModelImage(self, *args):
        return self.ModelMachine.GiveModelImage(*args)

    def setSideLobeLevel(self, SideLobeLevel, OffsetSideLobe):
        self.SideLobeLevel = SideLobeLevel
        self.OffsetSideLobe = OffsetSideLobe

    def SetPSF(self, DicoVariablePSF):
        self.PSFServer = ClassPSFServer(self.GD)
        self.PSFServer.setDicoVariablePSF(DicoVariablePSF)
        self.DicoVariablePSF = DicoVariablePSF

    def setNoiseMap(self, NoiseMap, PNRStop=10):
        """Sets the noise map. The mean dirty will be divided by the noise map before peak finding.
        If PNRStop is set, an additional stopping criterion (peak-to-noisemap) will be applied.
            Peaks are reported in units of sigmas.
        If PNRStop is not set, NoiseMap is treated as simply an (inverse) weighting that will bias
            peak selection in the minor cycle. In this mode, peaks are reported in units of flux.
        """
        self._NoiseMap = NoiseMap
        self._PNRStop = PNRStop
        self._peakMode = "sigma"

    def SetDirty(self, DicoDirty):
        self.DicoDirty = DicoDirty
        self._Dirty = self.DicoDirty["ImageCube"]
        self._MeanDirty = self.DicoDirty["MeanImage"]

        NPSF = self.PSFServer.NPSF
        _, _, NDirty, _ = self._Dirty.shape

        off = (NPSF - NDirty) / 2
        self.DirtyExtent = (off, off + NDirty, off, off + NDirty)

        if self._peakMode is "sigma":
            print >> log, "Will search for the peak in the SNR-weighted dirty map"
            a, b = self._MeanDirty, self._NoiseMap.reshape(
                self._MeanDirty.shape)
            self._PeakSearchImage = numexpr.evaluate("a/b")
        # elif self._peakMode is "weighted":   ######## will need to get a PeakWeightImage from somewhere for this option
        #     print>> log, "Will search for the peak in the weighted dirty map"
        #     a, b = self._MeanDirty, self._peakWeightImage
        #     self._PeakSearchImage = numexpr.evaluate("a*b")
        else:
            print >> log, "Will search for the peak in the unweighted dirty map"
            self._PeakSearchImage = self._MeanDirty

        if self.ModelImage is None:
            self._ModelImage = np.zeros_like(self._Dirty)
        if self.MaskArray is None:
            self._MaskArray = np.zeros(self._Dirty.shape, dtype=np.bool8)

    def SubStep(self, (dx, dy), LocalSM):
        """
        This is where subtraction in the image domain happens
        """
        npol, _, _ = self.Dirty.shape
        x0, x1, y0, y1 = self.DirtyExtent

        xc, yc = dx, dy
        N0 = self.Dirty.shape[-1]
        N1 = LocalSM.shape[-1]

        #Get overlap indices where psf should be subtracted
        Aedge, Bedge = self.GiveEdges((xc, yc), N0, (N1 / 2, N1 / 2), N1)

        x0d, x1d, y0d, y1d = Aedge
        x0p, x1p, y0p, y1p = Bedge

        #Subtract from each channel/band
        self._Dirty[:, :, x0d:x1d, y0d:y1d] -= LocalSM[:, :, x0p:x1p, y0p:y1p]
        #Subtract from the average
        if self.MultiFreqMode:  #If multiple frequencies are present construct the weighted mean
            W = np.mean(
                np.float32(self.DicoDirty["WeightChansImages"]), axis=1
            )  #Get the weights (assuming they stay relatively the same over stokes terms)
            self._MeanDirty[0, :, x0d:x1d, y0d:y1d] -= np.sum(
                LocalSM[:, :, x0p:x1p, y0p:y1p] * W.reshape((W.size, 1, 1, 1)),
                axis=0)  #Sum over frequency
Ejemplo n.º 3
0
class ClassImageDeconvMachine():
    """
    Currently constructor inputs match those in MSMF, should figure out which are truly generic and put the rest in
    parset's MinorCycleConfig option.
    These methods may be called from ClassDeconvMachine
        Init(**kwargs) - contains minor cycle specific initialisations which are only used once
            Input: currently kwargs are minor cycle specific and should be set from ClassDeconvMachine but a
                     ideally a generic interface has these set in the parset somehow.
        Deconvolve() - does joint deconvolution over all the channels/bands.
            Output: return_code - "MaxIter"????
                    continue - whether to continue the deconvolution
                    updated - whether the model has been updated
        GiveModelImage(freq) - returns current model at freq
            Input: freq - tuple of frequencies at which to return the model
            Output: Mod - the current model at freq
        Update(DicoDirty,**kwargs) - updates to minor cycle at the end of each major cycle
            Input:  DicoDirty - updated image dict at start of each major cycle
                    Use kwargs to pass any other minor cycle specific options
        ToFile(fname) - saves dico model to file
            Input: fname - the name of the file to write the dico image to
        FromFile(fname) - reads model dict from file
            Input: fname - the name of the file to write the dico image to
    """
    def __init__(
            self,
            Gain=0.3,
            MaxMinorIter=100,
            NCPU=6,
            CycleFactor=2.5,
            FluxThreshold=None,
            RMSFactor=3,
            PeakFactor=0,
            GD=None,
            SearchMaxAbs=1,
            CleanMaskImage=None,
            ImagePolDescriptor=["I"],
            ModelMachine=None,
            **kw  # absorb any unknown keywords arguments into this
    ):
        self.SearchMaxAbs = SearchMaxAbs
        self.ModelImage = None
        self.MaxMinorIter = MaxMinorIter
        self.NCPU = NCPU
        self.MaskArray = None
        self.GD = GD
        self.MultiFreqMode = (self.GD["Freq"]["NBand"] > 1)
        self.NFreqBand = self.GD["Freq"]["NBand"]
        self.FluxThreshold = FluxThreshold
        self.CycleFactor = CycleFactor
        self.RMSFactor = RMSFactor
        self.PeakFactor = PeakFactor
        self.GainMachine = ClassGainMachine.get_instance()
        if ModelMachine is None:
            from DDFacet.Imager.HOGBOM import ClassModelMachineHogbom as ClassModelMachine
            self.ModelMachine = ClassModelMachine.ClassModelMachine(
                self.GD, GainMachine=self.GainMachine)
        else:
            self.ModelMachine = ModelMachine
        self.GiveEdges = GiveEdges.GiveEdges
        self._niter = 0
        self._peakMode = "normal"

        self.CurrentNegMask = None
        self._NoiseMap = None
        self._PNRStop = None  # in _peakMode "sigma", provides addiitonal stopping criterion

        numexpr.set_num_threads(self.NCPU)

    def Init(self, **kwargs):
        self.SetPSF(kwargs["PSFVar"])
        self.setSideLobeLevel(kwargs["PSFAve"][0], kwargs["PSFAve"][1])
        self.Freqs = kwargs["GridFreqs"]
        AllDegridFreqs = []
        for i in kwargs["DegridFreqs"].keys():
            AllDegridFreqs.append(kwargs["DegridFreqs"][i])
        self.Freqs_degrid = np.asarray(AllDegridFreqs).flatten()
        self.SetPSF(kwargs["PSFVar"])
        self.setSideLobeLevel(kwargs["PSFAve"][0], kwargs["PSFAve"][1])
        self.ModelMachine.setPSFServer(self.PSFServer)
        self.ModelMachine.setFreqMachine(
            self.Freqs,
            self.Freqs_degrid,
            weights=kwargs["PSFVar"]["WeightChansImages"],
            PSFServer=self.PSFServer)

    def Reset(self):
        pass

    def setMaskMachine(self, MaskMachine):
        self.MaskMachine = MaskMachine
        if self.MaskMachine.ExternalMask is not None:
            print("Applying external mask", file=log)
            MaskArray = self.MaskMachine.ExternalMask
            nch, npol, _, _ = MaskArray.shape
            self._MaskArray = np.zeros(MaskArray.shape, np.bool8)
            for ch in range(nch):
                for pol in range(npol):
                    self._MaskArray[ch, pol, :, :] = np.bool8(
                        1 - MaskArray[ch, pol].copy())[:, :]
            self._MaskArray = np.ascontiguousarray(self._MaskArray)
            self.MaskArray = np.ascontiguousarray(self._MaskArray[0])

    def SetModelRefFreq(self, RefFreq):
        """
        Sets ref freq in ModelMachine.
        """
        AllFreqs = []
        AllFreqsMean = np.zeros((self.NFreqBand, ), np.float32)
        for iChannel in range(self.NFreqBand):
            AllFreqs += self.DicoVariablePSF["freqs"][iChannel]
            AllFreqsMean[iChannel] = np.mean(
                self.DicoVariablePSF["freqs"][iChannel])
        #assume that the frequency variance is somewhat the same in all the stokes images:
        #RefFreq = np.sum(AllFreqsMean.ravel() * np.mean(self.DicoVariablePSF["WeightChansImages"],axis=1).ravel())
        self.ModelMachine.setRefFreq(RefFreq)

    def SetModelShape(self):
        """
        Sets the shape params of model, call in every update step
        """
        self.ModelMachine.setModelShape(self._Dirty.shape)

    def GiveModelImage(self, *args):
        return self.ModelMachine.GiveModelImage(*args)

    def setSideLobeLevel(self, SideLobeLevel, OffsetSideLobe):
        self.SideLobeLevel = SideLobeLevel
        self.OffsetSideLobe = OffsetSideLobe

    def SetPSF(self, DicoVariablePSF):
        self.PSFServer = ClassPSFServer(self.GD)
        self.PSFServer.setDicoVariablePSF(DicoVariablePSF, NormalisePSF=True)
        self.DicoVariablePSF = DicoVariablePSF

    def setNoiseMap(self, NoiseMap, PNRStop=10):
        """Sets the noise map. The mean dirty will be divided by the noise map before peak finding.
        If PNRStop is set, an additional stopping criterion (peak-to-noisemap) will be applied.
            Peaks are reported in units of sigmas.
        If PNRStop is not set, NoiseMap is treated as simply an (inverse) weighting that will bias
            peak selection in the minor cycle. In this mode, peaks are reported in units of flux.
        """
        self._NoiseMap = NoiseMap
        self._PNRStop = PNRStop
        self._peakMode = "sigma"

    def SetDirty(self, DicoDirty):
        self.DicoDirty = DicoDirty
        self.WeightsChansImages = DicoDirty["WeightChansImages"].squeeze()
        self._Dirty = self.DicoDirty["ImageCube"]
        self._MeanDirty = self.DicoDirty["MeanImage"]

        self.NpixPSF = self.PSFServer.NPSF
        self.Nchan, self.Npol, self.Npix, _ = self._Dirty.shape

        # if self._peakMode is "sigma":
        #     print("Will search for the peak in the SNR-weighted dirty map", file=log)
        #     a, b = self._MeanDirty, self._NoiseMap.reshape(self._MeanDirty.shape)
        #     self._PeakSearchImage = numexpr.evaluate("a/b")
        # # elif self._peakMode is "weighted":   ######## will need to get a PeakWeightImage from somewhere for this option
        # #     print("Will search for the peak in the weighted dirty map", file=log)
        # #     a, b = self._MeanDirty, self._peakWeightImage
        # #     self._PeakSearchImage = numexpr.evaluate("a*b")
        # else:
        #     print("Will search for the peak in the unweighted dirty map", file=log)
        #     self._PeakSearchImage = self._MeanDirty

        if self.ModelImage is None:
            self._ModelImage = np.zeros_like(self._Dirty)
        if self.MaskArray is None:
            self._MaskArray = np.zeros(self._Dirty.shape, dtype=np.bool8)

    def SubStep(self, xc, yc, LocalSM):
        """
        This is where subtraction in the image domain happens
        
        Parameters
        ----------
        (xc, yc) - The location of the component
        LocalSM - array of shape (nchan, npol, nx, ny)
                  Local Sky Model = comp * PSF * gain where the PSF should be
                  normalised to unity at the center.
        """
        #Get overlap indices where psf should be subtracted
        Aedge, Bedge = self.GiveEdges(xc, yc, self.Npix, self.NpixPSF // 2,
                                      self.NpixPSF // 2, self.NpixPSF)

        x0d, x1d, y0d, y1d = Aedge
        x0p, x1p, y0p, y1p = Bedge

        cube, sm = self._Dirty[:, :, x0d:x1d, y0d:y1d], LocalSM[:, :, x0p:x1p,
                                                                y0p:y1p]
        numexpr.evaluate('cube-sm', out=cube, casting="unsafe")

        #Subtract from each channel/band
        # self._Dirty[:,:,x0d:x1d,y0d:y1d]-=LocalSM[:,:,x0p:x1p,y0p:y1p]
        # If multiple frequencies are present construct the weighted mean
        meanimage = self._MeanDirty[:, :, x0d:x1d, y0d:y1d]
        if self.MultiFreqMode:
            W = self.WeightsChansImages.reshape((self.Nchan, 1, 1, 1))
            meanimage[...] = (cube * W).sum(axis=0)  #Sum over frequency
        else:
            meanimage[0, ...] = cube[0, ...]

    def Deconvolve(self, **kwargs):
        """
        Runs minor cycle over image channel 'ch'.
        initMinor is number of minor iteration (keeps continuous count through major iterations)
        Nminor is max number of minor iterations

        Returns tuple of: return_code,continue,updated
        where return_code is a status string;
        continue is True if another cycle should be executed (one or more polarizations still need cleaning);
        update is True if one or more polarization models have been updated
        """
        exit_msg = ""
        continue_deconvolution = False
        update_model = False

        # # Get the PeakMap (first index will always be 0 because we only support I cleaning)
        PeakMap = self._MeanDirty[0, 0, :, :]

        #These options should probably be moved into MinorCycleConfig in parset
        DoAbs = int(self.GD["Deconv"]["AllowNegative"])
        print("  Running minor cycle [MinorIter = %i/%i, SearchMaxAbs = %i]" %
              (self._niter, self.MaxMinorIter, DoAbs),
              file=log)

        ## Determine which stopping criterion to use for flux limit
        #Get RMS stopping criterion
        NPixStats = self.GD["Deconv"]["NumRMSSamples"]
        if NPixStats:
            RandomInd = np.int64(np.random.rand(NPixStats) * self.Npix**2)
            RMS = np.std(np.real(PeakMap.ravel()[RandomInd]))
        else:
            RMS = np.std(PeakMap)

        self.RMS = RMS

        Fluxlimit_RMS = self.RMSFactor * RMS

        # Find position and intensity of first peak
        x, y, MaxDirty = NpParallel.A_whereMax(PeakMap,
                                               NCPU=self.NCPU,
                                               DoAbs=DoAbs,
                                               Mask=self.MaskArray)

        # Get peak factor stopping criterion
        Fluxlimit_Peak = MaxDirty * self.PeakFactor

        # Get side lobe stopping criterion
        Fluxlimit_Sidelobe = (
            (self.CycleFactor - 1.) / 4. * (1. - self.SideLobeLevel) +
            self.SideLobeLevel) * MaxDirty if self.CycleFactor else 0

        mm0, mm1 = PeakMap.min(), PeakMap.max()

        # Choose whichever threshold is highest
        StopFlux = max(Fluxlimit_Peak, Fluxlimit_RMS, Fluxlimit_Sidelobe,
                       self.FluxThreshold)

        print(
            "    Dirty image peak flux      = %10.6f Jy [(min, max) = (%.3g, %.3g) Jy]"
            % (MaxDirty, mm0, mm1),
            file=log)
        print(
            "      RMS-based threshold      = %10.6f Jy [rms = %.3g Jy; RMS factor %.1f]"
            % (Fluxlimit_RMS, RMS, self.RMSFactor),
            file=log)
        print(
            "      Sidelobe-based threshold = %10.6f Jy [sidelobe  = %.3f of peak; cycle factor %.1f]"
            % (Fluxlimit_Sidelobe, self.SideLobeLevel, self.CycleFactor),
            file=log)
        print("      Peak-based threshold     = %10.6f Jy [%.3f of peak]" %
              (Fluxlimit_Peak, self.PeakFactor),
              file=log)
        print("      Absolute threshold       = %10.6f Jy" %
              (self.FluxThreshold),
              file=log)
        print("    Stopping flux              = %10.6f Jy [%.3f of peak ]" %
              (StopFlux, StopFlux / MaxDirty),
              file=log)

        T = ClassTimeIt.ClassTimeIt()
        T.disable()

        ThisFlux = MaxDirty

        if ThisFlux < StopFlux:
            print(ModColor.Str(
                "    Initial maximum peak %g Jy below threshold, we're done CLEANing"
                % (ThisFlux),
                col="green"),
                  file=log)
            exit_msg = exit_msg + " " + "FluxThreshold"
            continue_deconvolution = False or continue_deconvolution
            update_model = False or update_model
            # No need to do anything further if we are already at the stopping flux
            return exit_msg, continue_deconvolution, update_model

        #Do minor cycle deconvolution loop
        try:
            for i in range(self._niter + 1, self.MaxMinorIter + 1):
                self._niter = i
                #grab a new peakmap
                PeakMap = self._MeanDirty[0, 0, :, :]

                x, y, ThisFlux = NpParallel.A_whereMax(PeakMap,
                                                       NCPU=self.NCPU,
                                                       DoAbs=DoAbs,
                                                       Mask=self.MaskArray)

                T.timeit("max0")

                if ThisFlux <= StopFlux:
                    print(ModColor.Str(
                        "    CLEANing [iter=%i] peak of %.3g Jy lower than stopping flux"
                        % (i, ThisFlux),
                        col="green"),
                          file=log)
                    cont = ThisFlux > self.FluxThreshold
                    if not cont:
                        print(ModColor.Str(
                            "    CLEANing [iter=%i] absolute flux threshold of %.3g Jy has been reached"
                            % (i, self.FluxThreshold),
                            col="green",
                            Bold=True),
                              file=log)
                    exit_msg = exit_msg + " " + "MinFluxRms"
                    continue_deconvolution = cont or continue_deconvolution
                    update_model = True or update_model

                    break  # stop cleaning if threshold reached

                # This is used to track Cleaning progress
                rounded_iter_step = 1 if i < 10 else (10 if i < 200 else (
                    100 if i < 2000 else 1000))
                # min(int(10**math.floor(math.log10(i))), 10000)
                if i >= 10 and i % rounded_iter_step == 0:
                    # if self.GD["Debug"]["PrintMinorCycleRMS"]:
                    #rms = np.std(np.real(self._CubeDirty.ravel()[self.IndStats]))
                    print("    [iter=%i] peak residual %.3g" % (i, ThisFlux),
                          file=log)

                # Find PSF corresponding to location (x,y)
                self.PSFServer.setLocation(
                    x, y)  # Selects the facet closest to (x,y)

                # Get the JonesNorm
                JonesNorm = self.DicoDirty["JonesNorm"][:, 0, x, y]

                # Get the solution (division by JonesNorm handled in fit)
                Iapp = self._Dirty[:, 0, x, y]

                # Fit a polynomial to get coeffs
                Coeffs = self.ModelMachine.FreqMachine.Fit(
                    Iapp, JonesNorm, self.WeightsChansImages)

                # Overwrite with polynoimial fit
                Iapp = self.ModelMachine.FreqMachine.Eval(Coeffs)
                T.timeit("stuff")

                PSF, meanPSF = self.PSFServer.GivePSF()  #Gives associated PSF

                T.timeit("FindScale")

                #Update model
                self.ModelMachine.AppendComponentToDictStacked((x, y), Coeffs)

                # Subtract LocalSM*CurrentGain from dirty image
                self.SubStep(
                    x, y, PSF * Iapp[:, None, None, None] *
                    self.GD["Deconv"]["Gain"])

                T.timeit("SubStep")

                T.timeit("End")

        except KeyboardInterrupt:
            print(ModColor.Str(
                "    CLEANing [iter=%i] minor cycle interrupted with Ctrl+C, peak flux %.3g"
                % (self._niter, ThisFlux)),
                  file=log)
            exit_msg = exit_msg + " " + "MaxIter"
            continue_deconvolution = False or continue_deconvolution
            update_model = True or update_model
            return exit_msg, continue_deconvolution, update_model

        if self._niter >= self.MaxMinorIter:  #Reached maximum number of iterations:
            print(ModColor.Str(
                "    CLEANing [iter=%i] Reached maximum number of iterations, peak flux %.3g"
                % (self._niter, ThisFlux)),
                  file=log)
            exit_msg = exit_msg + " " + "MaxIter"
            continue_deconvolution = False or continue_deconvolution
            update_model = True or update_model

        return exit_msg, continue_deconvolution, update_model

    def Update(self, DicoDirty, **kwargs):
        """
        Method to update attributes from ClassDeconvMachine
        """
        #Update image dict
        self.SetDirty(DicoDirty)
        #self.SetModelRefFreq()
        self.SetModelShape()

    def ToFile(self, fname):
        """
        Method to write model image to file
        """
        self.ModelMachine.ToFile(fname)

    def FromFile(self, fname):
        """
        Read model dict from file SubtractModel
        """
        self.ModelMachine.FromFile(fname)

    def updateRMS(self):
        _, npol, npix, _ = self._MeanDirty.shape
        NPixStats = self.GD["Deconv"]["NumRMSSamples"]
        if NPixStats:
            #self.IndStats=np.int64(np.random.rand(NPixStats)*npix**2)
            self.IndStats = np.int64(
                np.linspace(0, self._PeakSearchImage.size - 1, NPixStats))
        else:
            self.IndStats = slice(None)
        self.RMS = np.std(np.real(
            self._PeakSearchImage.ravel()[self.IndStats]))

    def resetCounter(self):
        self._niter = 0
Ejemplo n.º 4
0
class ClassImageDeconvMachine():
    def __init__(
            self,
            Gain=0.3,
            MaxMinorIter=100,
            NCPU=6,
            CycleFactor=2.5,
            FluxThreshold=None,
            RMSFactor=3,
            PeakFactor=0,
            GD=None,
            SearchMaxAbs=1,
            IdSharedMem=None,
            ModelMachine=None,
            NFreqBands=1,
            RefFreq=None,
            MainCache=None,
            **kw  # absorb any unknown keywords arguments into this
    ):
        #self.im=CasaImage
        self.maincache = MainCache
        self.SearchMaxAbs = SearchMaxAbs
        self.ModelImage = None
        self.MaxMinorIter = MaxMinorIter
        self.NCPU = NCPU
        self.Chi2Thr = 10000
        self.GD = GD
        if IdSharedMem is None:
            self.IdSharedMem = str(os.getpid())
        else:
            self.IdSharedMem = IdSharedMem
        self.SubPSF = None
        self.MultiFreqMode = (self.GD["Freq"]["NBand"] > 1)
        self.FluxThreshold = FluxThreshold
        self.CycleFactor = CycleFactor
        self.RMSFactor = RMSFactor
        self.PeakFactor = PeakFactor
        self.GainMachine = ClassGainMachine.ClassGainMachine(GainMin=Gain)
        # if ModelMachine is None:
        #     from DDFacet.Imager.SSD import ClassModelMachineSSD
        #     self.ModelMachine=ClassModelMachineSSD.ClassModelMachine(self.GD,GainMachine=self.GainMachine)
        # else:
        self.ModelMachine = ModelMachine
        if self.ModelMachine.DicoSMStacked["Type"] != "SSD":
            raise ValueError("ModelMachine Type should be SSD")

        ## If the Model machine was already initialised, it will ignore it in the setRefFreq method
        ## and we need to set the reference freq in PSFServer
        #self.ModelMachine.setRefFreq(self.RefFreq)#,self.PSFServer.AllFreqs)

        # reset overall iteration counter
        self._niter = 0
        self.NChains = self.NCPU

        self.DeconvMode = "GAClean"

        if self.GD["GAClean"]["InitType"] == "HMP":
            import ClassInitSSDModelHMP
            self.InitMachine = ClassInitSSDModelHMP.ClassInitSSDModelParallel(
                self.GD,
                NFreqBands,
                RefFreq,
                MainCache=self.maincache,
                IdSharedMem=self.IdSharedMem)
        elif self.GD["GAClean"]["InitType"] == "MORESANE":

            import ClassInitSSDModelMoresane
            self.InitMachine = ClassInitSSDModelMoresane.ClassInitSSDModelParallel(
                self.GD,
                NFreqBands,
                RefFreq,
                NCPU=self.NCPU,
                MainCache=self.maincache,
                IdSharedMem=self.IdSharedMem)
        else:
            raise ValueError("InitType should be HMP or MORESANE")

    def setMaskMachine(self, MaskMachine):
        self.MaskMachine = MaskMachine

    def setDeconvMode(self, Mode="MetroClean"):
        self.DeconvMode = Mode

    def Reset(self):
        # clear anything we have left lying around in shared memory ## OMS how can this be right, what about others?
        # NpShared.DelAll()
        self.InitMachine.Reset()

    def GiveModelImage(self, *args):
        return self.ModelMachine.GiveModelImage(*args)

    def setSideLobeLevel(self, SideLobeLevel, OffsetSideLobe):
        self.SideLobeLevel = SideLobeLevel
        self.OffsetSideLobe = OffsetSideLobe

    def SetPSF(self, DicoVariablePSF):
        self.PSFServer = ClassPSFServer(self.GD)
        self.PSFServer.setDicoVariablePSF(DicoVariablePSF)
        self.PSFServer.setRefFreq(self.ModelMachine.RefFreq)
        self.DicoVariablePSF = DicoVariablePSF
        #self.NChannels=self.DicoDirty["NChannels"]

        #self.PSFServer.RefFreq=self.ModelMachine.RefFreq

    def Init(self, **kwargs):
        self.SetPSF(kwargs["PSFVar"])
        self.DicoVariablePSF["PSFSideLobes"] = kwargs["PSFAve"]
        self.setSideLobeLevel(kwargs["PSFAve"][0], kwargs["PSFAve"][1])
        self.ModelMachine.setRefFreq(kwargs["RefFreq"])
        # store grid and degrid freqs for ease of passing to MSMF
        #print kwargs["GridFreqs"],kwargs["DegridFreqs"]
        self.GridFreqs = kwargs["GridFreqs"]
        self.DegridFreqs = kwargs["DegridFreqs"]
        self.ModelMachine.setFreqMachine(kwargs["GridFreqs"],
                                         kwargs["DegridFreqs"])
        self.InitMachine.Init(self.DicoVariablePSF, self.GridFreqs,
                              self.DegridFreqs)

    def AdaptArrayShape(self, A, Nout):
        nch, npol, Nin, _ = A.shape
        if Nin == Nout:
            return A
        elif Nin > Nout:
            dx = Nout / 2
            B = np.zeros((nch, npol, Nout, Nout), A.dtype)
            print >> log, "  Adapt shapes: %s -> %s" % (str(
                A.shape), str(B.shape))
            B[:] = A[..., Nin / 2 - dx:Nin / 2 + dx + 1,
                     Nin / 2 - dx:Nin / 2 + dx + 1]
            return B
        else:
            stop
            return None

    def SetDirty(self, DicoDirty):
        self.DicoDirty = DicoDirty
        self._Dirty = self.DicoDirty["ImageCube"]
        self._MeanDirty = self.DicoDirty["MeanImage"]

        NPSF = self.PSFServer.NPSF
        _, _, NDirty, _ = self._Dirty.shape

        off = (NPSF - NDirty) / 2

        self.DirtyExtent = (off, off + NDirty, off, off + NDirty)

        if self.ModelImage is None:
            self._ModelImage = np.zeros_like(self._Dirty)
        self.ModelMachine.setModelShape(self._Dirty.shape)

    def SearchIslands(self, Threshold):

        if self.MaskMachine.CurrentNegMask is None:
            raise RuntimeError("A mask image should be constructible with SSD")

        IslandDistanceMachine = ClassIslandDistanceMachine.ClassIslandDistanceMachine(
            self.GD,
            self.MaskMachine.CurrentNegMask,
            self.PSFServer,
            self.DicoDirty,
            IdSharedMem=self.IdSharedMem)
        ListIslands = IslandDistanceMachine.SearchIslands(Threshold)
        # FluxIslands=[]
        # for iIsland in range(len(ListIslands)):
        #     x,y=np.array(ListIslands[iIsland]).T
        #     FluxIslands.append(np.sum(Dirty[0,0,x,y]))
        # ind=np.argsort(np.array(FluxIslands))[::-1]

        # ListIslandsSort=[ListIslands[i] for i in ind]

        # ListIslands=self.CalcCrossIslandFlux(ListIslandsSort)

        # #############################
        # Filter by peak flux
        ListIslandsFiltered = []
        Dirty = self.DicoDirty["MeanImage"]
        for iIsland in range(len(ListIslands)):
            x, y = np.array(ListIslands[iIsland]).T
            PixVals = Dirty[0, 0, x, y]
            DoThisOne = False

            MaxIsland = np.max(np.abs(PixVals))

            # print "island %i [%i]: %f"%(iIsland,x.size,MaxIsland)

            #            if (MaxIsland>(3.*self.RMS))|(MaxIsland>Threshold):
            if (MaxIsland > Threshold):
                ListIslandsFiltered.append(ListIslands[iIsland])
            # else:
            #     self.MaskMachine.CurrentNegMask[:,:,x,y]=1
            #     self.MaskMachine.CurrentMask[:,:,x,y]=0
            # ###############################
            # if np.max(np.abs(PixVals))>Threshold:
            #     DoThisOne=True
            #     self.IslandHasBeenDone[0,0,x,y]=1
            # if ((DoThisOne)|self.IslandHasBeenDone[0,0,x[0],y[0]]):
            #     self.ListIslands.append(ListIslands[iIsland])
            # ###############################
        # #############################
        print >> log, "  selected %i islands [out of %i] with peak flux > %.3g Jy" % (
            len(ListIslandsFiltered), len(ListIslands), Threshold)
        ListIslands = ListIslandsFiltered
        #ListIslands=[np.load("errIsland_000524.npy").tolist()]

        ListIslands = IslandDistanceMachine.CalcCrossIslandFlux(ListIslands)
        ListIslands = IslandDistanceMachine.ConvexifyIsland(ListIslands)
        ListIslands = IslandDistanceMachine.MergeIslands(ListIslands)

        self.LabelIslandsImage = IslandDistanceMachine.CalcLabelImage(
            ListIslands)

        self.ListIslands = ListIslands
        self.NIslands = len(self.ListIslands)

        print >> log, "Sorting islands by size"
        Sz = np.array([
            len(self.ListIslands[iIsland]) for iIsland in range(self.NIslands)
        ])
        #print ":::::::::::::::::"
        ind = np.argsort(Sz)[::-1]

        ListIslandsOut = [self.ListIslands[i] for i in ind]
        self.ListIslands = ListIslandsOut

    def InitIslands(self):
        self.DicoInitIndiv = {}
        if self.GD["GAClean"]["MinSizeInit"] == -1: return

        DoAbs = int(self.GD["Deconv"]["AllowNegative"])
        print >> log, "  Running minor cycle [MinorIter = %i/%i, SearchMaxAbs = %i]" % (
            self._niter, self.MaxMinorIter, DoAbs)

        # ##########################################################################
        # # Init SSD model using MSMF

        FreqsModel = np.array([
            np.mean(self.DicoVariablePSF["freqs"][iBand])
            for iBand in range(len(self.DicoVariablePSF["freqs"]))
        ])
        ModelImage = self.ModelMachine.GiveModelImage(FreqsModel)
        ModelImage *= np.sqrt(self.DicoDirty["JonesNorm"])
        # ######################
        # SERIAL
        # InitMachine=ClassInitSSDModel.ClassInitSSDModel(self.GD,
        #                                                      self.DicoVariablePSF,
        #                                                      self.DicoDirty,
        #                                                      self.ModelMachine.RefFreq,
        #                                                      MainCache=self.maincache)
        # InitMachine.setSSDModelImage(ModelImage)
        # DicoInitIndiv={}
        # for iIsland,Island in enumerate(self.ListIslands):
        #     SModel,AModel=InitMachine.giveModel(Island)
        #     DicoInitIndiv[iIsland]={"S":SModel,"Alpha":AModel}
        # self.DicoInitIndiv=DicoInitIndiv

        # ######################
        # Parallel
        self.ListSizeIslands = []
        for ThisPixList in self.ListIslands:
            x, y = np.array(ThisPixList, dtype=np.float32).T
            dx, dy = x.max() - x.min(), y.max() - y.min()
            dd = np.max([dx, dy]) + 1
            self.ListSizeIslands.append(dd)

        ListDoIslandsInit = [
            True if
            self.ListSizeIslands[iIsland] >= self.GD["GAClean"]["MinSizeInit"]
            else False for iIsland in range(len(self.ListIslands))
        ]

        #ListDoMSMFIslandsInit=[True if iIsland==16 else False for iIsland in range(len(self.ListIslands))]

        print >> log, "  selected %i islands larger than %i pixels for initialisation" % (
            np.count_nonzero(ListDoIslandsInit),
            self.GD["GAClean"]["MinSizeInit"])

        if np.count_nonzero(ListDoIslandsInit) > 0:
            self.DicoInitIndiv = self.InitMachine.giveDicoInitIndiv(
                self.ListIslands,
                ListDoIsland=ListDoIslandsInit,
                ModelImage=ModelImage,
                DicoDirty=self.DicoDirty)

    def setChannel(self, ch=0):
        self.Dirty = self._MeanDirty[ch]
        self.ModelImage = self._ModelImage[ch]

    def GiveThreshold(self, Max):
        return ((self.CycleFactor - 1.) / 4. * (1. - self.SideLobeLevel) +
                self.SideLobeLevel) * Max if self.CycleFactor else 0

    def Deconvolve(self, ch=0):
        if self._niter >= self.MaxMinorIter:
            return "MaxIter", False, False

        self.setChannel(ch)

        _, npix, _ = self.Dirty.shape
        xc = (npix) / 2

        npol, _, _ = self.Dirty.shape

        m0, m1 = self.Dirty[0].min(), self.Dirty[0].max()

        DoAbs = int(self.GD["Deconv"]["AllowNegative"])
        print >> log, "  Running minor cycle [MinorIter = %i/%i, SearchMaxAbs = %i]" % (
            self._niter, self.MaxMinorIter, DoAbs)

        NPixStats = 1000
        #RandomInd=np.int64(np.random.rand(NPixStats)*npix**2)
        RandomInd = np.int64(np.linspace(0, self.Dirty.size - 1, NPixStats))
        RMS = np.std(np.real(self.Dirty.ravel()[RandomInd]))
        #print "::::::::::::::::::::::"
        self.RMS = RMS

        self.GainMachine.SetRMS(RMS)

        Fluxlimit_RMS = self.RMSFactor * RMS

        x, y, MaxDirty = NpParallel.A_whereMax(
            self.Dirty,
            NCPU=self.NCPU,
            DoAbs=DoAbs,
            Mask=self.MaskMachine.CurrentNegMask)
        #MaxDirty=np.max(np.abs(self.Dirty))
        #Fluxlimit_SideLobe=MaxDirty*(1.-self.SideLobeLevel)
        #Fluxlimit_Sidelobe=self.CycleFactor*MaxDirty*(self.SideLobeLevel)
        Fluxlimit_Peak = MaxDirty * self.PeakFactor
        Fluxlimit_Sidelobe = self.GiveThreshold(MaxDirty)

        mm0, mm1 = self.Dirty.min(), self.Dirty.max()

        # work out uper threshold
        StopFlux = max(Fluxlimit_Peak, Fluxlimit_RMS, Fluxlimit_Sidelobe,
                       Fluxlimit_Peak, self.FluxThreshold)

        print >> log, "    Dirty image peak flux      = %10.6f Jy [(min, max) = (%.3g, %.3g) Jy]" % (
            MaxDirty, mm0, mm1)
        print >> log, "      RMS-based threshold      = %10.6f Jy [rms = %.3g Jy; RMS factor %.1f]" % (
            Fluxlimit_RMS, RMS, self.RMSFactor)
        print >> log, "      Sidelobe-based threshold = %10.6f Jy [sidelobe  = %.3f of peak; cycle factor %.1f]" % (
            Fluxlimit_Sidelobe, self.SideLobeLevel, self.CycleFactor)
        print >> log, "      Peak-based threshold     = %10.6f Jy [%.3f of peak]" % (
            Fluxlimit_Peak, self.PeakFactor)
        print >> log, "      Absolute threshold       = %10.6f Jy" % (
            self.FluxThreshold)
        print >> log, "    Stopping flux              = %10.6f Jy [%.3f of peak ]" % (
            StopFlux, StopFlux / MaxDirty)

        MaxModelInit = np.max(np.abs(self.ModelImage))

        # Fact=4
        # self.BookKeepShape=(npix/Fact,npix/Fact)
        # BookKeep=np.zeros(self.BookKeepShape,np.float32)
        # NPixBook,_=self.BookKeepShape
        # FactorBook=float(NPixBook)/npix

        T = ClassTimeIt.ClassTimeIt()
        T.disable()

        x, y, ThisFlux = NpParallel.A_whereMax(
            self.Dirty,
            NCPU=self.NCPU,
            DoAbs=DoAbs,
            Mask=self.MaskMachine.CurrentNegMask)

        if ThisFlux < StopFlux:
            print >> log, ModColor.Str(
                "    Initial maximum peak %g Jy below threshold, we're done here"
                % (ThisFlux),
                col="green")
            return "FluxThreshold", False, False

        self.SearchIslands(StopFlux)
        #return None,None,None
        self.InitIslands()

        if self.DeconvMode == "GAClean":
            print >> log, "Evolving %i generations of %i sourcekin" % (
                self.GD["GAClean"]["NMaxGen"],
                self.GD["GAClean"]["NSourceKin"])
            ListBigIslands = []
            ListSmallIslands = []
            ListInitBigIslands = []
            ListInitSmallIslands = []
            for iIsland, Island in enumerate(self.ListIslands):
                if len(Island) > self.GD["SSDClean"]["ConvFFTSwitch"]:
                    ListBigIslands.append(Island)
                    ListInitBigIslands.append(
                        self.DicoInitIndiv.get(iIsland, None))
                else:
                    ListSmallIslands.append(Island)
                    ListInitSmallIslands.append(
                        self.DicoInitIndiv.get(iIsland, None))

            if len(ListSmallIslands) > 0:
                print >> log, "Deconvolve small islands (<=%i pixels) (parallelised over island)" % (
                    self.GD["SSDClean"]["ConvFFTSwitch"])
                self.DeconvListIsland(ListSmallIslands,
                                      ParallelMode="OverIslands",
                                      ListInitIslands=ListInitSmallIslands)
            else:
                print >> log, "No small islands"

            if len(ListBigIslands) > 0:
                print >> log, "Deconvolve large islands (>%i pixels) (parallelised per island)" % (
                    self.GD["SSDClean"]["ConvFFTSwitch"])
                self.DeconvListIsland(ListBigIslands,
                                      ParallelMode="PerIsland",
                                      ListInitIslands=ListInitBigIslands)
            else:
                print >> log, "No large islands"

        elif self.DeconvMode == "MetroClean":
            if self.GD["MetroClean"]["MetroNChains"] != "NCPU":
                self.NChains = self.GD["MetroClean"]["MetroNChains"]
            else:
                self.NChains = self.NCPU
            print >> log, "Evolving %i chains of %i iterations" % (
                self.NChains, self.GD["MetroClean"]["MetroNIter"])

            ListBigIslands = []
            for ThisPixList in self.ListIslands:
                x, y = np.array(ThisPixList, dtype=np.float32).T
                dx, dy = x.max() - x.min(), y.max() - y.min()
                dd = np.max([dx, dy]) + 1
                if dd > self.GD["SSDClean"]["RestoreMetroSwitch"]:
                    ListBigIslands.append(ThisPixList)

            # ListBigIslands=ListBigIslands[1::]
            # ListBigIslands=[Island for Island in self.ListIslands if len(Island)>=self.GD["SSDClean"]["RestoreMetroSwitch"]]
            print >> log, "Deconvolve %i large islands (>=%i pixels) (parallelised per island)" % (
                len(ListBigIslands), self.GD["SSDClean"]["RestoreMetroSwitch"])
            self.SelectedIslandsMask = np.zeros_like(
                self.DicoDirty["MeanImage"])
            for ThisIsland in ListBigIslands:
                x, y = np.array(ThisIsland).T
                self.SelectedIslandsMask[0, 0, x, y] = 1

            self.DeconvListIsland(ListBigIslands, ParallelMode="PerIsland")

        return "MaxIter", True, True  # stop deconvolution but do update model

    def DeconvListIsland(self,
                         ListIslands,
                         ParallelMode="OverIsland",
                         ListInitIslands=None):
        # ================== Parallel part

        NIslands = len(ListIslands)
        if NIslands == 0: return
        if ParallelMode == "OverIslands":
            NCPU = self.NCPU
            NCPU = np.min([NCPU, NIslands])
            Parallel = True
            ParallelPerIsland = False
            StopWhenQueueEmpty = True
        elif ParallelMode == "PerIsland":
            NCPU = 1  #self.NCPU
            Parallel = True
            ParallelPerIsland = True
            StopWhenQueueEmpty = True

        # ######### Debug
        # ParallelPerIsland=False
        # Parallel=False
        # NCPU=1
        # StopWhenQueueEmpty=True
        # ##################

        work_queue = multiprocessing.Queue()

        # shared dict to hold inputs and outputs to workers (each island number is a key)
        deconv_dict = shared_dict.create("DeconvListIslands")

        NJobs = NIslands
        T = ClassTimeIt.ClassTimeIt("    ")
        T.disable()
        for iIsland, ThisPixList in enumerate(ListIslands):
            island_dict = deconv_dict.addSubdict(iIsland)

            # print "%i/%i"%(iIsland,self.NIslands)
            island_dict["Island"] = np.array(ThisPixList)

            XY = np.array(ThisPixList, dtype=np.float32)
            xm, ym = np.mean(np.float32(XY), axis=0).astype(int)
            T.timeit("xm,ym")
            nchan, npol, _, _ = self._Dirty.shape
            JonesNorm = (self.DicoDirty["JonesNorm"][:, :, xm, ym]).reshape(
                (nchan, npol, 1, 1))
            W = self.DicoDirty["WeightChansImages"]
            JonesNorm = np.sum(JonesNorm * W.reshape((nchan, 1, 1, 1)),
                               axis=0).reshape((1, npol, 1, 1))
            T.timeit("JonesNorm")

            IslandBestIndiv = self.ModelMachine.GiveIndividual(ThisPixList)
            T.timeit("GiveIndividual")
            FacetID = self.PSFServer.giveFacetID2(xm, ym)
            T.timeit("FacetID")

            island_dict["BestIndiv"] = IslandBestIndiv

            ListOrder = [
                iIsland, FacetID, JonesNorm.flat[0], self.RMS**2,
                island_dict.path
            ]

            work_queue.put(ListOrder)
            T.timeit("Put")

        # ListArrayIslands=[np.array(ListIslands[iIsland]) for iIsland in range(NIslands)]
        # NpShared.PackListArray(SharedListIsland,ListArrayIslands)
        # T.timeit("Pack0")
        # SharedBestIndiv="%s.ListBestIndiv"%(self.IdSharedMem)
        # NpShared.PackListArray(SharedBestIndiv,ListBestIndiv)
        # T.timeit("Pack1")

        workerlist = []

        # List_Result_queue=[]
        # for ii in range(NCPU):
        #     List_Result_queue.append(multiprocessing.JoinableQueue())

        result_queue = multiprocessing.Queue()
        Title = " Evolve pop."
        if self.DeconvMode == "MetroClean":
            Title = " Running chain"

        pBAR = ProgressBar(Title=Title)
        #pBAR.disable()
        pBAR.render(0, NJobs)
        for ii in range(NCPU):
            W = WorkerDeconvIsland(work_queue,
                                   result_queue,
                                   self.GD,
                                   self._Dirty,
                                   self.DicoVariablePSF["CubeVariablePSF"],
                                   IdSharedMem=self.IdSharedMem,
                                   FreqsInfo=self.PSFServer.DicoMappingDesc,
                                   ParallelPerIsland=ParallelPerIsland,
                                   StopWhenQueueEmpty=StopWhenQueueEmpty,
                                   DeconvMode=self.DeconvMode,
                                   NChains=self.NChains,
                                   ListInitIslands=ListInitIslands)
            workerlist.append(W)

            if Parallel:
                workerlist[ii].start()
            else:
                workerlist[ii].run()

        iResult = 0
        #print "!!!!!!!!!!!!!!!!!!!!!!!!",iResult,NJobs
        while iResult < NJobs:
            DicoResult = None
            # for result_queue in List_Result_queue:
            #     if result_queue.qsize()!=0:
            #         try:
            #             DicoResult=result_queue.get_nowait()

            #             break
            #         except:

            #             pass
            #         #DicoResult=result_queue.get()
            #print "!!!!!!!!!!!!!!!!!!!!!!!!! Qsize",result_queue.qsize()
            if result_queue.qsize() != 0:
                try:
                    DicoResult = result_queue.get_nowait()
                except:
                    pass
                    #DicoResult=result_queue.get()

            if DicoResult is None:
                time.sleep(0.05)
                continue

            iResult += 1
            NDone = iResult
            intPercent = int(100 * NDone / float(NJobs))
            pBAR.render(NDone, NJobs)

            if DicoResult["Success"]:
                iIsland = DicoResult["iIsland"]
                island_dict = deconv_dict[iIsland]
                island_dict.reload()

                self.ModelMachine.AppendIsland(ListIslands[iIsland],
                                               island_dict["Model"].copy())

                if DicoResult["HasError"]:
                    self.ErrorModelMachine.AppendIsland(
                        ThisPixList, ListIslands[iIsland],
                        island_dict["sModel"].copy())

        deconv_dict.delete()

        for ii in range(NCPU):
            try:
                workerlist[ii].shutdown()
                workerlist[ii].terminate()
                workerlist[ii].join()
            except:
                pass

    ###################################################################################
    ###################################################################################

    def GiveEdges(self, (xc0, yc0), N0, (xc1, yc1), N1):
        M_xc = xc0
        M_yc = yc0
        NpixMain = N0
        F_xc = xc1
        F_yc = yc1
        NpixFacet = N1

        ## X
        M_x0 = M_xc - NpixFacet / 2
        x0main = np.max([0, M_x0])
        dx0 = x0main - M_x0
        x0facet = dx0

        M_x1 = M_xc + NpixFacet / 2
        x1main = np.min([NpixMain - 1, M_x1])
        dx1 = M_x1 - x1main
        x1facet = NpixFacet - dx1
        x1main += 1
        ## Y
        M_y0 = M_yc - NpixFacet / 2
        y0main = np.max([0, M_y0])
        dy0 = y0main - M_y0
        y0facet = dy0

        M_y1 = M_yc + NpixFacet / 2
        y1main = np.min([NpixMain - 1, M_y1])
        dy1 = M_y1 - y1main
        y1facet = NpixFacet - dy1
        y1main += 1

        Aedge = [x0main, x1main, y0main, y1main]
        Bedge = [x0facet, x1facet, y0facet, y1facet]
        return Aedge, Bedge
Ejemplo n.º 5
0
class ClassImageDeconvMachine():
    """
    These methods may be called from ClassDeconvMachine
        Init(**kwargs) - contains minor cycle specific initialisations which are only used once
            Input: currently kwargs are minor cycle specific and should be set from ClassDeconvMachine but a
                     ideally a generic interface has these set in the parset somehow.
        Deconvolve() - does joint deconvolution over all the channels/bands.
            Output: return_code - "MaxIter"????
                    continue - whether to continue the deconvolution
                    updated - whether the model has been updated
        GiveModelImage(freq) - returns current model at freq
            Input: freq - tuple of frequencies at which to return the model
            Output: Mod - the current model at freq
        Update(DicoDirty,**kwargs) - updates to minor cycle at the end of each major cycle
            Input:  DicoDirty - updated image dict at start of each major cycle
                    Use kwargs to pass any other minor cycle specific options
        ToFile(fname) - saves dico model to file
            Input: fname - the name of the file to write the dico image to
        FromFile(fname) - reads model dict from file
            Input: fname - the name of the file to write the dico image to
    """
    def __init__(
            self,
            Gain=0.1,
            MaxMinorIter=50000,
            NCPU=0,
            CycleFactor=2.5,
            FluxThreshold=None,
            RMSFactor=3,
            PeakFactor=0,
            GD=None,
            SearchMaxAbs=1,
            CleanMaskImage=None,
            ImagePolDescriptor=["I"],
            ModelMachine=None,
            MainCache=None,
            CacheFileName='WSCMS',
            **kw  # absorb any unknown keywords arguments here
    ):
        self.SearchMaxAbs = SearchMaxAbs
        self.ModelImage = None
        self.MaxMinorIter = MaxMinorIter
        self.NCPU = NCPU
        self.MaskArray = None
        self.GD = GD
        self.MultiFreqMode = (self.GD["Freq"]["NBand"] > 1)
        self.NFreqBand = self.GD["Freq"]["NBand"]
        self.FluxThreshold = FluxThreshold
        self.CycleFactor = CycleFactor
        self.RMSFactor = RMSFactor
        self.PeakFactor = PeakFactor
        if ModelMachine is None:
            # raise RuntimeError("You need to supply ImageDeconvMachine with a instantiated ModelMachine")
            import ClassModelMachineWSCMS as ClassModelMachine
            self.ModelMachine = ClassModelMachine.ClassModelMachine(
                self.GD, GainMachine=ClassGainMachine.get_instance())
        else:
            self.ModelMachine = ModelMachine
        self.GainMachine = self.ModelMachine.GainMachine
        self._niter = 0

        # cache options
        self.maincache = MainCache
        self.CacheFileName = CacheFileName
        self.PSFHasChanged = False
        self.LastScale = 99999

        #  TODO - use MaskMachine for this
        CleanMaskImage = self.GD["Mask"]["External"]
        if CleanMaskImage is not None:
            print >> log, "Reading mask image: %s" % CleanMaskImage
            MaskArray = image(CleanMaskImage).getdata()
            nch, npol, nxmask, nymask = MaskArray.shape
            # if (nch > 1) or (npol > 1):
            #     print>>log, "Warning - only single channel and pol mask supported. Will use mask for ch 0 pol 0"
            # MaskArray = MaskArray[0,0]
            # _, _, nxmod, nymod = self.ModelMachine.ModelShape
            # if (nxmod != nxmask) or (nymod !=nymask):
            #     print>>log, "Warning - shape of mask != shape of your model. Will pad/trncate to match model shape"
            #     nxdiff = nxmod - nxmask
            #     nydiff = nymod - nymask
            #     if nxdiff < 0:
            #         MaskArray = MaskArray
            self._MaskArray = np.zeros(MaskArray.shape, np.bool8)
            for ch in range(nch):
                for pol in range(npol):
                    self._MaskArray[ch, pol, :, :] = np.bool8(
                        1 - MaskArray[ch, pol].T[::-1].copy())[:, :]
            self.MaskArray = np.ascontiguousarray(self._MaskArray)

        # import matplotlib.pyplot as plt
        # plt.imshow(self.MaskArray[0,0])
        # plt.colorbar()
        # plt.show()
        #
        # import sys
        # sys.exit(0)

        self._peakMode = "normal"

        self.CacheFileName = CacheFileName
        self.CurrentNegMask = None
        self._NoiseMap = None
        self._PNRStop = None  # in _peakMode "sigma", provides addiitonal stopping criterion

        # # this is so that the relevant functions are registered as job handlers with APP
        # # pass to ModelMachine.setScaleMachine to set workers
        # self.FTMachine = FFTW_Scale_Manager(wisdom_file=self.GD["Cache"]["DirWisdomFFTW"])
        #
        # APP.registerJobHandlers(self)

    def Init(self, cache=None, facetcache=None, **kwargs):
        # check for valid cache
        cachehash = dict([
            (section, self.GD[section])
            for section in ("Data", "Beam", "Selection", "Freq", "Image",
                            "Facets", "Weight", "RIME", "Comp", "CF", "WSCMS")
        ])

        cachepath, valid = self.maincache.checkCache(self.CacheFileName,
                                                     cachehash,
                                                     directory=True,
                                                     reset=cache
                                                     or self.PSFHasChanged)
        # export the hash
        self.maincache.saveCache(name='WSCMS')

        self.Freqs = kwargs["GridFreqs"]
        AllDegridFreqs = []
        for i in kwargs["DegridFreqs"].keys():
            AllDegridFreqs.append(kwargs["DegridFreqs"][i])
        self.Freqs_degrid = np.asarray(AllDegridFreqs).flatten()
        self.SetPSF(kwargs["PSFVar"])
        self.setSideLobeLevel(kwargs["PSFAve"][0], kwargs["PSFAve"][1])

        self.ModelMachine.setPSFServer(self.PSFServer)
        self.ModelMachine.setFreqMachine(
            self.Freqs,
            self.Freqs_degrid,
            weights=kwargs["PSFVar"]["WeightChansImages"],
            PSFServer=self.PSFServer)

        from africanus.constants import c as lightspeed
        minlambda = lightspeed / self.Freqs.min()
        # LB - note MaskArray might be modified by ScaleMachine if GD{"WSCMS"]["AutoMask"] is True
        # so we should avoid keeping it as None
        # if self.MaskArray is None:
        #     self.MaskArray = np.zeros([1, 1, self.Npix, self.Npix], dtype=np.bool8)
        self.ModelMachine.setScaleMachine(self.PSFServer,
                                          NCPU=self.NCPU,
                                          MaskArray=self.MaskArray,
                                          cachepath=cachepath,
                                          MaxBaseline=kwargs["MaxBaseline"] /
                                          minlambda)

    def Reset(self):
        pass

    def setMaskMachine(self, MaskMachine):
        self.MaskMachine = MaskMachine

    def SetModelRefFreq(self, RefFreq):
        """
        Sets ref freq in ModelMachine.
        """
        self.ModelMachine.setRefFreq(RefFreq)

    def SetModelShape(self):
        """
        Sets the shape params of model, call in every update step
        """
        self.ModelMachine.setModelShape(self._Dirty.shape)
        self.Nchan, self.Npol, self.Npix, _ = self._Dirty.shape
        self.NpixFacet = self.Npix // self.GD["Facets"]["NFacets"]

    def GiveModelImage(self, *args):
        return self.ModelMachine.GiveModelImage(*args)

    def setSideLobeLevel(self, SideLobeLevel, OffsetSideLobe):
        self.SideLobeLevel = SideLobeLevel
        self.OffsetSideLobe = OffsetSideLobe

    def SetPSF(self, DicoVariablePSF):
        """
        The keys in DicoVariablePSF and what they mean:
         'MeanFacetPSF' -    
         'MeanImage' -
         'ImageCube' -
         'CellSizeRad' -
         'ChanMappingGrid' -
         'ChanMappingGridChan' -
         'CubeMeanVariablePSF' -
         'CubeVariablePSF' -
         'SumWeights'           -
         'MeanJonesBand'        -
         'PeakNormed_CubeMeanVariablePSF'
         'PeakNormed_CubeVariablePSF'
         'OutImShape'
         'JonesNorm'
         'Facets'
         'PSFSidelobes'
         'ImageInfo'
         'CentralFacet'
         'freqs'
         'SumJonesChan'
         'SumJonesChanWeightSq'
         'EstimatesAvgPSF'
         'WeightChansImages'
         'FacetNorm'
         'PSFGaussPars'
         'FWHMBeam'
        
        """
        self.PSFServer = ClassPSFServer(self.GD)
        # NormalisePSF must be true here for the beam to be applied correctly
        self.PSFServer.setDicoVariablePSF(DicoVariablePSF, NormalisePSF=True)
        self.DicoVariablePSF = DicoVariablePSF

    def setNoiseMap(self, NoiseMap, PNRStop=10):
        """
        Sets the noise map. The mean dirty will be divided by the noise map before peak finding.
        If PNRStop is set, an additional stopping criterion (peak-to-noisemap) will be applied.
            Peaks are reported in units of sigmas.
        If PNRStop is not set, NoiseMap is treated as simply an (inverse) weighting that will bias
            peak selection in the minor cycle. In this mode, peaks are reported in units of flux.
        """
        self._NoiseMap = NoiseMap
        self._PNRStop = PNRStop
        self._peakMode = "sigma"

    def SetDirty(self, DicoDirty):
        """
        The keys in DicoDirty and what they mean (see also FacetMachine.FacetsToIm docs)
         'JonesNorm' - array containing norm of Jones terms as an image 
         'ImageInfo' - dictionary containing 'CellSizeRad' and 'OutImShape'
         'ImageCube' - array containing residual
         'MeanImage' - array containing mean of the residual
         'freqs' - dictionary keyed by band number containing the actual frequencies that got binned into that band 
         'SumWeights' - sum of visibility weights used in normalizing the gridded correlations
         'FacetMeanResidual' - ???
         'WeightChansImages' - Weights corresponding to imaging bands (how is this computed?)
         'FacetNorm' - self.FacetImage (grid-correcting map) see FacetMachine
        """
        self.DicoDirty = DicoDirty
        self._Dirty = self.DicoDirty["ImageCube"]
        self._MeanDirty = self.DicoDirty["MeanImage"]
        self._JonesNorm = self.DicoDirty["JonesNorm"]
        self.WeightsChansImages = np.mean(np.float32(
            self.DicoDirty["WeightChansImages"]),
                                          axis=1)[:, None, None, None]

        # if self._peakMode is "sigma":
        #     print>> log, "Will search for the peak in the SNR-weighted dirty map"
        #     a, b = self._MeanDirty, self._NoiseMap.reshape(self._MeanDirty.shape)
        #     self._PeakSearchImage = numexpr.evaluate("a/b")
        # else:
        #     print>> log, "Will search for the peak in the unweighted dirty map"
        #     self._PeakSearchImage = self._MeanDirty
        #
        # if self.ModelImage is None:
        #     self._ModelImage = np.zeros_like(self._Dirty)
        # if self.MaskArray is None:
        #     self._MaskArray = np.zeros(self._Dirty.shape, dtype=np.bool8)

    def SubStep(self, (dx, dy), LocalSM):
        """
        This is where subtraction in the image domain happens
        """
        xc, yc = dx, dy
        N1 = LocalSM.shape[-1]

        # Get overlap indices where psf should be subtracted
        Aedge, Bedge = GiveEdges((xc, yc), self.Npix, (N1 // 2, N1 // 2), N1)

        x0d, x1d, y0d, y1d = Aedge
        x0p, x1p, y0p, y1p = Bedge

        self._Dirty[:, :, x0d:x1d, y0d:y1d] -= LocalSM[:, :, x0p:x1p, y0p:y1p]

        # Subtract from the average
        if self.MultiFreqMode:  # If multiple frequencies are present construct the weighted mean
            self._MeanDirty[:, 0, x0d:x1d, y0d:y1d] -= np.sum(
                LocalSM[:, :, x0p:x1p, y0p:y1p] * self.WeightsChansImages,
                axis=0)  # Sum over freq
        else:
            self._MeanDirty = self._Dirty
class ClassImageDeconvMachine():
    def __init__(self, GD=None, ModelMachine=None, RefFreq=None, *args, **kw):
        self.GD = GD
        self.ModelMachine = ModelMachine
        self.RefFreq = RefFreq
        if self.ModelMachine.DicoModel["Type"] != "MORESANE":
            raise ValueError("ModelMachine Type should be MORESANE")
        self.MultiFreqMode = (self.GD["Freq"]["NBand"] > 1)

    def SetPSF(self, DicoVariablePSF):
        self.PSFServer = ClassPSFServer(self.GD)
        DicoVariablePSF = shared_dict.attach(
            DicoVariablePSF.path)  #["CubeVariablePSF"]
        self.PSFServer.setDicoVariablePSF(DicoVariablePSF)
        self.PSFServer.setRefFreq(self.ModelMachine.RefFreq)
        self.DicoVariablePSF = DicoVariablePSF
        self.setFreqs(self.PSFServer.DicoMappingDesc)

    def setMaskMachine(self, MaskMachine):
        self.MaskMachine = MaskMachine

    def setFreqs(self, DicoMappingDesc):
        self.DicoMappingDesc = DicoMappingDesc
        if self.DicoMappingDesc is None: return
        self.SpectralFunctionsMachine = ClassSpectralFunctions.ClassSpectralFunctions(
            self.DicoMappingDesc,
            RefFreq=self.DicoMappingDesc["RefFreq"])  #,BeamEnable=False)
        self.SpectralFunctionsMachine.CalcFluxBands()

    def GiveModelImage(self, *args):
        return self.ModelMachine.GiveModelImage(*args)

    def Update(self, DicoDirty, **kwargs):
        """
        Method to update attributes from ClassDeconvMachine
        """
        #Update image dict
        self.SetDirty(DicoDirty)

    def ToFile(self, fname):
        """
        Write model dict to file
        """
        self.ModelMachine.ToFile(fname)

    def FromFile(self, fname):
        """
        Read model dict from file SubtractModel
        """
        self.ModelMachine.FromFile(fname)

    def FromDico(self, DicoName):
        """
        Read in model dict
        """
        self.ModelMachine.FromDico(DicoName)

    def setSideLobeLevel(self, SideLobeLevel, OffsetSideLobe):
        self.SideLobeLevel = SideLobeLevel
        self.OffsetSideLobe = OffsetSideLobe

    def Init(self, **kwargs):
        self.SetPSF(kwargs["PSFVar"])
        if "PSFSideLobes" not in self.DicoVariablePSF.keys():
            self.DicoVariablePSF["PSFSideLobes"] = kwargs["PSFAve"]
        self.setSideLobeLevel(kwargs["PSFAve"][0], kwargs["PSFAve"][1])
        self.ModelMachine.setRefFreq(kwargs["RefFreq"])
        # store grid and degrid freqs for ease of passing to MSMF
        #print kwargs["GridFreqs"],kwargs["DegridFreqs"]
        self.GridFreqs = kwargs["GridFreqs"]
        self.DegridFreqs = kwargs["DegridFreqs"]
        self.ModelMachine.setFreqMachine(kwargs["GridFreqs"],
                                         kwargs["DegridFreqs"])

    def SetDirty(self, DicoDirty):
        self.DicoDirty = DicoDirty
        self._Dirty = self.DicoDirty["ImageCube"]
        self._MeanDirty = self.DicoDirty["MeanImage"]
        NPSF = self.PSFServer.NPSF
        _, _, NDirty, _ = self._Dirty.shape
        off = (NPSF - NDirty) / 2
        self.DirtyExtent = (off, off + NDirty, off, off + NDirty)
        self.ModelMachine.setModelShape(self._Dirty.shape)

    def AdaptArrayShape(self, A, Nout):
        nch, npol, Nin, _ = A.shape
        if Nin == Nout:
            return A
        elif Nin > Nout:
            # dx=Nout/2
            # B=np.zeros((nch,npol,Nout,Nout),A.dtype)
            # print>>log,"  Adapt shapes: %s -> %s"%(str(A.shape),str(B.shape))
            # B[:]=A[...,Nin/2-dx:Nin/2+dx+1,Nin/2-dx:Nin/2+dx+1]

            N0 = A.shape[-1]
            xc0 = yc0 = N0 / 2
            N1 = Nout
            xc1 = yc1 = N1 / 2
            Aedge, Bedge = GiveEdges((xc0, yc0), N0, (xc1, yc1), N1)
            x0d, x1d, y0d, y1d = Aedge
            x0p, x1p, y0p, y1p = Bedge
            B = A[..., x0d:x1d, y0d:y1d]

            return B
        else:
            return A

    def giveSliceCut(self, A, Nout):
        nch, npol, Nin, _ = A.shape
        if Nin == Nout:
            slice(None)
        elif Nin > Nout:
            N0 = A.shape[-1]
            xc0 = yc0 = N0 / 2
            if Nout % 2 == 0:
                x0d, x1d = xc0 - Nout / 2, xc0 + Nout / 2
            else:
                x0d, x1d = xc0 - Nout / 2, xc0 + Nout / 2 + 1
            return slice(x0d, x1d)
        else:
            return None

    def updateModelMachine(self, ModelMachine):
        self.ModelMachine = ModelMachine
        if self.ModelMachine.RefFreq != self.RefFreq:
            raise ValueError("freqs should be equal")

    def updateMask(self, Mask):
        nx, ny = Mask.shape
        self._MaskArray = np.zeros((1, 1, nx, ny), np.bool8)
        self._MaskArray[0, 0, :, :] = Mask[:, :]

    def Deconvolve(self):

        if self._Dirty.shape[-1] != self._Dirty.shape[-2]:
            # print "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!"
            # print self._Dirty.shape
            # print "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!"
            return "MaxIter", True, True

        dirty = self._Dirty
        nch, npol, _, _ = dirty.shape
        Model = np.zeros_like(dirty)

        _, _, xp, yp = np.where(self._MeanDirty == np.max(self._MeanDirty))
        self.PSFServer.setLocation(xp, yp)
        self.iFacet = self.PSFServer.iFacet

        psf, _ = self.PSFServer.GivePSF()

        Nout = np.min([dirty.shape[-1], psf.shape[-1]])

        if Nout % 2 != 0: Nout -= 1

        s_dirty_cut = self.giveSliceCut(dirty, Nout)
        s_psf_cut = self.giveSliceCut(psf, 2 * Nout)

        if s_psf_cut is None:
            print >> log, ModColor.Str(
                "Could not adapt psf shape to 2*dirty shape!")
            print >> log, ModColor.Str(
                "   shapes are (dirty, psf) = [%s, %s]" %
                (str(dirty.shape), str(psf.shape)))
            s_psf_cut = self.giveSliceCut(psf, Nout)

        for ch in range(nch):
            #print dirty[ch,0,s_dirty_cut,s_dirty_cut].shape
            #print psf[ch,0,s_psf_cut,s_psf_cut].shape
            CM = ClassMoresaneSingleSlice(dirty[ch, 0, s_dirty_cut,
                                                s_dirty_cut],
                                          psf[ch, 0, s_psf_cut, s_psf_cut],
                                          mask=None,
                                          GD=None)
            model, resid = CM.giveModelResid(
                major_loop_miter=self.GD["MORESANE"]["NMajorIter"],
                minor_loop_miter=self.GD["MORESANE"]["NMinorIter"],
                loop_gain=self.GD["MORESANE"]["Gain"],
                sigma_level=self.GD["MORESANE"]
                ["SigmaCutLevel"],  # tolerance=1.,
                enforce_positivity=self.GD["MORESANE"]["ForcePositive"])
            Model[ch, 0, s_dirty_cut, s_dirty_cut] = model[:, :]

        #     import pylab
        #     pylab.clf()
        #     pylab.subplot(2,2,1)
        #     pylab.imshow(dirty[ch,0,SliceDirty,SliceDirty],interpolation="nearest")
        #     pylab.colorbar()

        #     pylab.subplot(2,2,2)
        #     pylab.imshow(psf[ch,0,SlicePSF,SlicePSF],interpolation="nearest")
        #     pylab.colorbar()

        #     pylab.subplot(2,2,3)
        #     pylab.imshow(model,interpolation="nearest")
        #     pylab.colorbar()

        #     pylab.subplot(2,2,4)
        #     pylab.imshow(resid,interpolation="nearest")
        #     pylab.colorbar()

        #     pylab.draw()
        #     pylab.show()

        # print
        # print np.max(np.max(Model,axis=-1),axis=-1)
        # print
        # print

        #_,_,nx,ny=Model.shape
        #Model=np.mean(Model,axis=0).reshape((1,1,nx,ny))

        #Model.fill(0)
        #Model[:,:,xp,yp]=self._Dirty[:,:,xp,yp]

        if self.MultiFreqMode:
            S, Al = self.DoSpectralFit(Model)
            self.ModelMachine.setModel(S, 0)
            self.ModelMachine.setModel(Al, 1)
        else:
            self.ModelMachine.setModel(Model, 0)

        return "MaxIter", True, True  # stop deconvolution but do update model

    def DoSpectralFit(self, Model):
        def GiveResid(X, F, iFacet):
            R = np.zeros_like(F)
            S0, Alpha = X
            for iBand in range(R.size):
                #print iBand,self.SpectralFunctionsMachine.IntExpFunc(Alpha=np.array([0.]),iChannel=iBand,iFacet=iFacet)
                R[iBand] = F[
                    iBand] - S0 * self.SpectralFunctionsMachine.IntExpFunc(
                        Alpha=np.array([Alpha]).ravel(),
                        iChannel=iBand,
                        iFacet=iFacet)

            #stop
            return R

        nx, ny = Model.shape[-2], Model.shape[-1]
        S = np.zeros((1, 1, nx, ny), np.float32)
        Al = np.zeros((1, 1, nx, ny), np.float32)

        for iPix in range(Model.shape[-2]):
            for jPix in range(Model.shape[-1]):
                F = Model[:, 0, iPix, jPix]

                JonesNorm = (self.DicoDirty["JonesNorm"][:, :, iPix,
                                                         jPix]).reshape(
                                                             (-1, 1, 1, 1))
                #W=self.DicoDirty["WeightChansImages"]
                #JonesNorm=np.sum(JonesNorm*W.reshape((-1,1,1,1)),axis=0).reshape((1,1,1,1))

                #F=F/np.sqrt(JonesNorm).ravel()
                F0 = np.mean(F)
                if F0 == 0:
                    continue

                x0 = (F0, -0.8)

                #print self.iFacet,iPix,jPix,F,F0
                X = least_squares(GiveResid,
                                  x0,
                                  args=(F, self.iFacet),
                                  ftol=1e-3,
                                  gtol=1e-3,
                                  xtol=1e-3)
                x = X['x']
                S[0, 0, iPix, jPix] = x[0]
                Al[0, 0, iPix, jPix] = x[1]

        return S, Al
Ejemplo n.º 7
0
class ClassImageDeconvMachine():
    """
    These methods may be called from ClassDeconvMachine
        Init(**kwargs) - contains minor cycle specific initialisations which are only used once
            Input: currently kwargs are minor cycle specific and should be set from ClassDeconvMachine but a
                     ideally a generic interface has these set in the parset somehow.
        Deconvolve() - does joint deconvolution over all the channels/bands.
            Output: return_code - "MaxIter"????
                    continue - whether to continue the deconvolution
                    updated - whether the model has been updated
        GiveModelImage(freq) - returns current model at freq
            Input: freq - tuple of frequencies at which to return the model
            Output: Mod - the current model at freq
        Update(DicoDirty,**kwargs) - updates to minor cycle at the end of each major cycle
            Input:  DicoDirty - updated image dict at start of each major cycle
                    Use kwargs to pass any other minor cycle specific options
        ToFile(fname) - saves dico model to file
            Input: fname - the name of the file to write the dico image to
        FromFile(fname) - reads model dict from file
            Input: fname - the name of the file to write the dico image to
    """
    def __init__(
            self,
            Gain=0.1,
            MaxMinorIter=50000,
            NCPU=0,
            CycleFactor=2.5,
            FluxThreshold=None,
            RMSFactor=3,
            PeakFactor=0,
            GD=None,
            SearchMaxAbs=1,
            CleanMaskImage=None,
            ImagePolDescriptor=["I"],
            ModelMachine=None,
            MainCache=None,
            CacheFileName='WSCMS',
            **kw  # absorb any unknown keywords arguments here
    ):
        self.SearchMaxAbs = SearchMaxAbs
        self.ModelImage = None
        self.MaxMinorIter = MaxMinorIter
        self.NCPU = NCPU
        self.MaskArray = None
        self.GD = GD
        self.MultiFreqMode = (self.GD["Freq"]["NBand"] > 1)
        self.NFreqBand = self.GD["Freq"]["NBand"]
        self.FluxThreshold = FluxThreshold
        self.CycleFactor = CycleFactor
        self.RMSFactor = RMSFactor
        self.PeakFactor = PeakFactor
        if ModelMachine is None:
            # raise RuntimeError("You need to supply ImageDeconvMachine with a instantiated ModelMachine")
            from DDFacet.Imager.WSCMS import ClassModelMachineWSCMS as ClassModelMachine
            self.ModelMachine = ClassModelMachine.ClassModelMachine(
                self.GD, GainMachine=ClassGainMachine.get_instance())
        else:
            self.ModelMachine = ModelMachine
        self.GainMachine = self.ModelMachine.GainMachine
        self._niter = 0

        # cache options
        self.maincache = MainCache
        self.CacheFileName = CacheFileName
        self.PSFHasChanged = False
        self.LastScale = 99999

        self._peakMode = "normal"

        self.CacheFileName = CacheFileName
        self.CurrentNegMask = None
        self._NoiseMap = None
        self._PNRStop = None  # in _peakMode "sigma", provides addiitonal stopping criterion

        # # this is so that the relevant functions are registered as job handlers with APP
        # # pass to ModelMachine.setScaleMachine to set workers
        # self.FTMachine = FFTW_Scale_Manager(wisdom_file=self.GD["Cache"]["DirWisdomFFTW"])
        #
        # APP.registerJobHandlers(self)

    def Init(self,
             cache=None,
             facetcache=None,
             FacetMachine=None,
             BaseName=None,
             **kwargs):
        # check for valid cache
        cachehash = dict([
            (section, self.GD[section])
            for section in ("Data", "Beam", "Selection", "Freq", "Image",
                            "Facets", "Weight", "RIME", "Comp", "CF", "WSCMS")
        ])

        cachepath, valid = self.maincache.checkCache(self.CacheFileName,
                                                     cachehash,
                                                     directory=True,
                                                     reset=cache
                                                     or self.PSFHasChanged)
        # export the hash
        self.maincache.saveCache(name='WSCMS')

        # required to save intermediate images
        self.FacetMachine = FacetMachine
        self.BaseName = BaseName
        self.ModelMachine.setFacetMachine(FacetMachine=self.FacetMachine,
                                          BaseName=self.BaseName)

        self.Freqs = kwargs["GridFreqs"]
        AllDegridFreqs = []
        for i in kwargs["DegridFreqs"].keys():
            AllDegridFreqs.append(kwargs["DegridFreqs"][i])
        self.Freqs_degrid = np.asarray(AllDegridFreqs).flatten()
        self.SetPSF(kwargs["PSFVar"])
        self.setSideLobeLevel(kwargs["PSFAve"][0], kwargs["PSFAve"][1])

        self.ModelMachine.setPSFServer(self.PSFServer)
        self.ModelMachine.setFreqMachine(
            self.Freqs,
            self.Freqs_degrid,
            weights=kwargs["PSFVar"]["WeightChansImages"],
            PSFServer=self.PSFServer)

        from africanus.constants import c as lightspeed
        minlambda = lightspeed / self.Freqs.min()
        # LB - note MaskArray might be modified by ScaleMachine if GD{"WSCMS"]["AutoMask"] is True
        # so we should avoid keeping it as None
        self.Nchan, self.Npol, self.Npix, _ = self.DicoVariablePSF[
            "OutImShape"]
        if self.MaskArray is None:
            self.MaskArray = np.zeros([1, 1, self.Npix, self.Npix],
                                      dtype=np.bool8)
        else:  # Make sure mask is correct shape
            if self.MaskArray.shape != (1, 1, self.Npix, self.Npix):
                raise ValueError(
                    "Mask is incorrect shape. Expected %s but got %s" %
                    ((1, 1, self.Npix, self.Npix), self.MaskArray.shape))
        self.ModelMachine.setScaleMachine(self.PSFServer,
                                          NCPU=self.NCPU,
                                          MaskArray=self.MaskArray,
                                          cachepath=cachepath,
                                          MaxBaseline=kwargs["MaxBaseline"] /
                                          minlambda)

    def Reset(self):
        pass

    def setMaskMachine(self, MaskMachine):
        self.MaskMachine = MaskMachine
        self.MaskMachine.readExternalMaskFromFits()
        if hasattr(self.MaskMachine, 'ExternalMask'):
            self.MaskArray = self.MaskMachine.ExternalMask
        else:
            self.MaskArray = None  # need to set this to None here since we don't yet know the image size etc.

    def SetModelRefFreq(self, RefFreq):
        """
        Sets ref freq in ModelMachine.
        """
        self.ModelMachine.setRefFreq(RefFreq)

    def SetModelShape(self):
        """
        Sets the shape params of model, call in every update step
        """
        assert self._Dirty.shape == (self.Nchan, self.Npol, self.Npix,
                                     self.Npix)
        self.ModelMachine.setModelShape(self._Dirty.shape)
        self.NpixFacet = self.Npix // self.GD["Facets"]["NFacets"]

    def GiveModelImage(self, *args):
        return self.ModelMachine.GiveModelImage(*args)

    def setSideLobeLevel(self, SideLobeLevel, OffsetSideLobe):
        self.SideLobeLevel = SideLobeLevel
        self.OffsetSideLobe = OffsetSideLobe

    def SetPSF(self, DicoVariablePSF):
        """
        The keys in DicoVariablePSF and what they mean:
         'MeanFacetPSF' -    
         'MeanImage' -
         'ImageCube' -
         'CellSizeRad' -
         'ChanMappingGrid' -
         'ChanMappingGridChan' -
         'CubeMeanVariablePSF' -
         'CubeVariablePSF' -
         'SumWeights'           -
         'MeanJonesBand'        -
         'PeakNormed_CubeMeanVariablePSF'
         'PeakNormed_CubeVariablePSF'
         'OutImShape'
         'JonesNorm'
         'Facets'
         'PSFSidelobes'
         'ImageInfo'
         'CentralFacet'
         'freqs'
         'SumJonesChan'
         'SumJonesChanWeightSq'
         'EstimatesAvgPSF'
         'WeightChansImages'
         'FacetNorm'
         'PSFGaussPars'
         'FWHMBeam'
        
        """
        self.PSFServer = ClassPSFServer(self.GD)
        # NormalisePSF must be true here for the beam to be applied correctly
        self.PSFServer.setDicoVariablePSF(DicoVariablePSF, NormalisePSF=True)
        self.DicoVariablePSF = DicoVariablePSF

    def setNoiseMap(self, NoiseMap, PNRStop=10):
        """
        Sets the noise map. The mean dirty will be divided by the noise map before peak finding.
        If PNRStop is set, an additional stopping criterion (peak-to-noisemap) will be applied.
            Peaks are reported in units of sigmas.
        If PNRStop is not set, NoiseMap is treated as simply an (inverse) weighting that will bias
            peak selection in the minor cycle. In this mode, peaks are reported in units of flux.
        """
        self._NoiseMap = NoiseMap
        self._PNRStop = PNRStop
        self._peakMode = "sigma"

    def SetDirty(self, DicoDirty):
        """
        The keys in DicoDirty and what they mean (see also FacetMachine.FacetsToIm docs)
         'JonesNorm' - array containing norm of Jones terms as an image 
         'ImageInfo' - dictionary containing 'CellSizeRad' and 'OutImShape'
         'ImageCube' - array containing residual
         'MeanImage' - array containing mean of the residual
         'freqs' - dictionary keyed by band number containing the actual frequencies that got binned into that band 
         'SumWeights' - sum of visibility weights used in normalizing the gridded correlations
         'FacetMeanResidual' - ???
         'WeightChansImages' - Weights corresponding to imaging bands (how is this computed?)
         'FacetNorm' - self.FacetImage (grid-correcting map) see FacetMachine
        """
        self.DicoDirty = DicoDirty
        self._Dirty = self.DicoDirty["ImageCube"]
        self._MeanDirty = self.DicoDirty["MeanImage"]
        self._JonesNorm = self.DicoDirty["JonesNorm"]
        self.WeightsChansImages = np.mean(np.float32(
            self.DicoDirty["WeightChansImages"]),
                                          axis=1)[:, None, None, None]

    def SubStep(self, dx, dy, LocalSM):
        """
        This is where subtraction in the image domain happens
        """
        xc, yc = dx, dy
        N1 = LocalSM.shape[-1]

        # Get overlap indices where psf should be subtracted
        Aedge, Bedge = GiveEdges(xc, yc, self.Npix, N1 // 2, N1 // 2, N1)

        x0d, x1d, y0d, y1d = Aedge
        x0p, x1p, y0p, y1p = Bedge

        self._Dirty[:, :, x0d:x1d, y0d:y1d] -= LocalSM[:, :, x0p:x1p, y0p:y1p]

        # Subtract from the average
        if self.MultiFreqMode:  # If multiple frequencies are present construct the weighted mean
            self._MeanDirty[:, 0, x0d:x1d, y0d:y1d] -= np.sum(
                LocalSM[:, :, x0p:x1p, y0p:y1p] * self.WeightsChansImages,
                axis=0)  # Sum over freq
        else:
            self._MeanDirty = self._Dirty

    def track_progress(self, i, ThisFlux):
        # This is used to track Cleaning progress
        rounded_iter_step = 1 if i < 10 else (10 if i < 200 else
                                              (100 if i < 2000 else 1000))
        # min(int(10**math.floor(math.log10(i))), 10000)
        if i >= 10 and i % rounded_iter_step == 0:
            # if self.GD["Debug"]["PrintMinorCycleRMS"]:
            # rms = np.std(np.real(self._CubeDirty.ravel()[self.IndStats]))
            print("    [iter=%i] peak residual %.3g" % (i, ThisFlux), file=log)

    def check_stopping_criteria(self):
        # Get RMS stopping criterion
        RMS = np.std(self._MeanDirty)
        Fluxlimit_RMS = self.RMSFactor * RMS

        # Find position and intensity of first peak
        x, y, MaxDirty = NpParallel.A_whereMax(
            self._MeanDirty,
            NCPU=self.NCPU,
            DoAbs=self.GD["Deconv"]["AllowNegative"],
            Mask=self.MaskArray)

        # Get peak factor stopping criterion
        Fluxlimit_Peak = MaxDirty * self.PeakFactor

        # Get side lobe stopping criterion
        Fluxlimit_Sidelobe = (
            (self.CycleFactor - 1.) / 4. * (1. - self.SideLobeLevel) +
            self.SideLobeLevel) * MaxDirty if self.CycleFactor else 0

        mm0, mm1 = self._MeanDirty.min(), self._MeanDirty.max()

        # Choose whichever threshold is highest
        StopFlux = max(Fluxlimit_Peak, Fluxlimit_RMS, Fluxlimit_Sidelobe,
                       self.FluxThreshold)

        print(
            "    Dirty image peak flux      = %10.6f Jy [(min, max) = (%.3g, %.3g) Jy]"
            % (MaxDirty, mm0, mm1),
            file=log)
        print(
            "      RMS-based threshold      = %10.6f Jy [rms = %.3g Jy; RMS factor %.1f]"
            % (Fluxlimit_RMS, RMS, self.RMSFactor),
            file=log)
        print(
            "      Sidelobe-based threshold = %10.6f Jy [sidelobe  = %.3f of peak; cycle factor %.1f]"
            % (Fluxlimit_Sidelobe, self.SideLobeLevel, self.CycleFactor),
            file=log)
        print("      Peak-based threshold     = %10.6f Jy [%.3f of peak]" %
              (Fluxlimit_Peak, self.PeakFactor),
              file=log)
        print("      Absolute threshold       = %10.6f Jy" %
              (self.FluxThreshold),
              file=log)
        print("    Stopping flux              = %10.6f Jy [%.3f of peak ]" %
              (StopFlux, StopFlux / MaxDirty),
              file=log)

        return StopFlux, MaxDirty, RMS

    def Deconvolve(self):
        """
        Runs minor cycle over image channel 'ch'.
        initMinor is number of minor iteration (keeps continuous count through major iterations)
        Nminor is max number of minor iterations

        Returns tuple of: return_code,continue,updated
        where return_code is a status string;
        continue is True if another cycle should be executed (one or more polarizations still need cleaning);
        update is True if one or more polarization models have been updated
        """
        exit_msg = ""
        continue_deconvolution = False
        update_model = False

        # These options should probably be moved into MinorCycleConfig in parset
        print("  Running minor cycle [MinorIter = %i/%i, SearchMaxAbs = %i]" %
              (self._niter, self.MaxMinorIter,
               int(self.GD["Deconv"]["AllowNegative"])),
              file=log)

        # Determine which stopping criterion to use for flux limit
        StopFlux, MaxDirty, RMS = self.check_stopping_criteria()

        TrackRMS = RMS.copy()

        ThisFlux = MaxDirty.copy()

        if ThisFlux < self.FluxThreshold:
            print(ModColor.Str(
                "    Initial maximum peak %g Jy below threshold, we're done CLEANing"
                % (ThisFlux),
                col="green"),
                  file=log)
            exit_msg = exit_msg + " " + "FluxThreshold"
            continue_deconvolution = False or continue_deconvolution
            update_model = False or update_model
            # No need to do anything further if we are already at the stopping flux
            return exit_msg, continue_deconvolution, update_model

        # Do minor cycle deconvolution loop
        TrackFlux = MaxDirty.copy()
        diverged = False
        diverged_count = 0
        stalled = False
        scale_stall_count = {}
        scales_stalled = np.zeros(self.ModelMachine.ScaleMachine.Nscales,
                                  dtype=np.bool)
        # reset retired scales at the start of each major cycle
        self.ModelMachine.ScaleMachine.retired_scales = []
        for scale in self.ModelMachine.ScaleMachine.forbidden_scales:
            self.ModelMachine.ScaleMachine.retired_scales.append(scale)
            scales_stalled[scale] = 1
        try:
            while self._niter <= self.MaxMinorIter:
                # Check if diverging
                if np.abs(ThisFlux) > self.GD["WSCMS"][
                        "MinorDivergenceFactor"] * np.abs(TrackFlux):
                    diverged_count += 1
                    if diverged_count > 5:
                        diverged = True

                TrackFlux = ThisFlux.copy()

                if ThisFlux <= StopFlux or diverged or stalled:
                    if diverged:
                        print(ModColor.Str(
                            "    At [iter=%i] minor cycle is diverging so it has been force stopped at a flux of %.3g Jy"
                            % (self._niter, ThisFlux),
                            col="green"),
                              file=log)
                    elif stalled:
                        print(ModColor.Str(
                            "    At [iter=%i] minor cycle has stalled so it has been force stopped at a flux of %.3g Jy"
                            % (self._niter, ThisFlux),
                            col="green"),
                              file=log)
                    else:
                        print(ModColor.Str(
                            "    CLEANing [iter=%i] peak of %.3g Jy lower than stopping flux"
                            % (self._niter, ThisFlux),
                            col="green"),
                              file=log)
                    cont = ThisFlux > self.FluxThreshold
                    if not cont:
                        print(ModColor.Str(
                            "    CLEANing [iter=%i] absolute flux threshold of %.3g Jy has been reached"
                            % (self._niter, StopFlux),
                            col="green",
                            Bold=True),
                              file=log)
                    exit_msg = exit_msg + " " + "MinFluxRms"
                    continue_deconvolution = cont or continue_deconvolution
                    update_model = True or update_model

                    break  # stop cleaning if threshold reached

                # Find the relevant scale and do sub-minor loop. Note that the dirty cube is updated during the
                # sub-minor loop by subtracting the once convolved PSF's as components are added to the model.
                # The model is updated by adding components to the ModelMachine dictionary.
                niter, iScale = self.ModelMachine.do_minor_loop(
                    self._Dirty, self._MeanDirty, self._JonesNorm,
                    self.WeightsChansImages, ThisFlux, StopFlux, RMS)

                # compute the new mean image from the weighted sum of over frequency
                self._MeanDirty = np.sum(self._Dirty * self.WeightsChansImages,
                                         axis=0,
                                         keepdims=True)

                ThisRMS = np.std(self._MeanDirty)

                # check for and retire scales that cause stalls
                if np.abs((TrackRMS - ThisRMS) /
                          TrackRMS) < self.GD['WSCMS']['MinorStallThreshold']:
                    scale_stall_count.setdefault(iScale, 0)
                    scale_stall_count[iScale] += 1
                    # retire scale if it causes a stall more than x number of times
                    if scale_stall_count[iScale] > 10:
                        self.ModelMachine.ScaleMachine.retired_scales.append(
                            iScale)
                        scales_stalled[iScale] = 1
                        print("Retired scale %i because it was stalling." %
                              iScale,
                              file=log)
                    # if all scales have stalled then we trigger a new major cycle
                    if np.all(scales_stalled):
                        stalled = True
                TrackRMS = ThisRMS.copy()

                # find peak
                x, y, ThisFlux = NpParallel.A_whereMax(
                    self._MeanDirty,
                    NCPU=self.NCPU,
                    DoAbs=self.GD["Deconv"]["AllowNegative"],
                    Mask=self.MaskArray)

                # update counter
                self._niter += niter

                if iScale != self.LastScale:
                    print(
                        "    [iter=%i] peak residual %.8g, rms = %.8g, scale = %i"
                        % (self._niter, ThisFlux, TrackRMS, iScale),
                        file=log)
                    self.LastScale = iScale

        except KeyboardInterrupt:
            print(ModColor.Str(
                "    CLEANing [iter=%i] minor cycle interrupted with Ctrl+C, peak flux %.3g"
                % (self._niter, ThisFlux)),
                  file=log)
            exit_msg = exit_msg + " " + "MaxIter"
            continue_deconvolution = False or continue_deconvolution
            update_model = True or update_model
            return exit_msg, continue_deconvolution, update_model

        if self._niter >= self.MaxMinorIter:  #Reached maximum number of iterations:
            print(ModColor.Str(
                "    CLEANing [iter=%i] Reached maximum number of iterations, peak flux %.3g"
                % (self._niter, ThisFlux)),
                  file=log)
            exit_msg = exit_msg + " " + "MaxIter"
            continue_deconvolution = False or continue_deconvolution
            update_model = True or update_model

        return exit_msg, continue_deconvolution, update_model

    def Update(self, DicoDirty, **kwargs):
        """
        Method to update attributes from ClassDeconvMachine
        """
        #Update image dict
        self.SetDirty(DicoDirty)
        #self.SetModelRefFreq()
        self.SetModelShape()

    def ToFile(self, fname):
        """
        Method to write model image to file
        """
        self.ModelMachine.ToFile(fname)

    def FromFile(self, fname):
        """
        Read model dict from file SubtractModel
        """
        self.ModelMachine.FromFile(fname)

    def updateRMS(self):
        _, npol, npix, _ = self._MeanDirty.shape
        NPixStats = self.GD["Deconv"]["NumRMSSamples"]
        if NPixStats:
            #self.IndStats=np.int64(np.random.rand(NPixStats)*npix**2)
            self.IndStats = np.int64(
                np.linspace(0, self._PeakSearchImage.size - 1, NPixStats))
        else:
            self.IndStats = slice(None)
        self.RMS = np.std(np.real(
            self._PeakSearchImage.ravel()[self.IndStats]))

    def resetCounter(self):
        self._niter = 0
Ejemplo n.º 8
0
class ClassImageDeconvMachine():

    def __init__(self, Gain=0.3,
                 MaxMinorIter=100, 
                 NCPU=1, #psutil.cpu_count()
                 CycleFactor=2.5, 
                 FluxThreshold=None, 
                 RMSFactor=3, 
                 PeakFactor=0,
                 PrevPeakFactor=0,
                 GD=None, 
                 SearchMaxAbs=1, 
                 ModelMachine=None,
                 NFreqBands=1,
                 RefFreq=None,
                 MainCache=None,
                 IdSharedMem="",
                 ParallelMode=True,
                 CacheFileName="HMPBasis",
                 **kw    # absorb any unknown keywords arguments into this
                 ):
        """
        ImageDeconvMachine constructor. Note that this should be called pretty much when setting up the imager,
        before APP workers are started, because the object registers APP handlers.
        """
        self.IdSharedMem=IdSharedMem
        self.SearchMaxAbs=SearchMaxAbs
        self._ModelImage = None
        self.MaxMinorIter = MaxMinorIter
        self.NCPU = NCPU
        self.Chi2Thr = 10000
        self._MaskArray = None
        self.GD = GD
        self.SubPSF = None
        self.MultiFreqMode = NFreqBands > 1
        self.NFreqBands = NFreqBands
        self.RefFreq = RefFreq
        self.FluxThreshold = FluxThreshold
        self.CycleFactor = CycleFactor
        self.RMSFactor = RMSFactor
        self.PeakFactor = PeakFactor
        self.PrevPeakFactor = PrevPeakFactor
        self.CacheFileName=CacheFileName
        self.GainMachine=ClassGainMachine.get_instance()
        self.ModelMachine = None
        self.PSFServer = None
        if ModelMachine is not None:
            self.updateModelMachine(ModelMachine)
        self.PSFHasChanged=False
        self._previous_initial_peak = None
        self.maincache = MainCache
        # reset overall iteration counter
        self._niter = 0
        self.facetcache=None
        self._MaskArray=None
        self.MaskMachine=None
        self.ParallelMode=ParallelMode
        if self.ParallelMode:
            APP.registerJobHandlers(self)

        # we are in a worker
        if not self.ParallelMode:
            numexpr.set_num_threads(NCPU)

        # peak finding mode.
        # "normal" searches for peak in mean dirty image
        # "sigma" searches for peak in mean_dirty/noise_map (setNoiseMap will have been called)
        # "weighted" searched for peak in mean_dirty*weight
        self._peakMode = "normal"

        self.CurrentNegMask=None
        self._NoiseMap=None
        self._PNRStop=None      # in _peakMode "sigma", provides addiitonal stopping criterion

        if self.GD["HMP"]["PeakWeightImage"]:
            print>> log, "  Reading peak weighting image %s" % self.GD["HMP"]["PeakWeightImage"]
            img = image(self.GD["HMP"]["PeakWeightImage"]).getdata()
            _, _, nx, ny = img.shape
            # collapse freq and pol axes
            img = img.sum(axis=1).sum(axis=0).T[::-1].copy()
            self._peakWeightImage = img.reshape((1,1,ny,nx))
            self._peakMode = "weighted"

        self._prevPeak = None

    def setNCPU(self,NCPU):
        self.NCPU=NCPU
        numexpr.set_num_threads(NCPU)

        
    def __del__ (self):
        if type(self.facetcache) is shared_dict.SharedDict:
            self.facetcache.delete()

    def updateMask(self,Mask):
        nx,ny=Mask.shape
        self._MaskArray = np.zeros((1,1,nx,ny),np.bool8)
        self._MaskArray[0,0,:,:]=Mask[:,:]

    def setMaskMachine(self,MaskMachine):
        self.MaskMachine=MaskMachine

    def resetCounter(self):
        self._niter = 0

    def updateModelMachine(self, ModelMachine):
        if ModelMachine.DicoSMStacked["Type"] not in ("MSMF", "HMP"):
            raise ValueError("ModelMachine Type should be HMP")
        if ModelMachine.RefFreq != self.RefFreq:
            raise ValueError("RefFreqs should be equal")

        self.ModelMachine = ModelMachine

        if self.PSFServer is not None:
            for iFacet in range(self.PSFServer.NFacets):
                self.DicoMSMachine[iFacet].setModelMachine(self.ModelMachine)

    def GiveModelImage(self,*args): return self.ModelMachine.GiveModelImage(*args)

    def setSideLobeLevel(self,SideLobeLevel,OffsetSideLobe):
        self.SideLobeLevel=SideLobeLevel
        self.OffsetSideLobe=OffsetSideLobe
        

    def SetPSF(self, DicoVariablePSF, quiet=False):
        self.PSFServer=ClassPSFServer(self.GD)
        self.PSFServer.setDicoVariablePSF(DicoVariablePSF,NormalisePSF=True, quiet=quiet)
        self.PSFServer.setRefFreq(self.RefFreq)
        self.DicoVariablePSF=DicoVariablePSF
        #self.NChannels=self.DicoDirty["NChannels"]

    def Init(self, PSFVar, PSFAve, approx=False, cache=None, facetcache=None, **kwargs):
        """
        Init method. This is called after the first round of gridding: PSFs and such are available.
        ModelMachine must be set by now.
        
        facetcache: dict of basis functions. If supplied, then InitMSMF is not called.
        
        cache: cache the basis functions. If None, GD["Cache"]["HMP"] setting is used
        """
        # close the solutions dump, in case it was opened by a previous HMP instance
        ClassMultiScaleMachine.CleanSolutionsDump.close()
        self.SetPSF(PSFVar)
        self.setSideLobeLevel(PSFAve[0], PSFAve[1])
        if cache is None:
            cache = self.GD["Cache"]["HMP"]
        self.InitMSMF(approx=approx, cache=cache, facetcache=facetcache)
        ## OMS: why is this needed? self.RefFreq is set from self.ModelMachine in the first place
        # self.ModelMachine.setRefFreq(self.RefFreq)
        try:  # LB - this is needed because sometimes kwargs["DegridFreqs"] is an array already
            AllDegridFreqs = []
            for i in kwargs["DegridFreqs"].keys():
                AllDegridFreqs.append(kwargs["DegridFreqs"][i])
            DegridFreqs = np.unique(np.asarray(AllDegridFreqs).flatten())
        except:
            DegridFreqs = kwargs["DegridFreqs"]
        self.ModelMachine.setFreqMachine(kwargs["GridFreqs"], DegridFreqs)

    def Reset(self):
        print>>log, "resetting HMP machine"
        self.DicoMSMachine = {}
        if type(self.facetcache) is shared_dict.SharedDict and self.facetcache.is_writeable():
            print>> log, "deleting HMP facet cache"
            self.facetcache.delete()
        self.facetcache = None

    def setNoiseMap(self, NoiseMap, PNRStop=10):
        """Sets the noise map. The mean dirty will be divided by the noise map before peak finding.
        If PNRStop is set, an additional stopping criterion (peak-to-noisemap) will be applied.
            Peaks are reported in units of sigmas.
        If PNRStop is not set, NoiseMap is treated as simply an (inverse) weighting that will bias
            peak selection in the minor cycle. In this mode, peaks are reported in units of flux.
        """
        self._NoiseMap = NoiseMap
        self._PNRStop = PNRStop
        self._peakMode = "sigma"
        
        
    def _initMSM_handler(self, fcdict, sfdict, psfdict, iFacet, SideLobeLevel, OffsetSideLobe, verbose):
        # init PSF server from PSF shared dict
        self.SetPSF(psfdict, quiet=True)
        MSMachine = self._initMSM_facet(iFacet,fcdict,sfdict,SideLobeLevel,OffsetSideLobe,verbose=verbose)
        del MSMachine

    def _initMSM_facet(self, iFacet, fcdict, sfdict, SideLobeLevel, OffsetSideLobe, MSM0=None, verbose=False):
        """initializes MSM for one facet"""
        self.PSFServer.setFacet(iFacet)
        MSMachine = ClassMultiScaleMachine.ClassMultiScaleMachine(self.GD, fcdict, self.GainMachine, NFreqBands=self.NFreqBands)
        MSMachine.setModelMachine(self.ModelMachine)
        MSMachine.setSideLobeLevel(SideLobeLevel, OffsetSideLobe)
        MSMachine.SetFacet(iFacet)
        MSMachine.SetPSF(self.PSFServer)  # ThisPSF,ThisMeanPSF)
        MSMachine.FindPSFExtent(verbose=verbose)  # only print to log for central facet
        if MSM0 is not None:
            MSMachine.CopyListScales(MSM0)
        else:
            MSMachine.MakeListScales(verbose=verbose, scalefuncs=sfdict)
        MSMachine.MakeMultiScaleCube()
        MSMachine.MakeBasisMatrix()
        return MSMachine


    def InitMSMF(self, approx=False, cache=True, facetcache=None):
        """Initializes MSMF basis functions. If approx is True, then uses the central facet's PSF for
        all facets.
        Populates the self.facetcache dict, unless facetcache is supplied
        """
        self.DicoMSMachine = {}
        valid = True
        if facetcache is not None:
            print>> log, "HMP basis functions pre-initialized"
            self.facetcache = facetcache
        else:
            cachehash = dict(
                [(section, self.GD[section]) for section in (
                    "Data", "Beam", "Selection", "Freq",
                    "Image", "Facets", "Weight", "RIME","DDESolutions",
                    "Comp", "CF",
                    "HMP")])
            cachepath, valid = self.maincache.checkCache(self.CacheFileName, cachehash, reset=not cache or self.PSFHasChanged)
            # do not use cache in approx mode
            if approx or not cache:
                valid = False
            if valid:
                print>>log,"Initialising HMP basis functions from cache %s"%cachepath
                self.facetcache = shared_dict.create(self.CacheFileName)
                self.facetcache.restore(cachepath)
            else:
                self.facetcache = None


        init_cache = self.facetcache is None
        if init_cache:
            self.facetcache = shared_dict.create(self.CacheFileName)

        # in any mode, start by initializing a MS machine for the central facet. This will precompute the scale
        # functions
        centralFacet = self.PSFServer.DicoVariablePSF["CentralFacet"]

        self.DicoMSMachine[centralFacet] = MSM0 = \
            self._initMSM_facet(centralFacet,
                                self.facetcache.addSubdict(centralFacet) if init_cache else self.facetcache[centralFacet],
                                None, self.SideLobeLevel, self.OffsetSideLobe, verbose=True)
        if approx:
            print>>log, "HMP approximation mode: using PSF of central facet (%d)" % centralFacet
            for iFacet in xrange(self.PSFServer.NFacets):
                self.DicoMSMachine[iFacet] = MSM0
        elif (self.GD["Facets"]["NFacets"]==1)&(not self.GD["DDESolutions"]["DDSols"]):
            self.DicoMSMachine[0] = MSM0
            
        else:
            # if no facet cache, init in parallel
            if init_cache:
                for iFacet in xrange(self.PSFServer.NFacets):
                    if iFacet != centralFacet:
                        fcdict = self.facetcache.addSubdict(iFacet)
                        if self.ParallelMode:
                            args=(fcdict.writeonly(), MSM0.ScaleFuncs.readonly(), self.DicoVariablePSF.readonly(),
                                  iFacet, self.SideLobeLevel, self.OffsetSideLobe, False)
                            APP.runJob("InitHMP:%d"%iFacet, self._initMSM_handler,
                                       args=args)
                        else:
                            self.DicoMSMachine[iFacet] = \
                                self._initMSM_facet(iFacet, fcdict, None,
                                                    self.SideLobeLevel, self.OffsetSideLobe, MSM0=MSM0, verbose=False)

                if self.ParallelMode:
                    APP.awaitJobResults("InitHMP:*", progress="Init HMP")
                    self.facetcache.reload()

            #        t = ClassTimeIt.ClassTimeIt()
            # now reinit from cache (since cache was computed by subprocesses)
            for iFacet in xrange(self.PSFServer.NFacets):
                if iFacet not in self.DicoMSMachine:
                    self.DicoMSMachine[iFacet] = \
                        self._initMSM_facet(iFacet, self.facetcache[iFacet], None,
                                            self.SideLobeLevel, self.OffsetSideLobe, MSM0=MSM0, verbose=False)

            # write cache to disk, unless in a mode where we explicitly don't want it
            if facetcache is None and not valid and cache and not approx:
                try:
                    #MyPickle.DicoNPToFile(facetcache,cachepath)
                    #cPickle.dump(facetcache, file(cachepath, 'w'), 2)
                    print>>log,"  saving HMP cache to %s"%cachepath
                    self.facetcache.save(cachepath)
                    #self.maincache.saveCache("HMPMachine")
                    self.maincache.saveCache(self.CacheFileName)
                    self.PSFHasChanged=False
                    print>>log,"  HMP init done"
                except:
                    print>>log, traceback.format_exc()
                    print >>log, ModColor.Str(
                        "WARNING: HMP cache could not be written, see error report above. Proceeding anyway.")

    def SetDirty(self, DicoDirty):#,DoSetMask=True):
        # if len(PSF.shape)==4:
        #     self.PSF=PSF[0,0]
        # else:
        #     self.PSF=PSF

        self.DicoDirty = DicoDirty
        # self.DicoPSF=DicoPSF
        # self.DicoVariablePSF=DicoVariablePSF

        for iFacet in xrange(self.PSFServer.NFacets):
            MSMachine = self.DicoMSMachine[iFacet]
            MSMachine.SetDirty(DicoDirty)

        # self._PSF=self.MSMachine._PSF
        self._CubeDirty = MSMachine._Dirty
        self._MeanDirty = MSMachine._MeanDirty
        
        # vector of per-band overall weights -- starts out as N,1 in the dico, so reshape
        W = np.float32(self.DicoDirty["WeightChansImages"])
        self._band_weights = W.reshape(W.size)[:, np.newaxis, np.newaxis, np.newaxis]

        if self._peakMode is "sigma":
            print>>log,"Will search for the peak in the SNR-weighted dirty map"
            a, b = self._MeanDirty, self._NoiseMap.reshape(self._MeanDirty.shape)
            self._PeakSearchImage = numexpr.evaluate("a/b")
        elif self._peakMode is "weighted":
            print>>log,"Will search for the peak in the weighted dirty map"
            a, b = self._MeanDirty, self._peakWeightImage
            self._PeakSearchImage = numexpr.evaluate("a*b")
        else:
            print>>log,"Will search for the peak in the unweighted dirty map"
            self._PeakSearchImage = self._MeanDirty


        # ########################################
        # op=lambda x: np.abs(x)
        # AA=op(self._PeakSearchImage)
        # _,_,xx, yx = np.where(AA==np.max(AA))
        # print "--!!!!!!!!!!!!!!!!!!!!!!",xx, yx
        # W=np.float32(self.DicoDirty["WeightChansImages"])
        # W=W/np.sum(W)
        # print "W",W
        # print np.sum(self._CubeDirty[:,0,xx,yx].ravel()*W.ravel()),self._MeanDirty[0,0,xx,yx]
        # #print np.mean(self._CubeDirty[:,0,xx,yx])
        # print "--!!!!!!!!!!!!!!!!!!!!!!"
        # ########################################
            
        NPixStats = self.GD["Deconv"]["NumRMSSamples"]
        if NPixStats>0:
            self.IndStats=np.int64(np.linspace(0,self._PeakSearchImage.size-1,NPixStats))
        # self._MeanPSF=self.MSMachine._MeanPSF


        NPSF = self.PSFServer.NPSF
        #_,_,NPSF,_=self._PSF.shape
        _, _, NDirty, _ = self._CubeDirty.shape

        off = (NPSF-NDirty)/2
        self.DirtyExtent = (off, off+NDirty, off, off+NDirty)

#        if self._ModelImage is None:
#            self._ModelImage=np.zeros_like(self._CubeDirty)

        # if DoSetMask:
        #     if self._MaskArray is None:
        #         self._MaskArray=np.zeros(self._MeanDirty.shape,dtype=np.bool8)
        #     else:
        #         maskshape = (1,1,NDirty,NDirty)
        #         # check for mask shape
        #         if maskshape != self._MaskArray.shape:
        #             ma0 = self._MaskArray
        #             _,_,nx,ny = ma0.shape
        #             def match_shapes (n1,n2):
        #                 if n1<n2:
        #                     return slice(None), slice((n2-n1)/2,(n2-n1)/2+n1)
        #                 elif n1>n2:
        #                     return slice((n1-n2)/2,(n1-n2)/2+n2), slice(None)
        #                 else:
        #                     return slice(None), slice(None)
        #             sx1, sx2 = match_shapes(NDirty, nx) 
        #             sy1, sy2 = match_shapes(NDirty, ny) 
        #             self._MaskArray = np.zeros(maskshape, dtype=np.bool8)
        #             self._MaskArray[0,0,sx1,sy1] = ma0[0,0,sx2,sy2]
        #             print>>log,ModColor.Str("WARNING: reshaping mask image from %dx%d to %dx%d"%(nx, ny, NDirty, NDirty))
        #             print>>log,ModColor.Str("Are you sure you supplied the correct cleaning mask?")
        

    def GiveEdges(self,(xc0,yc0),N0,(xc1,yc1),N1):
        M_xc=xc0
        M_yc=yc0
        NpixMain=N0
        F_xc=xc1
        F_yc=yc1
        NpixFacet=N1
                
        ## X
        M_x0=M_xc-NpixFacet/2
        x0main=np.max([0,M_x0])
        dx0=x0main-M_x0
        x0facet=dx0
                
        M_x1=M_xc+NpixFacet/2
        x1main=np.min([NpixMain-1,M_x1])
        dx1=M_x1-x1main
        x1facet=NpixFacet-dx1
        x1main+=1
        ## Y
        M_y0=M_yc-NpixFacet/2
        y0main=np.max([0,M_y0])
        dy0=y0main-M_y0
        y0facet=dy0
        
        M_y1=M_yc+NpixFacet/2
        y1main=np.min([NpixMain-1,M_y1])
        dy1=M_y1-y1main
        y1facet=NpixFacet-dy1
        y1main+=1

        Aedge=[x0main,x1main,y0main,y1main]
        Bedge=[x0facet,x1facet,y0facet,y1facet]
        return Aedge,Bedge
class ClassImageDeconvMachine():
    def __init__(
            self,
            Gain=0.3,
            MaxMinorIter=100,
            NCPU=1,  #psutil.cpu_count()
            CycleFactor=2.5,
            FluxThreshold=None,
            RMSFactor=3,
            PeakFactor=0,
            PrevPeakFactor=0,
            GD=None,
            SearchMaxAbs=1,
            ModelMachine=None,
            NFreqBands=1,
            RefFreq=None,
            MainCache=None,
            IdSharedMem="",
            ParallelMode=True,
            CacheFileName="HMPBasis",
            **kw  # absorb any unknown keywords arguments into this
    ):
        """
        ImageDeconvMachine constructor. Note that this should be called pretty much when setting up the imager,
        before APP workers are started, because the object registers APP handlers.
        """
        self.IdSharedMem = IdSharedMem
        self.SearchMaxAbs = SearchMaxAbs
        self._ModelImage = None
        self.MaxMinorIter = MaxMinorIter
        self.NCPU = NCPU
        self.Chi2Thr = 10000
        self._MaskArray = None
        self.GD = GD
        self.SubPSF = None
        self.MultiFreqMode = NFreqBands > 1
        self.NFreqBands = NFreqBands
        self.RefFreq = RefFreq
        self.FluxThreshold = FluxThreshold
        self.CycleFactor = CycleFactor
        self.RMSFactor = RMSFactor
        self.PeakFactor = PeakFactor
        self.PrevPeakFactor = PrevPeakFactor
        self.CacheFileName = CacheFileName
        self.GainMachine = ClassGainMachine.get_instance()
        self.ModelMachine = None
        self.PSFServer = None
        if ModelMachine is not None:
            self.updateModelMachine(ModelMachine)
        self.PSFHasChanged = False
        self._previous_initial_peak = None
        self.maincache = MainCache
        # reset overall iteration counter
        self._niter = 0
        self.facetcache = None
        self._MaskArray = None
        self.MaskMachine = None
        self.ParallelMode = ParallelMode
        if self.ParallelMode:
            APP.registerJobHandlers(self)

        # we are in a worker
        if not self.ParallelMode:
            numexpr.set_num_threads(NCPU)

        # peak finding mode.
        # "normal" searches for peak in mean dirty image
        # "sigma" searches for peak in mean_dirty/noise_map (setNoiseMap will have been called)
        # "weighted" searched for peak in mean_dirty*weight
        self._peakMode = "normal"

        self.CurrentNegMask = None
        self._NoiseMap = None
        self._PNRStop = None  # in _peakMode "sigma", provides addiitonal stopping criterion

        if self.GD["HMP"]["PeakWeightImage"]:
            print("  Reading peak weighting image %s" %
                  self.GD["HMP"]["PeakWeightImage"],
                  file=log)
            img = image(self.GD["HMP"]["PeakWeightImage"]).getdata()
            _, _, nx, ny = img.shape
            # collapse freq and pol axes
            img = img.sum(axis=1).sum(axis=0).T[::-1].copy()
            self._peakWeightImage = img.reshape((1, 1, ny, nx))
            self._peakMode = "weighted"

        self._prevPeak = None

    def setNCPU(self, NCPU):
        self.NCPU = NCPU
        numexpr.set_num_threads(NCPU)

    def __del__(self):
        if shared_dict is not None and type(
                self.facetcache) is shared_dict.SharedDict:
            self.facetcache.delete()

    def updateMask(self, Mask):
        nx, ny = Mask.shape
        self._MaskArray = np.zeros((1, 1, nx, ny), np.bool8)
        self._MaskArray[0, 0, :, :] = Mask[:, :]

    def setMaskMachine(self, MaskMachine):
        self.MaskMachine = MaskMachine

    def resetCounter(self):
        self._niter = 0

    def updateModelMachine(self, ModelMachine):
        if ModelMachine.DicoSMStacked["Type"] not in ("MSMF", "HMP"):
            raise ValueError("ModelMachine Type should be HMP")
        if ModelMachine.RefFreq != self.RefFreq:
            raise ValueError("RefFreqs should be equal")

        self.ModelMachine = ModelMachine

        if self.PSFServer is not None:
            for iFacet in range(self.PSFServer.NFacets):
                self.DicoMSMachine[iFacet].setModelMachine(self.ModelMachine)

    def GiveModelImage(self, *args):
        return self.ModelMachine.GiveModelImage(*args)

    def setSideLobeLevel(self, SideLobeLevel, OffsetSideLobe):
        self.SideLobeLevel = SideLobeLevel
        self.OffsetSideLobe = OffsetSideLobe

    def SetPSF(self, DicoVariablePSF, quiet=False):
        self.PSFServer = ClassPSFServer(self.GD)
        self.PSFServer.setDicoVariablePSF(DicoVariablePSF,
                                          NormalisePSF=True,
                                          quiet=quiet)
        self.PSFServer.setRefFreq(self.RefFreq)
        self.DicoVariablePSF = DicoVariablePSF
        #self.NChannels=self.DicoDirty["NChannels"]

    def Init(self,
             PSFVar,
             PSFAve,
             approx=False,
             cache=None,
             facetcache=None,
             **kwargs):
        """
        Init method. This is called after the first round of gridding: PSFs and such are available.
        ModelMachine must be set by now.
        
        facetcache: dict of basis functions. If supplied, then InitMSMF is not called.
        
        cache: cache the basis functions. If None, GD["Cache"]["HMP"] setting is used
        """
        # close the solutions dump, in case it was opened by a previous HMP instance
        ClassMultiScaleMachine.CleanSolutionsDump.close()
        self.SetPSF(PSFVar)
        self.setSideLobeLevel(PSFAve[0], PSFAve[1])
        if cache is None:
            cache = self.GD["Cache"]["HMP"]
        self.InitMSMF(approx=approx, cache=cache, facetcache=facetcache)
        ## OMS: why is this needed? self.RefFreq is set from self.ModelMachine in the first place
        # self.ModelMachine.setRefFreq(self.RefFreq)

    def Reset(self):
        print("resetting HMP machine", file=log)
        self.DicoMSMachine = {}
        if type(self.facetcache
                ) is shared_dict.SharedDict and self.facetcache.is_writeable():
            print("deleting HMP facet cache", file=log)
            self.facetcache.delete()
        self.facetcache = None

    def setNoiseMap(self, NoiseMap, PNRStop=10):
        """Sets the noise map. The mean dirty will be divided by the noise map before peak finding.
        If PNRStop is set, an additional stopping criterion (peak-to-noisemap) will be applied.
            Peaks are reported in units of sigmas.
        If PNRStop is not set, NoiseMap is treated as simply an (inverse) weighting that will bias
            peak selection in the minor cycle. In this mode, peaks are reported in units of flux.
        """
        self._NoiseMap = NoiseMap
        self._PNRStop = PNRStop
        self._peakMode = "sigma"

    def _initMSM_handler(self, fcdict, sfdict, psfdict, iFacet, SideLobeLevel,
                         OffsetSideLobe, verbose):
        # init PSF server from PSF shared dict
        self.SetPSF(psfdict, quiet=True)
        MSMachine = self._initMSM_facet(iFacet,
                                        fcdict,
                                        sfdict,
                                        SideLobeLevel,
                                        OffsetSideLobe,
                                        verbose=verbose)
        del MSMachine

    def _initMSM_facet(self,
                       iFacet,
                       fcdict,
                       sfdict,
                       SideLobeLevel,
                       OffsetSideLobe,
                       MSM0=None,
                       verbose=False):
        """initializes MSM for one facet"""
        self.PSFServer.setFacet(iFacet)
        MSMachine = ClassMultiScaleMachine.ClassMultiScaleMachine(
            self.GD, fcdict, self.GainMachine, NFreqBands=self.NFreqBands)
        MSMachine.setModelMachine(self.ModelMachine)
        MSMachine.setSideLobeLevel(SideLobeLevel, OffsetSideLobe)
        MSMachine.SetFacet(iFacet)
        MSMachine.SetPSF(self.PSFServer)  # ThisPSF,ThisMeanPSF)
        MSMachine.FindPSFExtent(
            verbose=verbose)  # only print to log for central facet
        if MSM0 is not None:
            MSMachine.CopyListScales(MSM0)
        else:
            MSMachine.MakeListScales(verbose=verbose, scalefuncs=sfdict)
        MSMachine.MakeMultiScaleCube()
        MSMachine.MakeBasisMatrix()
        return MSMachine

    def InitMSMF(self, approx=False, cache=True, facetcache=None):
        """Initializes MSMF basis functions. If approx is True, then uses the central facet's PSF for
        all facets.
        Populates the self.facetcache dict, unless facetcache is supplied
        """
        self.DicoMSMachine = {}
        valid = True
        if facetcache is not None:
            print("HMP basis functions pre-initialized", file=log)
            self.facetcache = facetcache
        else:
            cachehash = dict([
                (section, self.GD[section])
                for section in ("Data", "Beam", "Selection", "Freq", "Image",
                                "Facets", "Weight", "RIME", "DDESolutions",
                                "Comp", "CF", "HMP")
            ])
            cachepath, valid = self.maincache.checkCache(self.CacheFileName,
                                                         cachehash,
                                                         reset=not cache
                                                         or self.PSFHasChanged)
            # do not use cache in approx mode
            if approx or not cache:
                valid = False
            if valid:
                print("Initialising HMP basis functions from cache %s" %
                      cachepath,
                      file=log)
                self.facetcache = shared_dict.create(self.CacheFileName)
                self.facetcache.restore(cachepath)
            else:
                self.facetcache = None

        init_cache = self.facetcache is None
        if init_cache:
            self.facetcache = shared_dict.create(self.CacheFileName)

        # in any mode, start by initializing a MS machine for the central facet. This will precompute the scale
        # functions
        centralFacet = self.PSFServer.DicoVariablePSF["CentralFacet"]

        self.DicoMSMachine[centralFacet] = MSM0 = \
            self._initMSM_facet(centralFacet,
                                self.facetcache.addSubdict(centralFacet) if init_cache else self.facetcache[centralFacet],
                                None, self.SideLobeLevel, self.OffsetSideLobe, verbose=True)
        if approx:
            print("HMP approximation mode: using PSF of central facet (%d)" %
                  centralFacet,
                  file=log)
            for iFacet in range(self.PSFServer.NFacets):
                self.DicoMSMachine[iFacet] = MSM0
        elif (self.GD["Facets"]["NFacets"]
              == 1) & (not self.GD["DDESolutions"]["DDSols"]):
            self.DicoMSMachine[0] = MSM0

        else:
            # if no facet cache, init in parallel
            if init_cache:
                for iFacet in range(self.PSFServer.NFacets):
                    if iFacet != centralFacet:
                        fcdict = self.facetcache.addSubdict(iFacet)
                        if self.ParallelMode:
                            args = (fcdict.writeonly(),
                                    MSM0.ScaleFuncs.readonly(),
                                    self.DicoVariablePSF.readonly(), iFacet,
                                    self.SideLobeLevel, self.OffsetSideLobe,
                                    False)
                            APP.runJob("InitHMP:%d" % iFacet,
                                       self._initMSM_handler,
                                       args=args)
                        else:
                            self.DicoMSMachine[iFacet] = \
                                self._initMSM_facet(iFacet, fcdict, None,
                                                    self.SideLobeLevel, self.OffsetSideLobe, MSM0=MSM0, verbose=False)

                if self.ParallelMode:
                    APP.awaitJobResults("InitHMP:*", progress="Init HMP")
                    self.facetcache.reload()

            #        t = ClassTimeIt.ClassTimeIt()
            # now reinit from cache (since cache was computed by subprocesses)
            for iFacet in range(self.PSFServer.NFacets):
                if iFacet not in self.DicoMSMachine:
                    self.DicoMSMachine[iFacet] = \
                        self._initMSM_facet(iFacet, self.facetcache[iFacet], None,
                                            self.SideLobeLevel, self.OffsetSideLobe, MSM0=MSM0, verbose=False)

            # write cache to disk, unless in a mode where we explicitly don't want it
            if facetcache is None and not valid and cache and not approx:
                try:
                    #MyPickle.DicoNPToFile(facetcache,cachepath)
                    #cPickle.dump(facetcache, open(cachepath, 'w'), 2)
                    print("  saving HMP cache to %s" % cachepath, file=log)
                    self.facetcache.save(cachepath)
                    #self.maincache.saveCache("HMPMachine")
                    self.maincache.saveCache(self.CacheFileName)
                    self.PSFHasChanged = False
                    print("  HMP init done", file=log)
                except:
                    print(traceback.format_exc(), file=log)
                    print(ModColor.Str(
                        "WARNING: HMP cache could not be written, see error report above. Proceeding anyway."
                    ),
                          file=log)

    def SetDirty(self, DicoDirty):  #,DoSetMask=True):
        # if len(PSF.shape)==4:
        #     self.PSF=PSF[0,0]
        # else:
        #     self.PSF=PSF

        self.DicoDirty = DicoDirty
        # self.DicoPSF=DicoPSF
        # self.DicoVariablePSF=DicoVariablePSF

        for iFacet in range(self.PSFServer.NFacets):
            MSMachine = self.DicoMSMachine[iFacet]
            MSMachine.SetDirty(DicoDirty)

        # self._PSF=self.MSMachine._PSF
        self._CubeDirty = MSMachine._Dirty
        self._MeanDirty = MSMachine._MeanDirty

        # vector of per-band overall weights -- starts out as N,1 in the dico, so reshape
        W = np.float32(self.DicoDirty["WeightChansImages"])
        self._band_weights = W.reshape(W.size)[:, np.newaxis, np.newaxis,
                                               np.newaxis]

        if self._peakMode is "sigma":
            print("Will search for the peak in the SNR-weighted dirty map",
                  file=log)
            a, b = self._MeanDirty, self._NoiseMap.reshape(
                self._MeanDirty.shape)
            self._PeakSearchImage = numexpr.evaluate("a/b")
        elif self._peakMode is "weighted":
            print("Will search for the peak in the weighted dirty map",
                  file=log)
            a, b = self._MeanDirty, self._peakWeightImage
            self._PeakSearchImage = numexpr.evaluate("a*b")
        else:
            print("Will search for the peak in the unweighted dirty map",
                  file=log)
            self._PeakSearchImage = self._MeanDirty

        # ########################################
        # op=lambda x: np.abs(x)
        # AA=op(self._PeakSearchImage)
        # _,_,xx, yx = np.where(AA==np.max(AA))
        # print "--!!!!!!!!!!!!!!!!!!!!!!",xx, yx
        # W=np.float32(self.DicoDirty["WeightChansImages"])
        # W=W/np.sum(W)
        # print "W",W
        # print np.sum(self._CubeDirty[:,0,xx,yx].ravel()*W.ravel()),self._MeanDirty[0,0,xx,yx]
        # #print np.mean(self._CubeDirty[:,0,xx,yx])
        # print "--!!!!!!!!!!!!!!!!!!!!!!"
        # ########################################

        NPixStats = self.GD["Deconv"]["NumRMSSamples"]
        if NPixStats > 0:
            self.IndStats = np.int64(
                np.linspace(0, self._PeakSearchImage.size - 1, NPixStats))
        # self._MeanPSF=self.MSMachine._MeanPSF

        NPSF = self.PSFServer.NPSF
        #_,_,NPSF,_=self._PSF.shape
        _, _, NDirty, _ = self._CubeDirty.shape

        off = (NPSF - NDirty) // 2
        self.DirtyExtent = (off, off + NDirty, off, off + NDirty)

#        if self._ModelImage is None:
#            self._ModelImage=np.zeros_like(self._CubeDirty)

# if DoSetMask:
#     if self._MaskArray is None:
#         self._MaskArray=np.zeros(self._MeanDirty.shape,dtype=np.bool8)
#     else:
#         maskshape = (1,1,NDirty,NDirty)
#         # check for mask shape
#         if maskshape != self._MaskArray.shape:
#             ma0 = self._MaskArray
#             _,_,nx,ny = ma0.shape
#             def match_shapes (n1,n2):
#                 if n1<n2:
#                     return slice(None), slice((n2-n1)/2,(n2-n1)/2+n1)
#                 elif n1>n2:
#                     return slice((n1-n2)/2,(n1-n2)/2+n2), slice(None)
#                 else:
#                     return slice(None), slice(None)
#             sx1, sx2 = match_shapes(NDirty, nx)
#             sy1, sy2 = match_shapes(NDirty, ny)
#             self._MaskArray = np.zeros(maskshape, dtype=np.bool8)
#             self._MaskArray[0,0,sx1,sy1] = ma0[0,0,sx2,sy2]
#             print>>log,ModColor.Str("WARNING: reshaping mask image from %dx%d to %dx%d"%(nx, ny, NDirty, NDirty))
#             print>>log,ModColor.Str("Are you sure you supplied the correct cleaning mask?")

    def GiveEdges(self, xc0, yc0, N0, xc1, yc1, N1):
        M_xc = xc0
        M_yc = yc0
        NpixMain = N0
        F_xc = xc1
        F_yc = yc1
        NpixFacet = N1

        ## X
        M_x0 = M_xc - NpixFacet // 2
        x0main = np.max([0, M_x0])
        dx0 = x0main - M_x0
        x0facet = dx0

        M_x1 = M_xc + NpixFacet // 2
        x1main = np.min([NpixMain - 1, M_x1])
        dx1 = M_x1 - x1main
        x1facet = NpixFacet - dx1
        x1main += 1
        ## Y
        M_y0 = M_yc - NpixFacet // 2
        y0main = np.max([0, M_y0])
        dy0 = y0main - M_y0
        y0facet = dy0

        M_y1 = M_yc + NpixFacet // 2
        y1main = np.min([NpixMain - 1, M_y1])
        dy1 = M_y1 - y1main
        y1facet = NpixFacet - dy1
        y1main += 1

        Aedge = [x0main, x1main, y0main, y1main]
        Bedge = [x0facet, x1facet, y0facet, y1facet]
        return Aedge, Bedge

    def SubStep(self, dx, dy, LocalSM):
        _, npol, _, _ = self._MeanDirty.shape
        x0, x1, y0, y1 = self.DirtyExtent

        xc, yc = dx, dy
        #NpixFacet=self.SubPSF.shape[-1]
        #PSF=self.CubePSFScales[iScale]
        N0 = self._MeanDirty.shape[-1]
        N1 = LocalSM.shape[-1]

        # PSF=PSF[N1/2-1:N1/2+2,N1/2-1:N1/2+2]
        # N1=PSF.shape[-1]

        #Aedge,Bedge=self.GiveEdges(xc,yc,N0,N1/2,N1/2,N1)
        N0x, N0y = self._MeanDirty.shape[-2::]
        Aedge, Bedge = GiveEdgesDissymetric(xc, yc, N0x, N0y, N1 // 2, N1 // 2,
                                            N1, N1)

        #_,n,n=self.PSF.shape
        # PSF=self.PSF.reshape((n,n))
        # print "Fpol00",Fpol
        factor = -1.  # Fpol[0,0,0]*self.Gain
        # print "Fpol01",Fpol

        nch, npol, nx, ny = LocalSM.shape
        # print Fpol[0,0,0]
        # print Aedge
        # print Bedge

        #print>>log, "    Removing %f Jy at (%i %i) (peak of %f Jy)"%(Fpol[0,0,0]*self.Gain,dx,dy,Fpol[0,0,0])
        # PSF=self.PSF[0]

        x0d, x1d, y0d, y1d = Aedge
        x0p, x1p, y0p, y1p = Bedge

        # nxPSF=self.CubePSFScales.shape[-1]
        # x0,x1=nxPSF/2-self.SupWeightWidth,nxPSF/2+self.SupWeightWidth+1
        # y0,y1=nxPSF/2-self.SupWeightWidth,nxPSF/2+self.SupWeightWidth+1
        # x0p=x0+x0p
        # x1p=x0+x1p
        # y0p=y0+y0p
        # y1p=y0+y1p
        # Bedge=x0p,x1p,y0p,y1p

        # # import pylab
        # pylab.clf()
        # ax=pylab.subplot(1,3,1)
        # vmin,vmax=self._CubeDirty.min(),self._CubeDirty.max()
        # pylab.imshow(self._MeanDirty[0,0,x0d:x1d,y0d:y1d],interpolation="nearest",vmin=vmin,vmax=vmax)
        # pylab.colorbar()
        # pylab.subplot(1,3,2,sharex=ax,sharey=ax)
        # pylab.imshow(np.mean(LocalSM,axis=0)[0,x0p:x1p,y0p:y1p],interpolation="nearest",vmin=vmin,vmax=vmax)
        # pylab.colorbar()
        # pylab.draw()

        # #print "Fpol02",Fpol
        # # NpParallel.A_add_B_prod_factor((self.Dirty),LocalSM,Aedge,Bedge,factor=float(factor),NCPU=self.NCPU)

        # <<<<<<< HEAD

        #         self._CubeDirty[:,:,x0d:x1d,y0d:y1d] -= LocalSM[:,:,x0p:x1p,y0p:y1p]

        # =======
        # self._CubeDirty[:,:,x0d:x1d,y0d:y1d] -= LocalSM[:,:,x0p:x1p,y0p:y1p]
        cube, sm = self._CubeDirty[:, :, x0d:x1d,
                                   y0d:y1d], LocalSM[:, :, x0p:x1p, y0p:y1p]

        # if self.DoPlot:
        #     AA0=cube[0,0,:,:].copy()
        #     vmin,vmax=np.min(AA0),np.max(AA0)
        #     AA1=sm[0,0,:,:].copy()
        #     import pylab
        #     pylab.clf()
        #     pylab.subplot(1,3,1)
        #     pylab.imshow(AA0,interpolation="nearest")
        #     pylab.colorbar()
        #     pylab.subplot(1,3,2)
        #     pylab.imshow(AA1,interpolation="nearest")
        #     pylab.colorbar()
        #     pylab.subplot(1,3,3)
        #     pylab.imshow((AA0-AA1),interpolation="nearest")
        #     pylab.colorbar()
        #     pylab.draw()
        #     pylab.show(False)
        #     pylab.pause(0.1)

        numexpr.evaluate('cube-sm', out=cube, casting="unsafe")
        #a-=b

        if self._MeanDirty is not self._CubeDirty:
            ### old code, got MeanDirty out of alignment with CubeDirty somehow
            ## W=np.float32(self.DicoDirty["WeightChansImages"])
            ## self._MeanDirty[0,:,x0d:x1d,y0d:y1d]-=np.sum(LocalSM[:,:,x0p:x1p,y0p:y1p]*W.reshape((W.size,1,1,1)),axis=0)
            meanimage = self._MeanDirty[:, :, x0d:x1d, y0d:y1d]

            # cube.mean(axis=0, out=meanimage) should be a bit faster, but we experienced problems with some numpy versions,
            # see https://github.com/cyriltasse/DDFacet/issues/325
            # So use array copy instead (which makes an intermediate array)
            if cube.shape[0] > 1:
                meanimage[...] = (cube * self._band_weights).sum(axis=0)

                # cube.mean(axis=0, out=meanimage)
            else:
                meanimage[0, ...] = cube[0, ...]

            # ## this is slower:
            # self._MeanDirty[0,:,x0d:x1d,y0d:y1d] = self._CubeDirty[:,:,x0d:x1d,y0d:y1d].mean(axis=0)

            # np.save("_MeanDirty",self._MeanDirty)
            # np.save("_CubeDirty",self._CubeDirty)
            # stop

        if self._peakMode is "sigma":
            a, b = self._MeanDirty[:, :, x0d:x1d,
                                   y0d:y1d], self._NoiseMap[:, :, x0d:x1d,
                                                            y0d:y1d]
            numexpr.evaluate("a/b",
                             out=self._PeakSearchImage[:, :, x0d:x1d, y0d:y1d])
        elif self._peakMode is "weighted":
            a, b = self._MeanDirty[:, :, x0d:x1d,
                                   y0d:y1d], self._peakWeightImage[:, :,
                                                                   x0d:x1d,
                                                                   y0d:y1d]
            numexpr.evaluate("a*b",
                             out=self._PeakSearchImage[:, :, x0d:x1d, y0d:y1d])

            # pylab.subplot(1,3,3,sharex=ax,sharey=ax)
        # pylab.imshow(self._MeanDirty[0,0,x0d:x1d,y0d:y1d],interpolation="nearest",vmin=vmin,vmax=vmax)#,vmin=vmin,vmax=vmax)
        # pylab.colorbar()
        # pylab.draw()
        # pylab.show(False)
        # pylab.pause(0.1)
        # # print Aedge
        # #unc print Bedge
        # # print self.Dirty[0,x0d:x1d,y0d:y1d]

    def Plot(self):
        import pylab
        pylab.clf()
        pylab.subplot(1, 3, 1)
        pylab.imshow(self._CubeDirty[0, 0])
        pylab.colorbar()
        pylab.subplot(1, 3, 2)
        pylab.imshow(self._CubeDirty[1, 0])
        pylab.colorbar()
        pylab.draw()
        pylab.show()

    def updateRMS(self):
        _, npol, npix, _ = self._MeanDirty.shape
        NPixStats = self.GD["Deconv"]["NumRMSSamples"]
        if NPixStats:
            #self.IndStats=np.int64(np.random.rand(NPixStats)*npix**2)
            self.IndStats = np.int64(
                np.linspace(0, self._PeakSearchImage.size - 1, NPixStats))
        else:
            self.IndStats = slice(None)
        self.RMS = np.std(np.real(
            self._PeakSearchImage.ravel()[self.IndStats]))

    def setMask(self, Mask):
        self.CurrentNegMask = Mask

    def Deconvolve(self, ch=0, UpdateRMS=True):
        """
        Runs minor cycle over image channel 'ch'.
        initMinor is number of minor iteration (keeps continuous count through major iterations)
        Nminor is max number of minor iteration

        Returns tuple of: return_code,continue,updated
        where return_code is a status string;
        continue is True if another cycle should be executed;
        update is True if model has been updated (note update=False implies continue=False)
        """
        if self._niter >= self.MaxMinorIter:
            return "MaxIter", False, False

        _, npol, npix, _ = self._MeanDirty.shape
        xc = (npix) // 2

        # m0,m1=self._CubeDirty.min(),self._CubeDirty.max()
        # pylab.clf()
        # pylab.subplot(1,2,1)
        # pylab.imshow(self.Dirty[0],interpolation="nearest",vmin=m0,vmax=m1)
        # pylab.draw()
        # pylab.show(False)
        # pylab.pause(0.1)

        DoAbs = int(self.GD["Deconv"]["AllowNegative"])
        print("  Running minor cycle [MinorIter = %i/%i, SearchMaxAbs = %i]" %
              (self._niter, self.MaxMinorIter, DoAbs),
              file=log)

        if UpdateRMS: self.updateRMS()
        RMS = self.RMS
        self.GainMachine.SetRMS(RMS)

        Fluxlimit_RMS = self.RMSFactor * RMS
        #print "startmax",self._MeanDirty.shape,self._MaskArray.shape

        if self.CurrentNegMask is not None:
            print("  using externally defined Mask (self.CurrentNegMask)",
                  file=log)
            CurrentNegMask = self.CurrentNegMask
        elif self.MaskMachine:
            print("  using MaskMachine Mask", file=log)
            CurrentNegMask = self.MaskMachine.CurrentNegMask
        elif self._MaskArray is not None:
            print("  using externally defined Mask (self._MaskArray)",
                  file=log)
            CurrentNegMask = self._MaskArray
        else:
            print("  not using a mask", file=log)
            CurrentNegMask = None

        x, y, MaxDirty = NpParallel.A_whereMax(self._PeakSearchImage,
                                               NCPU=self.NCPU,
                                               DoAbs=DoAbs,
                                               Mask=CurrentNegMask)

        # ThisFlux is evaluated against stopping criteria. In weighted mode, use the true flux. Else use sigma value.
        ThisFlux = self._MeanDirty[
            0, 0, x, y] if self._peakMode is "weighted" else MaxDirty
        if DoAbs:
            ThisFlux = abs(ThisFlux)
        # in weighted or noisemap mode, look up the true max as well
        trueMaxDirty = MaxDirty if self._peakMode is "normal" else ThisFlux
        # return condition indicating cleaning is to be continued
        cont = True

        CondPeak = (self._previous_initial_peak is not None)
        CondDiverge = False
        if self._previous_initial_peak is not None:
            CondDiverge = (abs(ThisFlux) >
                           self.GD["HMP"]["MajorStallThreshold"] *
                           self._previous_initial_peak)
        CondPeakType = (self._peakMode != "sigma")

        if CondPeak and CondDiverge and CondPeakType:
            print(ModColor.Str(
                "STALL! dirty image peak %10.6g Jy, was %10.6g at previous major cycle."
                % (ThisFlux, self._previous_initial_peak),
                col="red"),
                  file=log)
            print(ModColor.Str("This will be the last major cycle"), file=log)
            cont = False

        self._previous_initial_peak = abs(ThisFlux)
        #x,y,MaxDirty=NpParallel.A_whereMax(self._MeanDirty.copy(),NCPU=1,DoAbs=DoAbs,Mask=self._MaskArray.copy())
        #A=self._MeanDirty.copy()
        #A.flat[:]=np.arange(A.size)[:]
        #x,y,MaxDirty=NpParallel.A_whereMax(A,NCPU=1,DoAbs=DoAbs)
        #print "max",x,y
        #stop

        # print>>log,"npp: %d %d %g"%(x,y,MaxDirty)
        # xy = ma.argmax(ma.masked_array(abs(self._MeanDirty), self._MaskArray))
        # x1, y1 = xy/npix, xy%npix
        # MaxDirty1 = abs(self._MeanDirty[0,0,x1,y1])
        # print>>log,"argmax: %d %d %g"%(x1,y1,MaxDirty1)

        Fluxlimit_Peak = ThisFlux * self.PeakFactor
        # if previous peak is not set (i.e. first major cycle), use current dirty image peak instead
        Fluxlimit_PrevPeak = (self._prevPeak if self._prevPeak is not None else
                              ThisFlux) * self.PrevPeakFactor
        Fluxlimit_Sidelobe = (
            (self.CycleFactor - 1.) / 4. * (1. - self.SideLobeLevel) +
            self.SideLobeLevel) * ThisFlux if self.CycleFactor else 0

        mm0, mm1 = self._PeakSearchImage.min(), self._PeakSearchImage.max()

        # work out upper peak threshold
        StopFlux = max(Fluxlimit_Peak, Fluxlimit_RMS, Fluxlimit_Sidelobe,
                       Fluxlimit_Peak, Fluxlimit_PrevPeak, self.FluxThreshold)

        print(
            "    Dirty image peak           = %10.6g Jy [(min, max) = (%.3g, %.3g) Jy]"
            % (trueMaxDirty, mm0, mm1),
            file=log)
        if self._peakMode is "sigma":
            print("      in sigma units           = %10.6g" % MaxDirty,
                  file=log)
        elif self._peakMode is "weighted":
            print("      weighted peak flux is    = %10.6g Jy" % MaxDirty,
                  file=log)
        print(
            "      RMS-based threshold      = %10.6g Jy [rms = %.3g Jy; RMS factor %.1f]"
            % (Fluxlimit_RMS, RMS, self.RMSFactor),
            file=log)
        print(
            "      Sidelobe-based threshold = %10.6g Jy [sidelobe  = %.3f of peak; cycle factor %.1f]"
            % (Fluxlimit_Sidelobe, self.SideLobeLevel, self.CycleFactor),
            file=log)
        print("      Peak-based threshold     = %10.6g Jy [%.3f of peak]" %
              (Fluxlimit_Peak, self.PeakFactor),
              file=log)
        print(
            "      Previous peak-based thr  = %10.6g Jy [%.3f of previous minor cycle peak]"
            % (Fluxlimit_PrevPeak, self.PrevPeakFactor),
            file=log)
        print("      Absolute threshold       = %10.6g Jy" %
              (self.FluxThreshold),
              file=log)
        print("    Stopping flux              = %10.6g Jy [%.3f of peak ]" %
              (StopFlux, StopFlux / ThisFlux),
              file=log)
        rms = RMS
        # MaxModelInit=np.max(np.abs(self.ModelImage))
        # Fact=4
        # self.BookKeepShape=(npix/Fact,npix/Fact)
        # BookKeep=np.zeros(self.BookKeepShape,np.float32)
        # NPixBook,_=self.BookKeepShape
        # FactorBook=float(NPixBook)/npix

        T = ClassTimeIt.ClassTimeIt()
        T.disable()

        # #print x,y
        # print>>log, "npp: %d %d %g"%(x,y,ThisFlux)
        # xy = ma.argmax(ma.masked_array(abs(self._MeanDirty), self._MaskArray))
        # x, y = xy/npix, xy%npix
        # ThisFlux = abs(self._MeanDirty[0,0,x,y])
        # print>> log, "argmax: %d %d %g"%(x, y, ThisFlux)

        if ThisFlux < StopFlux:
            print(ModColor.Str(
                "    Initial maximum peak %10.6g Jy below threshold, we're done here"
                % ThisFlux,
                col="green"),
                  file=log)
            return "FluxThreshold", False, False

        # self._MaskArray.fill(1)
        # self._MaskArray.fill(0)
        #self._MaskArray[np.abs(self._MeanDirty) > Fluxlimit_Sidelobe]=0

        #        DoneScale=np.zeros((self.MSMachine.NScales,),np.float32)

        PreviousMaxFlux = 1e30

        # pBAR= ProgressBar('white', width=50, block='=', empty=' ',Title="Cleaning   ", HeaderSize=20,TitleSize=30)
        # pBAR.disable()

        self.GainMachine.SetFluxMax(ThisFlux)
        # pBAR.render(0,"g=%3.3f"%self.GainMachine.GiveGain())
        PreviousFlux = ThisFlux

        divergence_factor = 1 + max(self.GD["HMP"]["AllowResidIncrease"], 0)

        xlast, ylast, Flast = None, None, None

        def GivePercentDone(ThisMaxFlux):
            fracDone = 1. - (ThisMaxFlux - StopFlux) / (MaxDirty - StopFlux)
            return max(int(round(100 * fracDone)), 100)

        x0 = y0 = None

        try:
            for i in range(self._niter + 1, self.MaxMinorIter + 1):
                self._niter = i

                # x,y,ThisFlux=NpParallel.A_whereMax(self.Dirty,NCPU=self.NCPU,DoAbs=1)
                x, y, peak = NpParallel.A_whereMax(self._PeakSearchImage,
                                                   NCPU=self.NCPU,
                                                   DoAbs=DoAbs,
                                                   Mask=CurrentNegMask)
                if self.GD["HMP"]["FractionRandomPeak"] is not None:
                    op = lambda x: x
                    if DoAbs: op = lambda x: np.abs(x)
                    _, _, indx, indy = np.where(
                        (op(self._PeakSearchImage) >= peak *
                         self.GD["HMP"]["FractionRandomPeak"])
                        & np.logical_not(CurrentNegMask))
                    ii = np.int64(np.random.rand(1)[0] * indx.size)
                    x, y = indx[ii], indy[ii]
                    peak = op(self._PeakSearchImage[0, 0, x, y])

                ThisFlux = float(self._MeanDirty[
                    0, 0, x, y] if self._peakMode is "weighted" else peak)
                if DoAbs:
                    ThisFlux = abs(ThisFlux)

                if xlast is not None:
                    if x == xlast and y == ylast and np.abs(
                        (Flast - peak) / Flast) < 1e-6:
                        print(ModColor.Str(
                            "    [iter=%i] peak of %.3g Jy stuck" %
                            (i, ThisFlux),
                            col="red"),
                              file=log)
                        return "Stuck", False, True
                xlast = x
                ylast = y
                Flast = peak

                #x,y=self.PSFServer.SolveOffsetLM(self._MeanDirty[0,0],x,y); ThisFlux=self._MeanDirty[0,0,x,y]
                self.GainMachine.SetFluxMax(ThisFlux)

                # #x,y=1224, 1994
                # print x,y,ThisFlux
                # x,y=np.where(np.abs(self.Dirty[0])==np.max(np.abs(self.Dirty[0])))
                # ThisFlux=self.Dirty[0,x,y]
                # print x,y,ThisFlux
                # stop

                T.timeit("max0")
                if np.abs(ThisFlux) > divergence_factor * np.abs(PreviousFlux):
                    print(ModColor.Str(
                        "    [iter=%i] peak of %.3g Jy diverging w.r.t. floor of %.3g Jy "
                        % (i, ThisFlux, PreviousFlux),
                        col="red"),
                          file=log)
                    return "Diverging", False, True
                fluxgain = np.abs(ThisFlux - self._prevPeak) / abs(
                    ThisFlux) if self._prevPeak is not None else 1e+99
                if x == x0 and y == y0 and fluxgain < 1e-6:
                    print(ModColor.Str(
                        "    [iter=%i] stalled at peak of %.3g Jy, x=%d y=%d" %
                        (i, ThisFlux, x, y),
                        col="red"),
                          file=log)
                    return "Stalled", False, True
                if np.abs(ThisFlux) < np.abs(PreviousFlux):
                    PreviousFlux = ThisFlux

                self._prevPeak = ThisFlux
                x0, y0 = x, y

                ThisPNR = ThisFlux / rms

                if ThisFlux <= (StopFlux if StopFlux is not None else 0.0) or \
                    ThisPNR <= (self._PNRStop if self._PNRStop is not None else 0.0):
                    rms = np.std(
                        np.real(self._PeakSearchImage.ravel()[self.IndStats]))
                    # pBAR.render(100,"peak %.3g"%(ThisFlux,))
                    if ThisFlux <= StopFlux:
                        print(ModColor.Str(
                            "    [iter=%i] peak of %.3g Jy lower than stopping flux, PNR %.3g"
                            % (i, ThisFlux, ThisFlux / rms),
                            col="green"),
                              file=log)
                    elif ThisPNR <= self._PNRStop:
                        print(ModColor.Str(
                            "    [iter=%i] PNR of %.3g lower than stopping PNR, peak of %.3g Jy"
                            % (i, ThisPNR, ThisFlux),
                            col="green"),
                              file=log)

                    cont = cont and ThisFlux > self.FluxThreshold
                    if not cont:
                        print(ModColor.Str(
                            "    [iter=%i] absolute flux threshold of %.3g Jy has been reached, PNR %.3g"
                            % (i, self.FluxThreshold, ThisFlux / rms),
                            col="green",
                            Bold=True),
                              file=log)
                    # DoneScale*=100./np.sum(DoneScale)
                    # for iScale in range(DoneScale.size):
                    #     print>>log,"       [Scale %i] %.1f%%"%(iScale,DoneScale[iScale])

                    # stop deconvolution if hit absolute treshold; update model
                    return "MinFluxRms", cont, True

    #            if (i>0)&((i%1000)==0):
    #                print>>log, "    [iter=%i] Peak residual flux %f Jy" % (i,ThisFlux)
    #             if (i>0)&((i%100)==0):
    #                 PercentDone=GivePercentDone(ThisFlux)
    #                 pBAR.render(PercentDone,"peak %.3g i%d"%(ThisFlux,self._niter))
                rounded_iter_step = 1 if i < 10 else (10 if i < 200 else (
                    100 if i < 2000 else 1000))
                # min(int(10**math.floor(math.log10(i))), 10000)
                if i >= 10 and i % rounded_iter_step == 0:
                    # if self.GD["Debug"]["PrintMinorCycleRMS"]:
                    rms = np.std(
                        np.real(self._PeakSearchImage.ravel()[self.IndStats]))
                    if self._peakMode is "weighted":
                        print(
                            "    [iter=%i] peak residual %.3g, gain %.3g, rms %g, PNR %.3g (weighted peak %.3g at x=%d y=%d)"
                            % (i, ThisFlux, fluxgain, rms, ThisFlux / rms,
                               peak, x, y),
                            file=log)
                    else:
                        print(
                            "    [iter=%i] peak residual %.3g, gain %.3g, rms %g, PNR %.3g (at x=%d y=%d)"
                            %
                            (i, ThisFlux, fluxgain, rms, ThisFlux / rms, x, y),
                            file=log)
                    # else:
                    #     print >>log, "    [iter=%i] peak residual %.3g" % (
                    #         i, ThisFlux)
                    ClassMultiScaleMachine.CleanSolutionsDump.flush()

                nch, npol, _, _ = self._CubeDirty.shape
                Fpol = np.float32((self._CubeDirty[:, :, x, y].reshape(
                    (nch, npol, 1, 1))).copy())
                #print "Fpol",Fpol
                dx = x - xc
                dy = y - xc

                T.timeit("stuff")

                # iScale=self.MSMachine.FindBestScale((x,y),Fpol)

                self.PSFServer.setLocation(x, y)

                PSF = self.PSFServer.GivePSF()
                MSMachine = self.DicoMSMachine[self.PSFServer.iFacet]

                LocalSM = MSMachine.GiveLocalSM(x, y, Fpol)

                T.timeit("FindScale")
                # print iScale

                # if iScale=="BadFit": continue

                # box=50
                # x0,x1=x-box,x+box
                # y0,y1=y-box,y+box
                # x0,x1=0,-1
                # y0,y1=0,-1
                # pylab.clf()
                # pylab.subplot(1,2,1)
                # pylab.imshow(self.Dirty[0][x0:x1,y0:y1],interpolation="nearest",vmin=mm0,vmax=mm1)
                # #pylab.subplot(1,3,2)
                # #pylab.imshow(self.MaskArray[0],interpolation="nearest",vmin=0,vmax=1,cmap="gray")
                # # pylab.subplot(1,2,2)
                # # pylab.imshow(self.ModelImage[0][x0:x1,y0:y1],interpolation="nearest",cmap="gray")
                # #pylab.imshow(PSF[0],interpolation="nearest",vmin=0,vmax=1)
                # #pylab.colorbar()

                # CurrentGain=self.GainMachine.GiveGain()
                CurrentGain = np.float32(self.GD["Deconv"]["Gain"])

                numexpr.evaluate('LocalSM*CurrentGain', out=LocalSM)
                self.SubStep(x, y, LocalSM)
                T.timeit("SubStep")

                # pylab.subplot(1,2,2)
                # pylab.imshow(self.Dirty[0][x0:x1,y0:y1],interpolation="nearest",vmin=mm0,vmax=mm1)#,vmin=m0,vmax=m1)

                # #pylab.imshow(PSF[0],interpolation="nearest",vmin=0,vmax=1)
                # #pylab.colorbar()
                # pylab.draw()
                # pylab.show(False)
                # pylab.pause(0.1)

                # ######################################

                # ThisComp=self.ListScales[iScale]

                # Scale=ThisComp["Scale"]
                # DoneScale[Scale]+=1

                # if ThisComp["ModelType"]=="Delta":
                #     for pol in range(npol):
                #        self.ModelImage[pol,x,y]+=Fpol[pol,0,0]*self.Gain

                # elif ThisComp["ModelType"]=="Gaussian":
                #     Gauss=ThisComp["Model"]
                #     Sup,_=Gauss.shape
                #     x0,x1=x-Sup/2,x+Sup/2+1
                #     y0,y1=y-Sup/2,y+Sup/2+1

                #     _,N0,_=self.ModelImage.shape

                #     Aedge,Bedge=self.GiveEdges(x,y,N0,Sup/2,Sup/2,Sup)
                #     x0d,x1d,y0d,y1d=Aedge
                #     x0p,x1p,y0p,y1p=Bedge

                #     for pol in range(npol):
                #         self.ModelImage[pol,x0d:x1d,y0d:y1d]+=Gauss[x0p:x1p,y0p:y1p]*pol[pol,0,0]*self.Gain

                # else:
                #     stop

                T.timeit("End")
        except KeyboardInterrupt:
            rms = np.std(np.real(self._PeakSearchImage.ravel()[self.IndStats]))
            print(ModColor.Str(
                "    [iter=%i] minor cycle interrupted with Ctrl+C, peak flux %.3g, PNR %.3g"
                % (self._niter, ThisFlux, ThisFlux / rms)),
                  file=log)
            # DoneScale*=100./np.sum(DoneScale)
            # for iScale in range(DoneScale.size):
            #     print>>log,"       [Scale %i] %.1f%%"%(iScale,DoneScale[iScale])
            return "MaxIter", False, True  # stop deconvolution but do update model

        rms = np.std(np.real(self._PeakSearchImage.ravel()[self.IndStats]))
        print(ModColor.Str(
            "    [iter=%i] Reached maximum number of iterations, peak flux %.3g, PNR %.3g"
            % (self._niter, ThisFlux, ThisFlux / rms)),
              file=log)
        # DoneScale*=100./np.sum(DoneScale)
        # for iScale in range(DoneScale.size):
        #     print>>log,"       [Scale %i] %.1f%%"%(iScale,DoneScale[iScale])
        return "MaxIter", False, True  # stop deconvolution but do update model

    def Update(self, DicoDirty, **kwargs):
        """
        Method to update attributes from ClassDeconvMachine
        """
        # Update image dict
        self.SetDirty(DicoDirty)

    def ToFile(self, fname):
        """
        Write model dict to file
        """
        self.ModelMachine.ToFile(fname)

    def FromFile(self, fname):
        """
        Read model dict from file SubtractModel
        """
        self.ModelMachine.FromFile(fname)

    def FromDico(self, DicoName):
        """
        Read in model dict
        """
        self.ModelMachine.FromDico(DicoName)